whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -3,8 +3,7 @@
3
3
  #if defined(GGML_METAL_EMBED_LIBRARY)
4
4
  __embed_ggml-common.h__
5
5
  #else
6
- // TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift
7
- #include "../ggml-common.h"
6
+ #include "ggml-common.h"
8
7
  #endif
9
8
  #include "ggml-metal-impl.h"
10
9
 
@@ -49,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
49
48
 
50
49
  template <typename type4>
51
50
  void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
52
- reg = (type4)(*(src + il));
51
+ reg = (type4)(*(src));
53
52
  }
54
53
 
55
54
  #if defined(GGML_METAL_USE_BF16)
@@ -57,6 +56,11 @@ template <typename type4x4>
57
56
  void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
58
57
  reg = (type4x4)(*src);
59
58
  }
59
+
60
+ template <typename type4>
61
+ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
62
+ reg = (type4)(*(src));
63
+ }
60
64
  #endif
61
65
 
62
66
  template <typename type4x4>
@@ -373,24 +377,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
373
377
  template <typename type4x4>
374
378
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
375
379
  const half d_all = xb->d;
376
- device const uint8_t * ql = (device const uint8_t *)xb->ql;
377
- device const uint8_t * qh = (device const uint8_t *)xb->qh;
380
+ device const uint16_t * ql = (device const uint16_t *)xb->ql;
381
+ device const uint16_t * qh = (device const uint16_t *)xb->qh;
378
382
  device const int8_t * scales = (device const int8_t *)xb->scales;
379
383
 
380
- ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
381
- qh = qh + 32*(il/8) + 16*(il&1);
384
+ ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
385
+ qh = qh + 16*(il/8) + 8*(il&1);
382
386
  float sc = scales[(il%2) + 2 * ((il/2))];
383
387
  il = (il/2) & 3;
384
388
 
385
- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
386
- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
387
- const float coef = il>1 ? 1.f/16.f : 1.f;
389
+ const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
390
+ const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
388
391
  const float ml = d_all * sc * 32.f;
389
- const float dl = d_all * sc * coef;
390
- for (int i = 0; i < 16; ++i) {
391
- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
392
- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
393
- reg[i/4][i%4] = dl * q - ml;
392
+ const float dl0 = d_all * sc;
393
+ const float dl1 = dl0 / 256.f;
394
+ const float dl2 = dl0 / (256.f * 256.f);
395
+ const float dl3 = dl0 / (256.f * 256.f * 256.f);
396
+ const uint8_t shr_h = il>2 ? 2 : 0;
397
+ const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
398
+ const uint8_t shr_l = il>1 ? 4 : 0;
399
+ for (int i = 0; i < 4; ++i) {
400
+ const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
401
+ const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
402
+ const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
403
+ reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
404
+ reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
405
+ reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
406
+ reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
394
407
  }
395
408
  }
396
409
 
@@ -843,6 +856,7 @@ kernel void kernel_tanh(
843
856
  constant float GELU_COEF_A = 0.044715f;
844
857
  constant float GELU_QUICK_COEF = -1.702f;
845
858
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
859
+ constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
846
860
 
847
861
  kernel void kernel_gelu(
848
862
  device const float * src0,
@@ -884,6 +898,42 @@ kernel void kernel_gelu_quick_4(
884
898
  dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
885
899
  }
886
900
 
901
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
902
+ // ref: https://www.johndcook.com/blog/python_erf/
903
+ constant float p_erf = 0.3275911f;
904
+ constant float a1_erf = 0.254829592f;
905
+ constant float a2_erf = -0.284496736f;
906
+ constant float a3_erf = 1.421413741f;
907
+ constant float a4_erf = -1.453152027f;
908
+ constant float a5_erf = 1.061405429f;
909
+
910
+ template<typename T>
911
+ T erf_approx(T x) {
912
+ T sign_x = sign(x);
913
+ x = fabs(x);
914
+ T t = 1.0f / (1.0f + p_erf * x);
915
+ T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
916
+ return sign_x * y;
917
+ }
918
+
919
+ kernel void kernel_gelu_erf(
920
+ device const float * src0,
921
+ device float * dst,
922
+ uint tpig[[thread_position_in_grid]]) {
923
+ device const float & x = src0[tpig];
924
+
925
+ dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
926
+ }
927
+
928
+ kernel void kernel_gelu_erf_4(
929
+ device const float4 * src0,
930
+ device float4 * dst,
931
+ uint tpig[[thread_position_in_grid]]) {
932
+ device const float4 & x = src0[tpig];
933
+
934
+ dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
935
+ }
936
+
887
937
  kernel void kernel_silu(
888
938
  device const float * src0,
889
939
  device float * dst,
@@ -936,48 +986,32 @@ kernel void kernel_cos(
936
986
  dst[tpig] = cos(src0[tpig]);
937
987
  }
938
988
 
989
+ kernel void kernel_neg(
990
+ device const float * src0,
991
+ device float * dst,
992
+ uint tpig[[thread_position_in_grid]]) {
993
+ dst[tpig] = -src0[tpig];
994
+ }
995
+
939
996
  kernel void kernel_sum_rows(
940
997
  device const float * src0,
941
998
  device float * dst,
942
- constant int64_t & ne00,
943
- constant int64_t & ne01,
944
- constant int64_t & ne02,
945
- constant int64_t & ne03,
946
- constant uint64_t & nb00,
947
- constant uint64_t & nb01,
948
- constant uint64_t & nb02,
949
- constant uint64_t & nb03,
950
- constant int64_t & ne10,
951
- constant int64_t & ne11,
952
- constant int64_t & ne12,
953
- constant int64_t & ne13,
954
- constant uint64_t & nb10,
955
- constant uint64_t & nb11,
956
- constant uint64_t & nb12,
957
- constant uint64_t & nb13,
958
- constant int64_t & ne0,
959
- constant int64_t & ne1,
960
- constant int64_t & ne2,
961
- constant int64_t & ne3,
962
- constant uint64_t & nb0,
963
- constant uint64_t & nb1,
964
- constant uint64_t & nb2,
965
- constant uint64_t & nb3,
999
+ constant ggml_metal_kargs_sum_rows & args,
966
1000
  uint3 tpig[[thread_position_in_grid]]) {
967
1001
  int64_t i3 = tpig.z;
968
1002
  int64_t i2 = tpig.y;
969
1003
  int64_t i1 = tpig.x;
970
1004
 
971
- if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
1005
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
972
1006
  return;
973
1007
  }
974
1008
 
975
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
976
- device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
1009
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1010
+ device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
977
1011
 
978
1012
  float row_sum = 0;
979
1013
 
980
- for (int64_t i0 = 0; i0 < ne00; i0++) {
1014
+ for (int64_t i0 = 0; i0 < args.ne00; i0++) {
981
1015
  row_sum += src_row[i0];
982
1016
  }
983
1017
 
@@ -989,36 +1023,29 @@ kernel void kernel_soft_max(
989
1023
  device const char * src0,
990
1024
  device const char * src1,
991
1025
  device char * dst,
992
- constant int64_t & ne00,
993
- constant int64_t & ne01,
994
- constant int64_t & ne02,
995
- constant float & scale,
996
- constant float & max_bias,
997
- constant float & m0,
998
- constant float & m1,
999
- constant uint32_t & n_head_log2,
1026
+ constant ggml_metal_kargs_soft_max & args,
1000
1027
  threadgroup float * buf [[threadgroup(0)]],
1001
1028
  uint tgpig[[threadgroup_position_in_grid]],
1002
1029
  uint tpitg[[thread_position_in_threadgroup]],
1003
1030
  uint sgitg[[simdgroup_index_in_threadgroup]],
1004
1031
  uint tiisg[[thread_index_in_simdgroup]],
1005
1032
  uint ntg[[threads_per_threadgroup]]) {
1006
- const int64_t i03 = (tgpig) / (ne02*ne01);
1007
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
1008
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
1033
+ const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1034
+ const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1035
+ const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1009
1036
 
1010
- device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
1011
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
1012
- device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
1037
+ device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1038
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
1039
+ device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
1013
1040
 
1014
1041
  float slope = 1.0f;
1015
1042
 
1016
1043
  // ALiBi
1017
- if (max_bias > 0.0f) {
1044
+ if (args.max_bias > 0.0f) {
1018
1045
  const int64_t h = i02;
1019
1046
 
1020
- const float base = h < n_head_log2 ? m0 : m1;
1021
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
1047
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1048
+ const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
1022
1049
 
1023
1050
  slope = pow(base, exp);
1024
1051
  }
@@ -1026,8 +1053,8 @@ kernel void kernel_soft_max(
1026
1053
  // parallel max
1027
1054
  float lmax = -INFINITY;
1028
1055
 
1029
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1030
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
1056
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1057
+ lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1031
1058
  }
1032
1059
 
1033
1060
  // find the max value in the block
@@ -1051,14 +1078,14 @@ kernel void kernel_soft_max(
1051
1078
 
1052
1079
  // parallel sum
1053
1080
  float lsum = 0.0f;
1054
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1055
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1081
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1082
+ const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1056
1083
  lsum += exp_psrc0;
1057
1084
  pdst[i00] = exp_psrc0;
1058
1085
  }
1059
1086
 
1060
1087
  // This barrier fixes a failing test
1061
- // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
1088
+ // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
1062
1089
  threadgroup_barrier(mem_flags::mem_none);
1063
1090
 
1064
1091
  float sum = simd_sum(lsum);
@@ -1082,7 +1109,7 @@ kernel void kernel_soft_max(
1082
1109
 
1083
1110
  const float inv_sum = 1.0f/sum;
1084
1111
 
1085
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1112
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
1086
1113
  pdst[i00] *= inv_sum;
1087
1114
  }
1088
1115
  }
@@ -1092,35 +1119,28 @@ kernel void kernel_soft_max_4(
1092
1119
  device const char * src0,
1093
1120
  device const char * src1,
1094
1121
  device char * dst,
1095
- constant int64_t & ne00,
1096
- constant int64_t & ne01,
1097
- constant int64_t & ne02,
1098
- constant float & scale,
1099
- constant float & max_bias,
1100
- constant float & m0,
1101
- constant float & m1,
1102
- constant uint32_t & n_head_log2,
1122
+ constant ggml_metal_kargs_soft_max & args,
1103
1123
  threadgroup float * buf [[threadgroup(0)]],
1104
1124
  uint tgpig[[threadgroup_position_in_grid]],
1105
1125
  uint tpitg[[thread_position_in_threadgroup]],
1106
1126
  uint sgitg[[simdgroup_index_in_threadgroup]],
1107
1127
  uint tiisg[[thread_index_in_simdgroup]],
1108
1128
  uint ntg[[threads_per_threadgroup]]) {
1109
- const int64_t i03 = (tgpig) / (ne02*ne01);
1110
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
1111
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
1129
+ const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
1130
+ const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
1131
+ const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
1112
1132
 
1113
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
1114
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
1115
- device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
1133
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1134
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
1135
+ device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
1116
1136
 
1117
1137
  float slope = 1.0f;
1118
1138
 
1119
- if (max_bias > 0.0f) {
1139
+ if (args.max_bias > 0.0f) {
1120
1140
  const int64_t h = i02;
1121
1141
 
1122
- const float base = h < n_head_log2 ? m0 : m1;
1123
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
1142
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1143
+ const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
1124
1144
 
1125
1145
  slope = pow(base, exp);
1126
1146
  }
@@ -1128,8 +1148,8 @@ kernel void kernel_soft_max_4(
1128
1148
  // parallel max
1129
1149
  float4 lmax4 = -INFINITY;
1130
1150
 
1131
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1132
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1151
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1152
+ lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1133
1153
  }
1134
1154
 
1135
1155
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -1154,8 +1174,8 @@ kernel void kernel_soft_max_4(
1154
1174
 
1155
1175
  // parallel sum
1156
1176
  float4 lsum4 = 0.0f;
1157
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1158
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1177
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1178
+ const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1159
1179
  lsum4 += exp_psrc4;
1160
1180
  pdst4[i00] = exp_psrc4;
1161
1181
  }
@@ -1163,7 +1183,7 @@ kernel void kernel_soft_max_4(
1163
1183
  const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
1164
1184
 
1165
1185
  // This barrier fixes a failing test
1166
- // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
1186
+ // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
1167
1187
  threadgroup_barrier(mem_flags::mem_none);
1168
1188
 
1169
1189
  float sum = simd_sum(lsum);
@@ -1187,7 +1207,7 @@ kernel void kernel_soft_max_4(
1187
1207
 
1188
1208
  const float inv_sum = 1.0f/sum;
1189
1209
 
1190
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1210
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
1191
1211
  pdst4[i00] *= inv_sum;
1192
1212
  }
1193
1213
  }
@@ -1203,27 +1223,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
1203
1223
  kernel void kernel_diag_mask_inf(
1204
1224
  device const float * src0,
1205
1225
  device float * dst,
1206
- constant int64_t & ne00,
1207
- constant int64_t & ne01,
1208
- constant int & n_past,
1226
+ constant ggml_metal_kargs_diag_mask_inf & args,
1209
1227
  uint3 tpig[[thread_position_in_grid]]) {
1210
1228
  const int64_t i02 = tpig[2];
1211
1229
  const int64_t i01 = tpig[1];
1212
1230
  const int64_t i00 = tpig[0];
1213
1231
 
1214
- if (i00 > n_past + i01) {
1215
- dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
1232
+ if (i00 > args.n_past + i01) {
1233
+ dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
1216
1234
  } else {
1217
- dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
1235
+ dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
1218
1236
  }
1219
1237
  }
1220
1238
 
1221
1239
  kernel void kernel_diag_mask_inf_8(
1222
1240
  device const float4 * src0,
1223
1241
  device float4 * dst,
1224
- constant int64_t & ne00,
1225
- constant int64_t & ne01,
1226
- constant int & n_past,
1242
+ constant ggml_metal_kargs_diag_mask_inf & args,
1227
1243
  uint3 tpig[[thread_position_in_grid]]) {
1228
1244
 
1229
1245
  const int64_t i = 2*tpig[0];
@@ -1231,42 +1247,26 @@ kernel void kernel_diag_mask_inf_8(
1231
1247
  dst[i+0] = src0[i+0];
1232
1248
  dst[i+1] = src0[i+1];
1233
1249
  int64_t i4 = 4*i;
1234
- const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
1235
- const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
1250
+ const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
1251
+ const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00;
1236
1252
  const int64_t i00 = i4;
1237
1253
  for (int k = 3; k >= 0; --k) {
1238
- if (i00 + 4 + k <= n_past + i01) {
1254
+ if (i00 + 4 + k <= args.n_past + i01) {
1239
1255
  break;
1240
1256
  }
1241
1257
  dst[i+1][k] = -INFINITY;
1242
- if (i00 + k > n_past + i01) {
1258
+ if (i00 + k > args.n_past + i01) {
1243
1259
  dst[i][k] = -INFINITY;
1244
1260
  }
1245
1261
  }
1246
1262
  }
1247
1263
 
1248
1264
  // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
1249
- // TODO: optimize
1250
1265
  kernel void kernel_ssm_conv_f32(
1251
1266
  device const void * src0,
1252
1267
  device const void * src1,
1253
1268
  device float * dst,
1254
- constant int64_t & ne00,
1255
- constant int64_t & ne01,
1256
- constant int64_t & ne02,
1257
- constant uint64_t & nb00,
1258
- constant uint64_t & nb01,
1259
- constant uint64_t & nb02,
1260
- constant int64_t & ne10,
1261
- constant int64_t & ne11,
1262
- constant uint64_t & nb10,
1263
- constant uint64_t & nb11,
1264
- constant int64_t & ne0,
1265
- constant int64_t & ne1,
1266
- constant int64_t & ne2,
1267
- constant uint64_t & nb0,
1268
- constant uint64_t & nb1,
1269
- constant uint64_t & nb2,
1269
+ constant ggml_metal_kargs_ssm_conv & args,
1270
1270
  uint3 tgpig[[threadgroup_position_in_grid]],
1271
1271
  uint3 tpitg[[thread_position_in_threadgroup]],
1272
1272
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -1274,15 +1274,15 @@ kernel void kernel_ssm_conv_f32(
1274
1274
  const int64_t i2 = tgpig.y;
1275
1275
  const int64_t i3 = tgpig.z;
1276
1276
 
1277
- const int64_t nc = ne10;
1278
- //const int64_t ncs = ne00;
1279
- //const int64_t nr = ne01;
1280
- //const int64_t n_t = ne1;
1281
- //const int64_t n_s = ne2;
1277
+ const int64_t nc = args.ne10;
1278
+ //const int64_t ncs = args.ne00;
1279
+ //const int64_t nr = args.ne01;
1280
+ //const int64_t n_t = args.ne1;
1281
+ //const int64_t n_s = args.ne2;
1282
1282
 
1283
- device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
1284
- device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
1285
- device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
1283
+ device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
1284
+ device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
1285
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
1286
1286
 
1287
1287
  float sumf = 0.0f;
1288
1288
 
@@ -1294,7 +1294,6 @@ kernel void kernel_ssm_conv_f32(
1294
1294
  }
1295
1295
 
1296
1296
  // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
1297
- // TODO: optimize
1298
1297
  kernel void kernel_ssm_scan_f32(
1299
1298
  device const void * src0,
1300
1299
  device const void * src1,
@@ -1303,48 +1302,27 @@ kernel void kernel_ssm_scan_f32(
1303
1302
  device const void * src4,
1304
1303
  device const void * src5,
1305
1304
  device float * dst,
1306
- constant int64_t & d_state,
1307
- constant int64_t & d_inner,
1308
- constant int64_t & n_seq_tokens,
1309
- constant int64_t & n_seqs,
1310
- constant uint64_t & nb00,
1311
- constant uint64_t & nb01,
1312
- constant uint64_t & nb02,
1313
- constant uint64_t & nb10,
1314
- constant uint64_t & nb11,
1315
- constant uint64_t & nb12,
1316
- constant uint64_t & nb13,
1317
- constant uint64_t & nb20,
1318
- constant uint64_t & nb21,
1319
- constant uint64_t & nb22,
1320
- constant uint64_t & nb30,
1321
- constant uint64_t & nb31,
1322
- constant uint64_t & nb40,
1323
- constant uint64_t & nb41,
1324
- constant uint64_t & nb42,
1325
- constant uint64_t & nb50,
1326
- constant uint64_t & nb51,
1327
- constant uint64_t & nb52,
1305
+ constant ggml_metal_kargs_ssm_scan & args,
1328
1306
  uint3 tgpig[[threadgroup_position_in_grid]],
1329
1307
  uint3 tpitg[[thread_position_in_threadgroup]],
1330
1308
  uint3 ntg[[threads_per_threadgroup]]) {
1331
1309
  const int64_t ir = tgpig.x;
1332
1310
  const int64_t i3 = tgpig.y;
1333
1311
 
1334
- const int64_t nc = d_state;
1335
- //const int64_t nr = d_inner;
1336
- const int64_t n_t = n_seq_tokens;
1337
- //const int64_t n_s = n_seqs;
1312
+ const int64_t nc = args.d_state;
1313
+ // const int64_t nr = args.d_inner;
1314
+ const int64_t n_t = args.n_seq_tokens;
1315
+ // const int64_t n_s = args.n_seqs;
1338
1316
 
1339
1317
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1340
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
1341
- device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
1342
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
1343
- device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
1344
- device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
1345
- device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
1346
- device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
1347
- device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
1318
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
1319
+ device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
1320
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
1321
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1322
+ device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
1323
+ device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
1324
+ device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
1325
+ device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
1348
1326
 
1349
1327
  if (i2 > 0) {
1350
1328
  s0 = s;
@@ -1366,6 +1344,184 @@ kernel void kernel_ssm_scan_f32(
1366
1344
  }
1367
1345
  }
1368
1346
 
1347
+ kernel void kernel_rwkv_wkv6_f32(
1348
+ device const float * k,
1349
+ device const float * v,
1350
+ device const float * r,
1351
+ device const float * tf,
1352
+ device const float * td,
1353
+ device const float * state_in,
1354
+ device float * dst,
1355
+ constant uint & B,
1356
+ constant uint & T,
1357
+ constant uint & C,
1358
+ constant uint & H,
1359
+ uint3 tgpig[[threadgroup_position_in_grid]],
1360
+ uint3 tpitg[[thread_position_in_threadgroup]],
1361
+ uint3 ntg[[threads_per_threadgroup]]) {
1362
+
1363
+ const uint head_size = 64; // TODO: support head_size = 128
1364
+ const uint batch_id = tgpig.x / H;
1365
+ const uint head_id = tgpig.x % H;
1366
+ const uint tid = tpitg.x;
1367
+
1368
+ if (batch_id >= B || head_id >= H) {
1369
+ return;
1370
+ }
1371
+
1372
+ const uint state_size = C * head_size;
1373
+ const uint n_seq_tokens = T / B;
1374
+
1375
+ threadgroup float _k[head_size];
1376
+ threadgroup float _r[head_size];
1377
+ threadgroup float _tf[head_size];
1378
+ threadgroup float _td[head_size];
1379
+
1380
+ float state[head_size];
1381
+
1382
+ for (uint i = 0; i < head_size; i++) {
1383
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
1384
+ + i * head_size + tid];
1385
+ }
1386
+
1387
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1388
+ _tf[tid] = tf[head_id * head_size + tid];
1389
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1390
+
1391
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
1392
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
1393
+
1394
+ for (uint t = start_t; t < end_t; t += C) {
1395
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1396
+ _k[tid] = k[t];
1397
+ _r[tid] = r[t];
1398
+ _td[tid] = td[t];
1399
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1400
+
1401
+ const float v_val = v[t];
1402
+ float y = 0.0;
1403
+
1404
+ for (uint j = 0; j < head_size; j += 4) {
1405
+ float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
1406
+ float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
1407
+ float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
1408
+ float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
1409
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1410
+
1411
+ float4 kv = k_vec * v_val;
1412
+
1413
+ float4 temp = tf_vec * kv + s_vec;
1414
+ y += dot(r_vec, temp);
1415
+
1416
+ s_vec = s_vec * td_vec + kv;
1417
+ state[j] = s_vec[0];
1418
+ state[j+1] = s_vec[1];
1419
+ state[j+2] = s_vec[2];
1420
+ state[j+3] = s_vec[3];
1421
+ }
1422
+
1423
+ dst[t] = y;
1424
+ }
1425
+
1426
+ for (uint i = 0; i < head_size; i++) {
1427
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
1428
+ + i * head_size + tid] = state[i];
1429
+ }
1430
+ }
1431
+
1432
+ kernel void kernel_rwkv_wkv7_f32(
1433
+ device const float * r,
1434
+ device const float * w,
1435
+ device const float * k,
1436
+ device const float * v,
1437
+ device const float * a,
1438
+ device const float * b,
1439
+ device const float * state_in,
1440
+ device float * dst,
1441
+ constant uint & B,
1442
+ constant uint & T,
1443
+ constant uint & C,
1444
+ constant uint & H,
1445
+ uint3 tgpig[[threadgroup_position_in_grid]],
1446
+ uint3 tpitg[[thread_position_in_threadgroup]],
1447
+ uint3 ntg[[threads_per_threadgroup]]) {
1448
+
1449
+ const uint head_size = 64; // TODO: support head_size = 128
1450
+ const uint batch_id = tgpig.x / H;
1451
+ const uint head_id = tgpig.x % H;
1452
+ const uint tid = tpitg.x;
1453
+
1454
+ if (batch_id >= B || head_id >= H) {
1455
+ return;
1456
+ }
1457
+
1458
+ const uint state_size = C * head_size;
1459
+ const uint n_seq_tokens = T / B;
1460
+
1461
+ threadgroup float _r[head_size];
1462
+ threadgroup float _w[head_size];
1463
+ threadgroup float _k[head_size];
1464
+ threadgroup float _a[head_size];
1465
+ threadgroup float _b[head_size];
1466
+
1467
+ float state[head_size];
1468
+
1469
+ for (uint i = 0; i < head_size; i++) {
1470
+ state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
1471
+ + tid * head_size + i];
1472
+ }
1473
+
1474
+ const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
1475
+ const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
1476
+
1477
+ for (uint t = start_t; t < end_t; t += C) {
1478
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1479
+ _r[tid] = r[t];
1480
+ _w[tid] = w[t];
1481
+ _k[tid] = k[t];
1482
+ _a[tid] = a[t];
1483
+ _b[tid] = b[t];
1484
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1485
+
1486
+ const float v_val = v[t];
1487
+ float y = 0.0, sa = 0.0;
1488
+
1489
+ float4 sa_vec(0.0);
1490
+
1491
+ for (uint j = 0; j < head_size; j += 4) {
1492
+ float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
1493
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1494
+ sa_vec += a_vec * s_vec;
1495
+ }
1496
+ sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
1497
+
1498
+ for (uint j = 0; j < head_size; j += 4) {
1499
+ float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
1500
+ float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
1501
+ float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
1502
+ float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
1503
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1504
+
1505
+ float4 kv = k_vec * v_val;
1506
+
1507
+ s_vec = s_vec * w_vec + kv + sa * b_vec;
1508
+ y += dot(s_vec, r_vec);
1509
+
1510
+ state[j] = s_vec[0];
1511
+ state[j+1] = s_vec[1];
1512
+ state[j+2] = s_vec[2];
1513
+ state[j+3] = s_vec[3];
1514
+ }
1515
+
1516
+ dst[t] = y;
1517
+ }
1518
+
1519
+ for (uint i = 0; i < head_size; i++) {
1520
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
1521
+ + tid * head_size + i] = state[i];
1522
+ }
1523
+ }
1524
+
1369
1525
  kernel void kernel_argmax(
1370
1526
  device const void * x,
1371
1527
  device int32_t * dst,
@@ -1534,25 +1690,61 @@ kernel void kernel_rms_norm(
1534
1690
  }
1535
1691
  }
1536
1692
 
1693
+ kernel void kernel_l2_norm(
1694
+ constant ggml_metal_kargs_l2_norm & args,
1695
+ device const char * src0,
1696
+ device char * dst,
1697
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
1698
+ uint tgpig[[threadgroup_position_in_grid]],
1699
+ ushort tpitg[[thread_position_in_threadgroup]],
1700
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1701
+ ushort tiisg[[thread_index_in_simdgroup]],
1702
+ ushort ntg[[threads_per_threadgroup]]) {
1703
+ if (sgitg == 0) {
1704
+ shmem_f32[tiisg] = 0.0f;
1705
+ }
1706
+
1707
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
1708
+
1709
+ float sumf = 0.0f;
1710
+
1711
+ // parallel sum
1712
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1713
+ sumf += dot(x[i00], x[i00]);
1714
+ }
1715
+ sumf = simd_sum(sumf);
1716
+
1717
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1718
+
1719
+ if (tiisg == 0) {
1720
+ shmem_f32[sgitg] = sumf;
1721
+ }
1722
+
1723
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1724
+
1725
+ sumf = shmem_f32[tiisg];
1726
+ sumf = simd_sum(sumf);
1727
+
1728
+ const float scale = 1.0f/sqrt(max(sumf, args.eps));
1729
+
1730
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
1731
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1732
+ y[i00] = x[i00] * scale;
1733
+ }
1734
+ }
1735
+
1537
1736
  kernel void kernel_group_norm(
1538
1737
  device const float * src0,
1539
1738
  device float * dst,
1540
- constant int64_t & ne00,
1541
- constant int64_t & ne01,
1542
- constant int64_t & ne02,
1543
- constant uint64_t & nb00,
1544
- constant uint64_t & nb01,
1545
- constant uint64_t & nb02,
1546
- constant int32_t & n_groups,
1547
- constant float & eps,
1739
+ constant ggml_metal_kargs_group_norm & args,
1548
1740
  threadgroup float * buf [[threadgroup(0)]],
1549
1741
  uint tgpig[[threadgroup_position_in_grid]],
1550
1742
  uint tpitg[[thread_position_in_threadgroup]],
1551
1743
  uint sgitg[[simdgroup_index_in_threadgroup]],
1552
1744
  uint tiisg[[thread_index_in_simdgroup]],
1553
1745
  uint ntg[[threads_per_threadgroup]]) {
1554
- const int64_t ne = ne00*ne01*ne02;
1555
- const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
1746
+ const int64_t ne = args.ne00*args.ne01*args.ne02;
1747
+ const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
1556
1748
 
1557
1749
  int start = tgpig * gs;
1558
1750
  int end = start + gs;
@@ -1616,7 +1808,7 @@ kernel void kernel_group_norm(
1616
1808
  }
1617
1809
 
1618
1810
  const float variance = tmp / gs;
1619
- const float scale = 1.0f/sqrt(variance + eps);
1811
+ const float scale = 1.0f/sqrt(variance + args.eps);
1620
1812
  for (int j = start; j < end; j += ntg) {
1621
1813
  dst[j] *= scale;
1622
1814
  }
@@ -1710,14 +1902,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
1710
1902
  return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
1711
1903
  }
1712
1904
 
1713
- // putting them in the kernel cause a significant performance penalty
1714
- #define N_DST 4 // each SIMD group works on 4 rows
1715
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
1716
- //Note: This is a template, but strictly speaking it only applies to
1717
- // quantizations where the block size is 32. It also does not
1718
- // guard against the number of rows not being divisible by
1719
- // N_DST, so this is another explicit assumption of the implementation.
1720
- template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
1905
+ template<typename block_q_type, int nr0, int nsg, int nw, typename args_t>
1721
1906
  void mul_vec_q_n_f32_impl(
1722
1907
  args_t args,
1723
1908
  device const char * src0,
@@ -1733,7 +1918,7 @@ void mul_vec_q_n_f32_impl(
1733
1918
  const int r1 = tgpig.y;
1734
1919
  const int im = tgpig.z;
1735
1920
 
1736
- const int first_row = (r0 * nsg + sgitg) * nr;
1921
+ const int first_row = (r0 * nsg + sgitg) * nr0;
1737
1922
 
1738
1923
  const uint i12 = im%args.ne12;
1739
1924
  const uint i13 = im/args.ne12;
@@ -1745,15 +1930,15 @@ void mul_vec_q_n_f32_impl(
1745
1930
  device const float * y = (device const float *) (src1 + offset1);
1746
1931
 
1747
1932
  // pointers to src0 rows
1748
- device const block_q_type * ax[nr];
1749
- for (int row = 0; row < nr; ++row) {
1933
+ device const block_q_type * ax[nr0];
1934
+ for (int row = 0; row < nr0; ++row) {
1750
1935
  const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
1751
1936
 
1752
1937
  ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
1753
1938
  }
1754
1939
 
1755
1940
  float yl[16]; // src1 vector cache
1756
- float sumf[nr] = {0.f};
1941
+ float sumf[nr0] = {0.f};
1757
1942
 
1758
1943
  const short ix = (tiisg/2);
1759
1944
  const short il = (tiisg%2)*8;
@@ -1765,7 +1950,7 @@ void mul_vec_q_n_f32_impl(
1765
1950
  float sumy[2] = { 0.f, 0.f };
1766
1951
 
1767
1952
  #pragma unroll
1768
- for (int i = 0; i < 8; i += 2) {
1953
+ for (short i = 0; i < 8; i += 2) {
1769
1954
  sumy[0] += yb[i + 0] + yb[i + 1];
1770
1955
  yl[i + 0] = yb[i + 0];
1771
1956
  yl[i + 1] = yb[i + 1]/256.f;
@@ -1776,7 +1961,7 @@ void mul_vec_q_n_f32_impl(
1776
1961
  }
1777
1962
 
1778
1963
  #pragma unroll
1779
- for (int row = 0; row < nr; row++) {
1964
+ for (short row = 0; row < nr0; row++) {
1780
1965
  sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
1781
1966
  }
1782
1967
 
@@ -1785,7 +1970,7 @@ void mul_vec_q_n_f32_impl(
1785
1970
 
1786
1971
  device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
1787
1972
 
1788
- for (int row = 0; row < nr; ++row) {
1973
+ for (int row = 0; row < nr0; ++row) {
1789
1974
  const float tot = simd_sum(sumf[row]);
1790
1975
 
1791
1976
  if (tiisg == 0 && first_row + row < args.ne01) {
@@ -1802,7 +1987,7 @@ kernel void kernel_mul_mv_q4_0_f32(
1802
1987
  uint3 tgpig[[threadgroup_position_in_grid]],
1803
1988
  ushort tiisg[[thread_index_in_simdgroup]],
1804
1989
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1805
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1990
+ mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1806
1991
  }
1807
1992
 
1808
1993
  kernel void kernel_mul_mv_q4_1_f32(
@@ -1813,7 +1998,7 @@ kernel void kernel_mul_mv_q4_1_f32(
1813
1998
  uint3 tgpig[[threadgroup_position_in_grid]],
1814
1999
  ushort tiisg[[thread_index_in_simdgroup]],
1815
2000
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1816
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
2001
+ mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1817
2002
  }
1818
2003
 
1819
2004
  kernel void kernel_mul_mv_q5_0_f32(
@@ -1824,7 +2009,7 @@ kernel void kernel_mul_mv_q5_0_f32(
1824
2009
  uint3 tgpig[[threadgroup_position_in_grid]],
1825
2010
  ushort tiisg[[thread_index_in_simdgroup]],
1826
2011
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1827
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
2012
+ mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1828
2013
  }
1829
2014
 
1830
2015
  kernel void kernel_mul_mv_q5_1_f32(
@@ -1835,12 +2020,12 @@ kernel void kernel_mul_mv_q5_1_f32(
1835
2020
  uint3 tgpig[[threadgroup_position_in_grid]],
1836
2021
  ushort tiisg[[thread_index_in_simdgroup]],
1837
2022
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1838
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
2023
+ mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1839
2024
  }
1840
2025
 
1841
2026
  #define NB_Q8_0 8
1842
2027
 
1843
- template<typename args_t>
2028
+ template<int nr0, int nsg, int nw, typename args_t>
1844
2029
  void kernel_mul_mv_q8_0_f32_impl(
1845
2030
  args_t args,
1846
2031
  device const char * src0,
@@ -1850,16 +2035,13 @@ void kernel_mul_mv_q8_0_f32_impl(
1850
2035
  uint3 tgpig,
1851
2036
  ushort tiisg,
1852
2037
  ushort sgitg) {
1853
- const int nr = N_DST;
1854
- const int nsg = N_SIMDGROUP;
1855
- const int nw = N_SIMDWIDTH;
1856
-
1857
2038
  const int nb = args.ne00/QK8_0;
2039
+
1858
2040
  const int r0 = tgpig.x;
1859
2041
  const int r1 = tgpig.y;
1860
2042
  const int im = tgpig.z;
1861
2043
 
1862
- const int first_row = (r0*nsg + sgitg)*nr;
2044
+ const int first_row = (r0 * nsg + sgitg) * nr0;
1863
2045
 
1864
2046
  const uint i12 = im%args.ne12;
1865
2047
  const uint i13 = im/args.ne12;
@@ -1871,15 +2053,15 @@ void kernel_mul_mv_q8_0_f32_impl(
1871
2053
  device const float * y = (device const float *) (src1 + offset1);
1872
2054
 
1873
2055
  // pointers to src0 rows
1874
- device const block_q8_0 * ax[nr];
1875
- for (int row = 0; row < nr; ++row) {
2056
+ device const block_q8_0 * ax[nr0];
2057
+ for (int row = 0; row < nr0; ++row) {
1876
2058
  const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
1877
2059
 
1878
2060
  ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
1879
2061
  }
1880
2062
 
1881
2063
  float yl[NB_Q8_0];
1882
- float sumf[nr] = { 0.f };
2064
+ float sumf[nr0] = { 0.f };
1883
2065
 
1884
2066
  const short ix = tiisg/4;
1885
2067
  const short il = tiisg%4;
@@ -1892,7 +2074,7 @@ void kernel_mul_mv_q8_0_f32_impl(
1892
2074
  yl[i] = yb[i];
1893
2075
  }
1894
2076
 
1895
- for (int row = 0; row < nr; row++) {
2077
+ for (short row = 0; row < nr0; row++) {
1896
2078
  device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
1897
2079
  float sumq = 0.f;
1898
2080
  for (short iq = 0; iq < NB_Q8_0; ++iq) {
@@ -1906,7 +2088,7 @@ void kernel_mul_mv_q8_0_f32_impl(
1906
2088
 
1907
2089
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
1908
2090
 
1909
- for (int row = 0; row < nr; ++row) {
2091
+ for (int row = 0; row < nr0; ++row) {
1910
2092
  const float tot = simd_sum(sumf[row]);
1911
2093
 
1912
2094
  if (tiisg == 0 && first_row + row < args.ne01) {
@@ -1924,7 +2106,7 @@ kernel void kernel_mul_mv_q8_0_f32(
1924
2106
  uint3 tgpig[[threadgroup_position_in_grid]],
1925
2107
  ushort tiisg[[thread_index_in_simdgroup]],
1926
2108
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1927
- kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
2109
+ kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1928
2110
  }
1929
2111
 
1930
2112
  // mat-vec kernel processing in chunks of float4
@@ -2261,9 +2443,9 @@ void kernel_mul_mv_impl(
2261
2443
  sumf += (T0) x[i] * (T1) y[i];
2262
2444
  }
2263
2445
 
2264
- float all_sum = simd_sum(sumf);
2446
+ float sum_all = simd_sum(sumf);
2265
2447
  if (tiisg == 0) {
2266
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
2448
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
2267
2449
  }
2268
2450
  }
2269
2451
  } else {
@@ -2284,10 +2466,10 @@ void kernel_mul_mv_impl(
2284
2466
  sumf += dot((float4) x4[i], (float4) y4[i]);
2285
2467
  }
2286
2468
 
2287
- float all_sum = simd_sum(sumf);
2469
+ float sum_all = simd_sum(sumf);
2288
2470
  if (tiisg == 0) {
2289
- for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
2290
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
2471
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
2472
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
2291
2473
  }
2292
2474
  }
2293
2475
  }
@@ -2349,9 +2531,9 @@ kernel void kernel_mul_mv_1row(
2349
2531
  for (int i = tiisg; i < args.ne00; i += 32) {
2350
2532
  sumf += (float) x[i] * (float) y[i];
2351
2533
  }
2352
- float all_sum = simd_sum(sumf);
2534
+ float sum_all = simd_sum(sumf);
2353
2535
  if (tiisg == 0) {
2354
- dst_f32[r0] = all_sum;
2536
+ dst_f32[r0] = sum_all;
2355
2537
  }
2356
2538
  } else {
2357
2539
  device const T4 * x4 = (device const T4 *) x;
@@ -2361,11 +2543,11 @@ kernel void kernel_mul_mv_1row(
2361
2543
  sumf += dot((float4) x4[i], y4[i]);
2362
2544
  }
2363
2545
 
2364
- float all_sum = simd_sum(sumf);
2546
+ float sum_all = simd_sum(sumf);
2365
2547
 
2366
2548
  if (tiisg == 0) {
2367
- for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
2368
- dst_f32[r0] = all_sum;
2549
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
2550
+ dst_f32[r0] = sum_all;
2369
2551
  }
2370
2552
  }
2371
2553
  }
@@ -2410,9 +2592,9 @@ kernel void kernel_mul_mv_l4(
2410
2592
  sumf += dot((float4) x4[i], y4[i]);
2411
2593
  }
2412
2594
 
2413
- float all_sum = simd_sum(sumf);
2595
+ float sum_all = simd_sum(sumf);
2414
2596
  if (tiisg == 0) {
2415
- dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
2597
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
2416
2598
  }
2417
2599
  }
2418
2600
  }
@@ -2568,8 +2750,148 @@ kernel void kernel_rope_neox(
2568
2750
  }
2569
2751
  }
2570
2752
 
2753
+ template<typename T>
2754
+ kernel void kernel_rope_multi(
2755
+ constant ggml_metal_kargs_rope & args,
2756
+ device const char * src0,
2757
+ device const char * src1,
2758
+ device const char * src2,
2759
+ device char * dst,
2760
+ ushort tiitg[[thread_index_in_threadgroup]],
2761
+ ushort3 tptg [[threads_per_threadgroup]],
2762
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2763
+ const int i3 = tgpig[2];
2764
+ const int i2 = tgpig[1];
2765
+ const int i1 = tgpig[0];
2766
+
2767
+ float corr_dims[2];
2768
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
2769
+
2770
+ device const int32_t * pos = (device const int32_t *) src1;
2771
+
2772
+ const float inv_ndims = -1.f/args.n_dims;
2773
+
2774
+ float cos_theta;
2775
+ float sin_theta;
2776
+
2777
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2778
+ if (i0 < args.n_dims) {
2779
+ const int ic = i0/2;
2780
+
2781
+ // mrope theta calculations
2782
+ // note: the rest is the same as kernel_rope_neox
2783
+ const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
2784
+ const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
2785
+ const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
2786
+ const int sector = ic % sect_dims;
2787
+
2788
+ float theta_base;
2789
+ if (sector < args.sect_0) {
2790
+ theta_base = (float) pos[i2];
2791
+ } else if (sector < sec_w01) {
2792
+ theta_base = (float) pos[i2 + args.ne02];
2793
+ } else if (sector < sec_w012) {
2794
+ theta_base = (float) pos[i2 + args.ne02 * 2];
2795
+ } else {
2796
+ theta_base = (float) pos[i2 + args.ne02 * 3];
2797
+ }
2798
+ // end of mrope
2799
+
2800
+ const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
2801
+
2802
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
2803
+
2804
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
2805
+
2806
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2807
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
2808
+
2809
+ const float x0 = src[0];
2810
+ const float x1 = src[args.n_dims/2];
2811
+
2812
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
2813
+ dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
2814
+ } else {
2815
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2816
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2817
+
2818
+ dst_data[0] = src[0];
2819
+ dst_data[1] = src[1];
2820
+ }
2821
+ }
2822
+ }
2823
+
2824
+ template<typename T>
2825
+ kernel void kernel_rope_vision(
2826
+ constant ggml_metal_kargs_rope & args,
2827
+ device const char * src0,
2828
+ device const char * src1,
2829
+ device const char * src2,
2830
+ device char * dst,
2831
+ ushort tiitg[[thread_index_in_threadgroup]],
2832
+ ushort3 tptg [[threads_per_threadgroup]],
2833
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2834
+ const int i3 = tgpig[2];
2835
+ const int i2 = tgpig[1];
2836
+ const int i1 = tgpig[0];
2837
+
2838
+ float corr_dims[2];
2839
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
2840
+
2841
+ device const int32_t * pos = (device const int32_t *) src1;
2842
+
2843
+ const float inv_ndims = -1.f/args.n_dims;
2844
+
2845
+ float cos_theta;
2846
+ float sin_theta;
2847
+
2848
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2849
+ if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
2850
+ const int ic = i0/2;
2851
+
2852
+ // mrope theta calculations (only support 2 dimensions)
2853
+ const int sect_dims = args.sect_0 + args.sect_1;
2854
+ const int sector = ic % sect_dims;
2855
+
2856
+ float p;
2857
+ float theta_base;
2858
+ if (sector < args.sect_1) {
2859
+ p = (float) sector;
2860
+ theta_base = (float) pos[i2];
2861
+ } else {
2862
+ p = (float) sector - args.sect_0;
2863
+ theta_base = (float) pos[i2 + args.ne02];
2864
+ }
2865
+
2866
+ const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
2867
+ // end of mrope
2868
+
2869
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
2870
+
2871
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
2872
+
2873
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2874
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
2875
+
2876
+ const float x0 = src[0];
2877
+ const float x1 = src[args.n_dims]; // different from kernel_rope_multi
2878
+
2879
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
2880
+ dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
2881
+ } else {
2882
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2883
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2884
+
2885
+ dst_data[0] = src[0];
2886
+ dst_data[1] = src[1];
2887
+ }
2888
+ }
2889
+ }
2890
+
2571
2891
  typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
2572
2892
  typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
2893
+ typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
2894
+ typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
2573
2895
 
2574
2896
  template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
2575
2897
  template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
@@ -2577,20 +2899,16 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
2577
2899
  template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
2578
2900
  template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
2579
2901
 
2902
+ template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
2903
+ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
2904
+
2905
+ template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
2906
+ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
2907
+
2580
2908
  typedef void (im2col_t)(
2581
2909
  device const float * x,
2582
2910
  device char * dst,
2583
- constant int32_t & ofs0,
2584
- constant int32_t & ofs1,
2585
- constant int32_t & IW,
2586
- constant int32_t & IH,
2587
- constant int32_t & CHW,
2588
- constant int32_t & s0,
2589
- constant int32_t & s1,
2590
- constant int32_t & p0,
2591
- constant int32_t & p1,
2592
- constant int32_t & d0,
2593
- constant int32_t & d1,
2911
+ constant ggml_metal_kargs_im2col & args,
2594
2912
  uint3 tgpig[[threadgroup_position_in_grid]],
2595
2913
  uint3 tgpg[[threadgroups_per_grid]],
2596
2914
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2600,17 +2918,7 @@ template <typename T>
2600
2918
  kernel void kernel_im2col(
2601
2919
  device const float * x,
2602
2920
  device char * dst,
2603
- constant int32_t & ofs0,
2604
- constant int32_t & ofs1,
2605
- constant int32_t & IW,
2606
- constant int32_t & IH,
2607
- constant int32_t & CHW,
2608
- constant int32_t & s0,
2609
- constant int32_t & s1,
2610
- constant int32_t & p0,
2611
- constant int32_t & p1,
2612
- constant int32_t & d0,
2613
- constant int32_t & d1,
2921
+ constant ggml_metal_kargs_im2col & args,
2614
2922
  uint3 tgpig[[threadgroup_position_in_grid]],
2615
2923
  uint3 tgpg[[threadgroups_per_grid]],
2616
2924
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2631,17 +2939,17 @@ kernel void kernel_im2col(
2631
2939
  const int64_t ioh = tgpig[1];
2632
2940
  const int64_t iow = tgpig[2];
2633
2941
 
2634
- const int64_t iiw = iow*s0 + ikw*d0 - p0;
2635
- const int64_t iih = ioh*s1 + ikh*d1 - p1;
2942
+ const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
2943
+ const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
2636
2944
 
2637
- const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
2945
+ const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
2638
2946
 
2639
2947
  device T * pdst = (device T *) (dst);
2640
2948
 
2641
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
2949
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
2642
2950
  pdst[offset_dst] = 0.0f;
2643
2951
  } else {
2644
- const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
2952
+ const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
2645
2953
  pdst[offset_dst] = x[offset_src];
2646
2954
  }
2647
2955
  }
@@ -2652,20 +2960,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
2652
2960
  typedef void (im2col_ext_t)(
2653
2961
  device const float * x,
2654
2962
  device char * dst,
2655
- constant int32_t & ofs0,
2656
- constant int32_t & ofs1,
2657
- constant int32_t & IW,
2658
- constant int32_t & IH,
2659
- constant int32_t & CHW,
2660
- constant int32_t & s0,
2661
- constant int32_t & s1,
2662
- constant int32_t & p0,
2663
- constant int32_t & p1,
2664
- constant int32_t & d0,
2665
- constant int32_t & d1,
2666
- constant int32_t & N,
2667
- constant int32_t & KH,
2668
- constant int32_t & KW,
2963
+ constant ggml_metal_kargs_im2col & args,
2669
2964
  uint3 tgpig[[threadgroup_position_in_grid]],
2670
2965
  uint3 tgpg[[threadgroups_per_grid]],
2671
2966
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2675,53 +2970,40 @@ template <typename T>
2675
2970
  kernel void kernel_im2col_ext(
2676
2971
  device const float * x,
2677
2972
  device char * dst,
2678
- constant int32_t & ofs0,
2679
- constant int32_t & ofs1,
2680
- constant int32_t & IW,
2681
- constant int32_t & IH,
2682
- constant int32_t & CHW,
2683
- constant int32_t & s0,
2684
- constant int32_t & s1,
2685
- constant int32_t & p0,
2686
- constant int32_t & p1,
2687
- constant int32_t & d0,
2688
- constant int32_t & d1,
2689
- constant int32_t & N,
2690
- constant int32_t & KH,
2691
- constant int32_t & KW,
2973
+ constant ggml_metal_kargs_im2col & args,
2692
2974
  uint3 tgpig[[threadgroup_position_in_grid]],
2693
2975
  uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
2694
2976
  uint3 tpitg[[thread_position_in_threadgroup]],
2695
2977
  uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
2696
- const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
2978
+ const int64_t KHW = (int64_t)args.KHW;
2697
2979
 
2698
- const int64_t d = tgpig[0] / CHW;
2699
- const int64_t chw = tgpig[0] % CHW;
2980
+ const int64_t d = tgpig[0] / args.CHW;
2981
+ const int64_t chw = tgpig[0] % args.CHW;
2700
2982
  const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2701
2983
  const int64_t HW = tgpig[0] % KHW;
2702
2984
 
2703
2985
  const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
2704
- if (tpitg_0 >= N) {
2986
+ if (tpitg_0 >= args.N) {
2705
2987
  return;
2706
2988
  }
2707
2989
 
2708
- const int64_t tpitg_1 = HW / KW;
2709
- const int64_t tpitg_2 = HW % KW;
2990
+ const int64_t tpitg_1 = HW / args.KW;
2991
+ const int64_t tpitg_2 = HW % args.KW;
2710
2992
 
2711
- const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
2712
- const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
2993
+ const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
2994
+ const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
2713
2995
 
2714
2996
  const int64_t offset_dst =
2715
- (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
2716
- (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
2997
+ (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
2998
+ (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
2717
2999
 
2718
3000
  device T * pdst = (device T *) (dst);
2719
3001
 
2720
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
3002
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
2721
3003
  pdst[offset_dst] = 0.0f;
2722
3004
  } else {
2723
- const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2724
- pdst[offset_dst] = x[offset_src + iih * IW + iiw];
3005
+ const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
3006
+ pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
2725
3007
  }
2726
3008
  }
2727
3009
 
@@ -2732,12 +3014,7 @@ typedef void (conv_transpose_1d_t)(
2732
3014
  device const float * src0,
2733
3015
  device const float * src1,
2734
3016
  device char * dst,
2735
- constant int32_t & IC,
2736
- constant int32_t & IL,
2737
- constant int32_t & K,
2738
- constant int32_t & s0,
2739
- constant uint64_t & nb0,
2740
- constant uint64_t & nb1,
3017
+ constant ggml_metal_kargs_conv_transpose_1d & args,
2741
3018
  uint3 tgpig[[threadgroup_position_in_grid]],
2742
3019
  uint3 tgpg[[threadgroups_per_grid]]);
2743
3020
 
@@ -2746,29 +3023,24 @@ kernel void kernel_conv_transpose_1d(
2746
3023
  device const T * src0,
2747
3024
  device const float * src1,
2748
3025
  device char * dst,
2749
- constant int32_t & IC,
2750
- constant int32_t & IL,
2751
- constant int32_t & K,
2752
- constant int32_t & s0,
2753
- constant uint64_t & nb0,
2754
- constant uint64_t & nb1,
3026
+ constant ggml_metal_kargs_conv_transpose_1d & args,
2755
3027
  uint3 tgpig[[threadgroup_position_in_grid]],
2756
3028
  uint3 tgpg[[threadgroups_per_grid]]) {
2757
3029
 
2758
3030
  float v = 0.0f;
2759
3031
 
2760
- for (int64_t c = 0; c < IC; c++) {
2761
- const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
2762
- const int32_t input_offset = c * IL;
3032
+ for (int64_t c = 0; c < args.IC; c++) {
3033
+ const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
3034
+ const int32_t input_offset = c * args.IL;
2763
3035
 
2764
- for (int64_t i = 0; i < IL; i++) {
2765
- if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
2766
- v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
3036
+ for (int64_t i = 0; i < args.IL; i++) {
3037
+ if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
3038
+ v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
2767
3039
  }
2768
3040
  }
2769
3041
  }
2770
3042
 
2771
- device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
3043
+ device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
2772
3044
 
2773
3045
  dst_ptr[0] = v;
2774
3046
  }
@@ -2778,12 +3050,7 @@ kernel void kernel_conv_transpose_1d<float>(
2778
3050
  device const float * src0,
2779
3051
  device const float * src1,
2780
3052
  device char * dst,
2781
- constant int32_t & IC,
2782
- constant int32_t & IL,
2783
- constant int32_t & K,
2784
- constant int32_t & s0,
2785
- constant uint64_t & nb0,
2786
- constant uint64_t & nb1,
3053
+ constant ggml_metal_kargs_conv_transpose_1d & args,
2787
3054
  uint3 tgpig[[threadgroup_position_in_grid]],
2788
3055
  uint3 tgpg[[threadgroups_per_grid]]);
2789
3056
 
@@ -2792,38 +3059,14 @@ kernel void kernel_conv_transpose_1d<half>(
2792
3059
  device const half * src0,
2793
3060
  device const float * src1,
2794
3061
  device char * dst,
2795
- constant int32_t & IC,
2796
- constant int32_t & IL,
2797
- constant int32_t & K,
2798
- constant int32_t & s0,
2799
- constant uint64_t & nb0,
2800
- constant uint64_t & nb1,
3062
+ constant ggml_metal_kargs_conv_transpose_1d & args,
2801
3063
  uint3 tgpig[[threadgroup_position_in_grid]],
2802
3064
  uint3 tgpg[[threadgroups_per_grid]]);
2803
3065
 
2804
3066
  kernel void kernel_upscale_f32(
2805
3067
  device const char * src0,
2806
3068
  device char * dst,
2807
- constant int64_t & ne00,
2808
- constant int64_t & ne01,
2809
- constant int64_t & ne02,
2810
- constant int64_t & ne03,
2811
- constant uint64_t & nb00,
2812
- constant uint64_t & nb01,
2813
- constant uint64_t & nb02,
2814
- constant uint64_t & nb03,
2815
- constant int64_t & ne0,
2816
- constant int64_t & ne1,
2817
- constant int64_t & ne2,
2818
- constant int64_t & ne3,
2819
- constant uint64_t & nb0,
2820
- constant uint64_t & nb1,
2821
- constant uint64_t & nb2,
2822
- constant uint64_t & nb3,
2823
- constant float & sf0,
2824
- constant float & sf1,
2825
- constant float & sf2,
2826
- constant float & sf3,
3069
+ constant ggml_metal_kargs_upscale & args,
2827
3070
  uint3 tgpig[[threadgroup_position_in_grid]],
2828
3071
  uint3 tpitg[[thread_position_in_threadgroup]],
2829
3072
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -2832,15 +3075,15 @@ kernel void kernel_upscale_f32(
2832
3075
  const int64_t i2 = tgpig.y;
2833
3076
  const int64_t i1 = tgpig.x;
2834
3077
 
2835
- const int64_t i03 = i3/sf3;
2836
- const int64_t i02 = i2/sf2;
2837
- const int64_t i01 = i1/sf1;
3078
+ const int64_t i03 = i3/args.sf3;
3079
+ const int64_t i02 = i2/args.sf2;
3080
+ const int64_t i01 = i1/args.sf1;
2838
3081
 
2839
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2840
- const int64_t i00 = i0/sf0;
3082
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
3083
+ const int64_t i00 = i0/args.sf0;
2841
3084
 
2842
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2843
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3085
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
3086
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2844
3087
 
2845
3088
  dst_ptr[0] = src0_ptr[0];
2846
3089
  }
@@ -2849,22 +3092,7 @@ kernel void kernel_upscale_f32(
2849
3092
  kernel void kernel_pad_f32(
2850
3093
  device const char * src0,
2851
3094
  device char * dst,
2852
- constant int64_t & ne00,
2853
- constant int64_t & ne01,
2854
- constant int64_t & ne02,
2855
- constant int64_t & ne03,
2856
- constant uint64_t & nb00,
2857
- constant uint64_t & nb01,
2858
- constant uint64_t & nb02,
2859
- constant uint64_t & nb03,
2860
- constant int64_t & ne0,
2861
- constant int64_t & ne1,
2862
- constant int64_t & ne2,
2863
- constant int64_t & ne3,
2864
- constant uint64_t & nb0,
2865
- constant uint64_t & nb1,
2866
- constant uint64_t & nb2,
2867
- constant uint64_t & nb3,
3095
+ constant ggml_metal_kargs_pad & args,
2868
3096
  uint3 tgpig[[threadgroup_position_in_grid]],
2869
3097
  uint3 tpitg[[thread_position_in_threadgroup]],
2870
3098
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -2877,12 +3105,12 @@ kernel void kernel_pad_f32(
2877
3105
  const int64_t i02 = i2;
2878
3106
  const int64_t i01 = i1;
2879
3107
 
2880
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
2881
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
3108
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
3109
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
2882
3110
 
2883
- if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
2884
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2885
- if (i0 < ne00) {
3111
+ if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
3112
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
3113
+ if (i0 < args.ne00) {
2886
3114
  dst_ptr[i0] = src0_ptr[i0];
2887
3115
  } else {
2888
3116
  dst_ptr[i0] = 0.0f;
@@ -2892,7 +3120,7 @@ kernel void kernel_pad_f32(
2892
3120
  return;
2893
3121
  }
2894
3122
 
2895
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3123
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2896
3124
  dst_ptr[i0] = 0.0f;
2897
3125
  }
2898
3126
  }
@@ -2900,21 +3128,7 @@ kernel void kernel_pad_f32(
2900
3128
  kernel void kernel_pad_reflect_1d_f32(
2901
3129
  device const char * src0,
2902
3130
  device char * dst,
2903
- constant int64_t & ne00,
2904
- constant int64_t & ne01,
2905
- constant int64_t & ne02,
2906
- constant int64_t & ne03,
2907
- constant int64_t & ne0,
2908
- constant uint64_t & nb00,
2909
- constant uint64_t & nb01,
2910
- constant uint64_t & nb02,
2911
- constant uint64_t & nb03,
2912
- constant uint64_t & nb0,
2913
- constant uint64_t & nb1,
2914
- constant uint64_t & nb2,
2915
- constant uint64_t & nb3,
2916
- constant int32_t & p0,
2917
- constant int32_t & p1,
3131
+ constant ggml_metal_kargs_pad_reflect_1d & args,
2918
3132
  uint3 tgpig[[threadgroup_position_in_grid]],
2919
3133
  uint3 tgpg[[threadgroups_per_grid]],
2920
3134
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2928,17 +3142,17 @@ kernel void kernel_pad_reflect_1d_f32(
2928
3142
  const int64_t i02 = i2;
2929
3143
  const int64_t i01 = i1;
2930
3144
 
2931
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
2932
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
3145
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
3146
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
2933
3147
 
2934
- if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
2935
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2936
- if (i0 < p0) {
2937
- dst_ptr[i0] = src0_ptr[p0 - i0];
2938
- } else if (i0 < ne0 - p1) {
2939
- dst_ptr[i0] = src0_ptr[i0 - p0];
3148
+ if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
3149
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
3150
+ if (i0 < args.p0) {
3151
+ dst_ptr[i0] = src0_ptr[args.p0 - i0];
3152
+ } else if (i0 < args.ne0 - args.p1) {
3153
+ dst_ptr[i0] = src0_ptr[i0 - args.p0];
2940
3154
  } else {
2941
- dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
3155
+ dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
2942
3156
  }
2943
3157
  }
2944
3158
  }
@@ -2946,44 +3160,40 @@ kernel void kernel_pad_reflect_1d_f32(
2946
3160
 
2947
3161
  kernel void kernel_arange_f32(
2948
3162
  device char * dst,
2949
- constant int64_t & ne0,
2950
- constant float & start,
2951
- constant float & step,
3163
+ constant ggml_metal_kargs_arange & args,
2952
3164
  uint3 tgpig[[threadgroup_position_in_grid]],
2953
3165
  uint3 tpitg[[thread_position_in_threadgroup]],
2954
3166
  uint3 ntg[[threads_per_threadgroup]]) {
2955
3167
 
2956
3168
  device float * dst_ptr = (device float *) dst;
2957
3169
 
2958
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2959
- dst_ptr[i0] = start + step * i0;
3170
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
3171
+ dst_ptr[i0] = args.start + args.step * i0;
2960
3172
  }
2961
3173
  }
2962
3174
 
2963
3175
  kernel void kernel_timestep_embedding_f32(
2964
3176
  device const char * src0,
2965
3177
  device char * dst,
2966
- constant uint64_t & nb1,
2967
- constant int & dim,
2968
- constant int & max_period,
3178
+ constant ggml_metal_kargs_timestep_embedding & args,
2969
3179
  uint3 tgpig[[threadgroup_position_in_grid]],
2970
3180
  uint3 tpitg[[thread_position_in_threadgroup]],
2971
3181
  uint3 ntg[[threads_per_threadgroup]]) {
2972
3182
 
2973
3183
  int i = tgpig.x;
2974
- device float * embed_data = (device float *)(dst + i*nb1);
3184
+ device float * embed_data = (device float *)(dst + i*args.nb1);
2975
3185
 
2976
- int half_ = dim / 2;
3186
+ int half_ = args.dim / 2;
2977
3187
  for (int j = tpitg.x; j < half_; j += ntg.x) {
2978
3188
  float timestep = ((device float *)src0)[i];
2979
- float freq = (float)exp(-log((float)max_period) * j / half_);
3189
+ float freq = (float)exp(-log((float)args.max_period) * j / half_);
2980
3190
  float arg = timestep * freq;
2981
3191
  embed_data[j ] = cos(arg);
2982
3192
  embed_data[j + half_] = sin(arg);
2983
3193
  }
2984
3194
 
2985
- if (dim % 2 != 0 && tpitg.x == 0) {
2986
- embed_data[dim] = 0.f;
3195
+ if (args.dim % 2 != 0 && tpitg.x == 0) {
3196
+ embed_data[args.dim] = 0.f;
2987
3197
  }
2988
3198
  }
2989
3199
 
@@ -2991,8 +3201,7 @@ kernel void kernel_timestep_embedding_f32(
2991
3201
  typedef void (argsort_t)(
2992
3202
  device const float * x,
2993
3203
  device int32_t * dst,
2994
- constant int64_t & ncols,
2995
- constant int64_t & ncols_pad,
3204
+ constant ggml_metal_kargs_argsort & args,
2996
3205
  threadgroup int32_t * shared_values [[threadgroup(0)]],
2997
3206
  uint3 tgpig[[threadgroup_position_in_grid]],
2998
3207
  uint3 tpitg[[thread_position_in_threadgroup]]);
@@ -3001,8 +3210,7 @@ template<ggml_sort_order order>
3001
3210
  kernel void kernel_argsort_f32_i32(
3002
3211
  device const float * x,
3003
3212
  device int32_t * dst,
3004
- constant int64_t & ncols,
3005
- constant int64_t & ncols_pad,
3213
+ constant ggml_metal_kargs_argsort & args,
3006
3214
  threadgroup int32_t * shared_values [[threadgroup(0)]],
3007
3215
  uint3 tgpig[[threadgroup_position_in_grid]],
3008
3216
  uint3 tpitg[[thread_position_in_threadgroup]]) {
@@ -3010,9 +3218,9 @@ kernel void kernel_argsort_f32_i32(
3010
3218
  int col = tpitg[0];
3011
3219
  int row = tgpig[1];
3012
3220
 
3013
- if (col >= ncols_pad) return;
3221
+ if (col >= args.ncols_pad) return;
3014
3222
 
3015
- device const float * x_row = x + row * ncols;
3223
+ device const float * x_row = x + row * args.ncols;
3016
3224
  threadgroup int32_t * dst_row = shared_values;
3017
3225
 
3018
3226
  // initialize indices
@@ -3020,21 +3228,21 @@ kernel void kernel_argsort_f32_i32(
3020
3228
 
3021
3229
  threadgroup_barrier(mem_flags::mem_threadgroup);
3022
3230
 
3023
- for (int k = 2; k <= ncols_pad; k *= 2) {
3231
+ for (int k = 2; k <= args.ncols_pad; k *= 2) {
3024
3232
  for (int j = k / 2; j > 0; j /= 2) {
3025
3233
  int ixj = col ^ j;
3026
3234
  if (ixj > col) {
3027
3235
  if ((col & k) == 0) {
3028
- if (dst_row[col] >= ncols ||
3029
- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
3236
+ if (dst_row[col] >= args.ncols ||
3237
+ (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
3030
3238
  x_row[dst_row[col]] > x_row[dst_row[ixj]] :
3031
3239
  x_row[dst_row[col]] < x_row[dst_row[ixj]]))
3032
3240
  ) {
3033
3241
  SWAP(dst_row[col], dst_row[ixj]);
3034
3242
  }
3035
3243
  } else {
3036
- if (dst_row[ixj] >= ncols ||
3037
- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
3244
+ if (dst_row[ixj] >= args.ncols ||
3245
+ (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
3038
3246
  x_row[dst_row[col]] < x_row[dst_row[ixj]] :
3039
3247
  x_row[dst_row[col]] > x_row[dst_row[ixj]]))
3040
3248
  ) {
@@ -3047,8 +3255,8 @@ kernel void kernel_argsort_f32_i32(
3047
3255
  }
3048
3256
 
3049
3257
  // copy the result to dst without the padding
3050
- if (col < ncols) {
3051
- dst[row * ncols + col] = dst_row[col];
3258
+ if (col < args.ncols) {
3259
+ dst[row * args.ncols + col] = dst_row[col];
3052
3260
  }
3053
3261
  }
3054
3262
 
@@ -3058,9 +3266,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
3058
3266
  kernel void kernel_leaky_relu_f32(
3059
3267
  device const float * src0,
3060
3268
  device float * dst,
3061
- constant float & slope,
3269
+ constant ggml_metal_kargs_leaky_relu & args,
3062
3270
  uint tpig[[thread_position_in_grid]]) {
3063
- dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
3271
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
3064
3272
  }
3065
3273
 
3066
3274
  // ref: https://arxiv.org/pdf/2307.08691.pdf
@@ -3084,10 +3292,11 @@ template<
3084
3292
  typename kd4x4_t, // key type in device memory
3085
3293
  short nl_k,
3086
3294
  void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
3087
- typename vd4x4_t, // key type in device memory
3295
+ typename vd4x4_t, // value type in device memory
3088
3296
  short nl_v,
3089
3297
  void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
3090
- short D, // head size
3298
+ short DK, // K head size
3299
+ short DV, // V head size
3091
3300
  short Q = 8, // queries per threadgroup
3092
3301
  short KV = 8, // key/value processed per each simdgroup
3093
3302
  short C = 32> // cache items per threadgroup
@@ -3109,20 +3318,24 @@ kernel void kernel_flash_attn_ext(
3109
3318
  const int iq2 = tgpig[1];
3110
3319
  const int iq1 = tgpig[0]*Q;
3111
3320
 
3112
- const short D4 = D/4;
3113
- const short D8 = D/8;
3114
- const short D16 = D/16;
3115
- const short NW = N_SIMDWIDTH;
3116
- const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3321
+ constexpr short DK4 = DK/4;
3322
+ constexpr short DK8 = DK/8;
3323
+ constexpr short DK16 = DK/16;
3324
+ constexpr short DV4 = DV/4;
3325
+ constexpr short DV8 = DV/8;
3326
+ constexpr short DV16 = DV/16;
3327
+
3328
+ constexpr short NW = N_SIMDWIDTH;
3329
+ constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3117
3330
 
3118
3331
  const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3119
- const short T = D + 2*TS; // shared memory size per query in (half)
3332
+ const short T = DK + 2*TS; // shared memory size per query in (half)
3120
3333
 
3121
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
3122
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
3123
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation
3124
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t
3125
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
3334
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3335
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3338
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3126
3339
 
3127
3340
  threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
3128
3341
  threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3131,23 +3344,23 @@ kernel void kernel_flash_attn_ext(
3131
3344
  threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
3132
3345
 
3133
3346
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3134
- o8x8_t lo[D8];
3347
+ o8x8_t lo[DV8];
3135
3348
 
3136
3349
  // load heads from Q to shared memory
3137
3350
  for (short j = sgitg; j < Q; j += nsg) {
3138
3351
  device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
3139
3352
 
3140
- for (short i = tiisg; i < D4; i += NW) {
3353
+ for (short i = tiisg; i < DK4; i += NW) {
3141
3354
  if (iq1 + j < args.ne01) {
3142
- sq4[j*D4 + i] = (q4_t) q4[i];
3355
+ sq4[j*DK4 + i] = (q4_t) q4[i];
3143
3356
  } else {
3144
- sq4[j*D4 + i] = (q4_t) 0.0f;
3357
+ sq4[j*DK4 + i] = (q4_t) 0.0f;
3145
3358
  }
3146
3359
  }
3147
3360
  }
3148
3361
 
3149
3362
  // zero out lo
3150
- for (short i = 0; i < D8; ++i) {
3363
+ for (short i = 0; i < DV8; ++i) {
3151
3364
  lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
3152
3365
  }
3153
3366
 
@@ -3161,8 +3374,8 @@ kernel void kernel_flash_attn_ext(
3161
3374
  threadgroup_barrier(mem_flags::mem_threadgroup);
3162
3375
 
3163
3376
  {
3164
- half S[Q] = { [0 ... Q-1] = 0.0f };
3165
- half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
3377
+ float S[Q] = { [0 ... Q-1] = 0.0f };
3378
+ float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
3166
3379
 
3167
3380
  // thread indices inside the simdgroup
3168
3381
  // TODO: see if we can utilize quad-group functions for better performance
@@ -3177,22 +3390,15 @@ kernel void kernel_flash_attn_ext(
3177
3390
  const short ikv2 = iq2/(args.ne02/args.ne_12_2);
3178
3391
  const short ikv3 = iq3/(args.ne03/args.ne_12_3);
3179
3392
 
3180
- // load the queries from shared memory into local memory
3181
- q8x8_t mq[D8];
3182
-
3183
- for (short i = 0; i < D8; ++i) {
3184
- simdgroup_load(mq[i], sq + i*8, D);
3185
- }
3186
-
3187
3393
  const bool has_mask = mask != q;
3188
3394
 
3189
- half slope = 1.0f;
3395
+ float slope = 1.0f;
3190
3396
 
3191
3397
  // ALiBi
3192
3398
  if (args.max_bias > 0.0f) {
3193
3399
  const short h = iq2;
3194
3400
 
3195
- const half base = h < args.n_head_log2 ? args.m0 : args.m1;
3401
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
3196
3402
  const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
3197
3403
 
3198
3404
  slope = pow(base, exph);
@@ -3208,14 +3414,14 @@ kernel void kernel_flash_attn_ext(
3208
3414
 
3209
3415
  if (has_mask) {
3210
3416
  // used to detect blocks full of -INF
3211
- half smax = -INFINITY;
3417
+ float smax = -INFINITY;
3212
3418
 
3213
3419
  // load the mask in shared memory
3214
3420
  #pragma unroll(Q)
3215
3421
  for (short j = 0; j < Q; ++j) {
3216
3422
  device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
3217
3423
 
3218
- const half m = pm[ic + tiisg];
3424
+ const float m = pm[ic + tiisg];
3219
3425
 
3220
3426
  ss[j*TS + C + tiisg] = m;
3221
3427
  smax = max(smax, m);
@@ -3236,20 +3442,22 @@ kernel void kernel_flash_attn_ext(
3236
3442
  // this is compile-time check, so it does not have runtime overhead
3237
3443
  if (is_same<kd4x4_t, k4x4_t>::value) {
3238
3444
  // we can read directly from global memory
3239
- device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3445
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
3240
3446
 
3241
- #pragma unroll(D8)
3242
- for (short i = 0; i < D8; ++i) {
3447
+ #pragma unroll(DK8)
3448
+ for (short i = 0; i < DK8; ++i) {
3243
3449
  k8x8_t mk;
3244
- simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
3450
+ simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
3245
3451
 
3246
- simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
3452
+ q8x8_t mq;
3453
+ simdgroup_load(mq, sq + i*8, DK);
3454
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3247
3455
  }
3248
3456
  } else {
3249
- for (short ii = 0; ii < D16; ii += 4) {
3250
- device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3457
+ for (short ii = 0; ii < DK16; ii += 4) {
3458
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
3251
3459
 
3252
- if (D16%4 == 0) {
3460
+ if (DK16%4 == 0) {
3253
3461
  // the head is evenly divisible by 4*16 = 64, so no need for bound checks
3254
3462
  {
3255
3463
  k4x4_t tmp;
@@ -3262,15 +3470,18 @@ kernel void kernel_flash_attn_ext(
3262
3470
  #pragma unroll(4)
3263
3471
  for (short k = 0; k < 4; ++k) {
3264
3472
  k8x8_t mk;
3473
+ q8x8_t mq;
3265
3474
 
3266
3475
  simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3267
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
3476
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
3477
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3268
3478
 
3269
3479
  simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3270
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
3480
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
3481
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3271
3482
  }
3272
3483
  } else {
3273
- if (ii + tx < D16) {
3484
+ if (ii + tx < DK16) {
3274
3485
  k4x4_t tmp;
3275
3486
  deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
3276
3487
  sk4x4[4*ty + tx] = tmp;
@@ -3278,14 +3489,17 @@ kernel void kernel_flash_attn_ext(
3278
3489
 
3279
3490
  simdgroup_barrier(mem_flags::mem_threadgroup);
3280
3491
 
3281
- for (short k = 0; k < 4 && ii + k < D16; ++k) {
3492
+ for (short k = 0; k < 4 && ii + k < DK16; ++k) {
3282
3493
  k8x8_t mk;
3494
+ q8x8_t mq;
3283
3495
 
3284
3496
  simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3285
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
3497
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
3498
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3286
3499
 
3287
3500
  simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3288
- simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
3501
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
3502
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
3289
3503
  }
3290
3504
  }
3291
3505
  }
@@ -3303,10 +3517,10 @@ kernel void kernel_flash_attn_ext(
3303
3517
  // online softmax
3304
3518
  {
3305
3519
  for (ushort j = 0; j < Q; ++j) {
3306
- const half m = M[j];
3520
+ const float m = M[j];
3307
3521
 
3308
3522
  // scale and apply the logitcap / mask
3309
- half s = ss[j*TS + tiisg]*args.scale;
3523
+ float s = ss[j*TS + tiisg]*args.scale;
3310
3524
 
3311
3525
  if (args.logit_softcap != 0.0f) {
3312
3526
  s = args.logit_softcap*precise::tanh(s);
@@ -3317,8 +3531,8 @@ kernel void kernel_flash_attn_ext(
3317
3531
 
3318
3532
  M[j] = simd_max(max(M[j], s));
3319
3533
 
3320
- const half ms = exp(m - M[j]);
3321
- const half vs = exp(s - M[j]);
3534
+ const float ms = exp(m - M[j]);
3535
+ const float vs = exp(s - M[j]);
3322
3536
 
3323
3537
  S[j] = S[j]*ms + simd_sum(vs);
3324
3538
 
@@ -3337,8 +3551,8 @@ kernel void kernel_flash_attn_ext(
3337
3551
  s8x8_t mm;
3338
3552
  simdgroup_load(mm, ss + 2*C, TS, 0, false);
3339
3553
 
3340
- #pragma unroll(D8)
3341
- for (short i = 0; i < D8; ++i) {
3554
+ #pragma unroll(DV8)
3555
+ for (short i = 0; i < DV8; ++i) {
3342
3556
  simdgroup_multiply(lo[i], mm, lo[i]);
3343
3557
  }
3344
3558
  }
@@ -3351,20 +3565,20 @@ kernel void kernel_flash_attn_ext(
3351
3565
 
3352
3566
  if (is_same<vd4x4_t, v4x4_t>::value) {
3353
3567
  // we can read directly from global memory
3354
- device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3568
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
3355
3569
 
3356
- #pragma unroll(D8)
3357
- for (short i = 0; i < D8; ++i) {
3570
+ #pragma unroll(DV8)
3571
+ for (short i = 0; i < DV8; ++i) {
3358
3572
  v8x8_t mv;
3359
- simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
3573
+ simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
3360
3574
 
3361
3575
  simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3362
3576
  }
3363
3577
  } else {
3364
- for (short ii = 0; ii < D16; ii += 4) {
3365
- device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3578
+ for (short ii = 0; ii < DV16; ii += 4) {
3579
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
3366
3580
 
3367
- if (D16%4 == 0) {
3581
+ if (DV16%4 == 0) {
3368
3582
  // no need for bound checks
3369
3583
  {
3370
3584
  v4x4_t tmp;
@@ -3385,7 +3599,7 @@ kernel void kernel_flash_attn_ext(
3385
3599
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3386
3600
  }
3387
3601
  } else {
3388
- if (ii + tx < D16) {
3602
+ if (ii + tx < DV16) {
3389
3603
  v4x4_t tmp;
3390
3604
  deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
3391
3605
  sv4x4[4*ty + tx] = tmp;
@@ -3393,7 +3607,7 @@ kernel void kernel_flash_attn_ext(
3393
3607
 
3394
3608
  simdgroup_barrier(mem_flags::mem_threadgroup);
3395
3609
 
3396
- for (short k = 0; k < 4 && ii + k < D16; ++k) {
3610
+ for (short k = 0; k < 4 && ii + k < DV16; ++k) {
3397
3611
  v8x8_t mv;
3398
3612
 
3399
3613
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
@@ -3420,15 +3634,15 @@ kernel void kernel_flash_attn_ext(
3420
3634
 
3421
3635
  // reduce the warps sequentially
3422
3636
  for (ushort sg = 1; sg < nsg; ++sg) {
3423
- half S = { 0.0f };
3424
- half M = { -__FLT16_MAX__/2 };
3637
+ float S = { 0.0f };
3638
+ float M = { -__FLT_MAX__/2 };
3425
3639
 
3426
3640
  threadgroup_barrier(mem_flags::mem_threadgroup);
3427
3641
 
3428
3642
  // each simdgroup stores its output to shared memory, reusing sq
3429
3643
  if (sgitg == sg) {
3430
- for (short i = 0; i < D8; ++i) {
3431
- simdgroup_store(lo[i], so + i*8, D, 0, false);
3644
+ for (short i = 0; i < DV8; ++i) {
3645
+ simdgroup_store(lo[i], so + i*8, DV, 0, false);
3432
3646
  }
3433
3647
  }
3434
3648
 
@@ -3437,16 +3651,16 @@ kernel void kernel_flash_attn_ext(
3437
3651
  // the first simdgroup accumulates the results from the other simdgroups
3438
3652
  if (sgitg == 0) {
3439
3653
  for (short j = 0; j < Q; ++j) {
3440
- const half S0 = ss[j*TS + 0];
3441
- const half S1 = ss[j*TS + sg*SH + 0];
3654
+ const float S0 = ss[j*TS + 0];
3655
+ const float S1 = ss[j*TS + sg*SH + 0];
3442
3656
 
3443
- const half M0 = ss[j*TS + 1];
3444
- const half M1 = ss[j*TS + sg*SH + 1];
3657
+ const float M0 = ss[j*TS + 1];
3658
+ const float M1 = ss[j*TS + sg*SH + 1];
3445
3659
 
3446
3660
  M = max(M0, M1);
3447
3661
 
3448
- const half ms0 = exp(M0 - M);
3449
- const half ms1 = exp(M1 - M);
3662
+ const float ms0 = exp(M0 - M);
3663
+ const float ms1 = exp(M1 - M);
3450
3664
 
3451
3665
  S = S0*ms0 + S1*ms1;
3452
3666
 
@@ -3467,11 +3681,11 @@ kernel void kernel_flash_attn_ext(
3467
3681
  simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3468
3682
  simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3469
3683
 
3470
- #pragma unroll(D8)
3471
- for (short i = 0; i < D8; ++i) {
3684
+ #pragma unroll(DV8)
3685
+ for (short i = 0; i < DV8; ++i) {
3472
3686
  o8x8_t t;
3473
3687
 
3474
- simdgroup_load (t, so + i*8, D, 0, false);
3688
+ simdgroup_load (t, so + i*8, DV, 0, false);
3475
3689
  simdgroup_multiply(t, ms1, t);
3476
3690
 
3477
3691
  simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -3482,8 +3696,8 @@ kernel void kernel_flash_attn_ext(
3482
3696
 
3483
3697
  // store result to shared memory (reuse sq)
3484
3698
  if (sgitg == 0) {
3485
- for (short i = 0; i < D8; ++i) {
3486
- simdgroup_store(lo[i], so + i*8, D, 0, false);
3699
+ for (short i = 0; i < DV8; ++i) {
3700
+ simdgroup_store(lo[i], so + i*8, DV, 0, false);
3487
3701
  }
3488
3702
  }
3489
3703
 
@@ -3494,8 +3708,8 @@ kernel void kernel_flash_attn_ext(
3494
3708
  for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
3495
3709
  const float S = ss[j*TS + 0];
3496
3710
 
3497
- for (short i = tiisg; i < D4; i += NW) {
3498
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
3711
+ for (short i = tiisg; i < DV4; i += NW) {
3712
+ dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3499
3713
  }
3500
3714
  }
3501
3715
  }
@@ -3512,80 +3726,101 @@ kernel void kernel_flash_attn_ext(
3512
3726
  float, simdgroup_float8x8, \
3513
3727
  half, half4, simdgroup_half8x8
3514
3728
 
3515
- typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
3729
+ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
3516
3730
 
3517
- template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>;
3518
- template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>;
3519
- template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>;
3520
- template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>;
3521
- template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
3522
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
3731
+ template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
3732
+ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
3733
+ template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
3734
+ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
3735
+ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
3736
+ template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
3737
+ template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
3738
+ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
3739
+ template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
3523
3740
 
3524
3741
  #if defined(GGML_METAL_USE_BF16)
3525
- template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>;
3526
- template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>;
3527
- template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>;
3528
- template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>;
3529
- template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
3530
- template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
3742
+ template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3743
+ template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3744
+ template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3745
+ template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3746
+ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3747
+ template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3748
+ template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3749
+ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3750
+ template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
3531
3751
  #endif
3532
3752
 
3533
- template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
3534
- template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
3535
- template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
3536
- template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
3537
- template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
3538
- template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
3539
-
3540
- template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
3541
- template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
3542
- template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
3543
- template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
3544
- template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
3545
- template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
3546
-
3547
- template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
3548
- template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
3549
- template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
3550
- template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
3551
- template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
3552
- template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
3553
-
3554
- template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
3555
- template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
3556
- template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
3557
- template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
3558
- template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
3559
- template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
3560
-
3561
- template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
3562
- template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
3563
- template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
3564
- template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
3565
- template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
3566
- template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
3753
+ template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
3754
+ template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
3755
+ template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
3756
+ template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
3757
+ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
3758
+ template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
3759
+ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
3760
+ template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
3761
+ template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
3762
+
3763
+ template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
3764
+ template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
3765
+ template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
3766
+ template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
3767
+ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
3768
+ template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
3769
+ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
3770
+ template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
3771
+ template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
3772
+
3773
+ template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
3774
+ template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
3775
+ template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
3776
+ template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
3777
+ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
3778
+ template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
3779
+ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
3780
+ template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
3781
+ template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
3782
+
3783
+ template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
3784
+ template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
3785
+ template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
3786
+ template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
3787
+ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
3788
+ template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
3789
+ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
3790
+ template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
3791
+ template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
3792
+
3793
+ template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
3794
+ template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
3795
+ template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
3796
+ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
3797
+ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
3798
+ template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
3799
+ template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
3800
+ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
3801
+ template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
3567
3802
 
3568
3803
  #undef FA_TYPES
3569
3804
 
3570
3805
  template<
3571
- typename q4_t, // query types in shared memory
3572
- typename q4x4_t,
3573
- typename k4x4_t, // key types in shared memory
3574
- typename v4x4_t, // value types in shared memory
3575
- typename qk_t, // Q*K types
3576
- typename s_t, // soft-max types
3806
+ typename q4_t, // query types in shared memory
3807
+ typename k4_t, // key types in shared memory
3808
+ typename v4_t, // value types in shared memory
3809
+ typename qk_t, // Q*K types
3810
+ typename s_t, // soft-max types
3577
3811
  typename s4_t,
3578
- typename s4x4_t,
3579
- typename o4x4_t, // attention accumulation types
3580
- typename kd4x4_t, // key type in device memory
3812
+ typename o4_t, // attention accumulation types
3813
+ typename kd4_t, // key type in device memory
3581
3814
  short nl_k,
3582
- void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
3583
- typename vd4x4_t, // key type in device memory
3815
+ void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
3816
+ typename vd4_t, // value type in device memory
3584
3817
  short nl_v,
3585
- void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
3586
- short D, // head size
3587
- short Q = 1, // queries per threadgroup
3588
- short C = 32> // cache items per threadgroup
3818
+ void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
3819
+ short DK, // K head size
3820
+ short DV, // V head size
3821
+ short NE = 4, // head elements per thread
3822
+ short Q = 1, // queries per threadgroup
3823
+ short C = 32> // cache items per threadgroup
3589
3824
  kernel void kernel_flash_attn_ext_vec(
3590
3825
  constant ggml_metal_kargs_flash_attn_ext & args,
3591
3826
  device const char * q,
@@ -3604,29 +3839,28 @@ kernel void kernel_flash_attn_ext_vec(
3604
3839
  const int iq2 = tgpig[1];
3605
3840
  const int iq1 = tgpig[0];
3606
3841
 
3607
- const short D4 = D/4;
3608
- const short D16 = D/16;
3609
- const short NW = N_SIMDWIDTH;
3610
- const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
3611
- const short SH = 2*C; // shared memory per simdgroup
3842
+ constexpr short DK4 = DK/4;
3843
+ constexpr short DV4 = DV/4;
3844
+ constexpr short NW = N_SIMDWIDTH;
3845
+ constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3846
+ constexpr short SH = 4*C; // shared memory per simdgroup
3612
3847
 
3613
- const short T = D + nsg*SH; // shared memory size per query in (half)
3848
+ const short T = DK + nsg*SH; // shared memory size per query in (half)
3614
3849
 
3615
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
3616
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
3617
- threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t
3618
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention
3619
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t
3620
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
3621
- threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
3850
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3851
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3852
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3853
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3854
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3855
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3622
3856
 
3623
- // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3624
- o4x4_t lo[D16/NL];
3857
+ // store the result for all queries in local memory (the O matrix from the paper)
3858
+ o4_t lo[DV4/NL];
3625
3859
 
3626
3860
  // load heads from Q to shared memory
3627
3861
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
3628
3862
 
3629
- for (short i = tiisg; i < D4; i += NW) {
3863
+ for (short i = tiisg; i < DK4; i += NW) {
3630
3864
  if (iq1 < args.ne01) {
3631
3865
  sq4[i] = (q4_t) q4[i];
3632
3866
  } else {
@@ -3635,8 +3869,8 @@ kernel void kernel_flash_attn_ext_vec(
3635
3869
  }
3636
3870
 
3637
3871
  // zero out lo
3638
- for (short i = 0; i < D16/NL; ++i) {
3639
- lo[i] = (o4x4_t) 0.0f;
3872
+ for (short i = 0; i < DV4/NL; ++i) {
3873
+ lo[i] = (o4_t) 0.0f;
3640
3874
  }
3641
3875
 
3642
3876
  // zero out shared memory SH
@@ -3647,8 +3881,8 @@ kernel void kernel_flash_attn_ext_vec(
3647
3881
  threadgroup_barrier(mem_flags::mem_threadgroup);
3648
3882
 
3649
3883
  {
3650
- half S = 0.0f;
3651
- half M = -__FLT16_MAX__/2;
3884
+ float S = 0.0f;
3885
+ float M = -__FLT_MAX__/2;
3652
3886
 
3653
3887
  // thread indices inside the simdgroup
3654
3888
  const short tx = tiisg%NL;
@@ -3661,26 +3895,18 @@ kernel void kernel_flash_attn_ext_vec(
3661
3895
  const short ikv2 = iq2/(args.ne02/args.ne_12_2);
3662
3896
  const short ikv3 = iq3/(args.ne03/args.ne_12_3);
3663
3897
 
3664
- // load the queries from shared memory into local memory
3665
- q4x4_t mq[D16/NL];
3666
-
3667
- #pragma unroll(D16/NL)
3668
- for (short ii = 0; ii < D16; ii += NL) {
3669
- mq[ii/NL] = sq4x4[ii + tx];
3670
- }
3671
-
3672
- const bool has_mask = mask != q;
3898
+ const bool has_mask = mask != q;
3673
3899
 
3674
3900
  // pointer to the mask
3675
3901
  device const half * pm = (device const half *) (mask + iq1*args.nb31);
3676
3902
 
3677
- half slope = 1.0f;
3903
+ float slope = 1.0f;
3678
3904
 
3679
3905
  // ALiBi
3680
3906
  if (args.max_bias > 0.0f) {
3681
3907
  const short h = iq2;
3682
3908
 
3683
- const half base = h < args.n_head_log2 ? args.m0 : args.m1;
3909
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
3684
3910
  const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
3685
3911
 
3686
3912
  slope = pow(base, exph);
@@ -3698,45 +3924,63 @@ kernel void kernel_flash_attn_ext_vec(
3698
3924
  sm[tiisg] = pm[ic + tiisg];
3699
3925
  }
3700
3926
 
3927
+ // skip -INF blocks
3928
+ if (simd_max(sm[tiisg]) == -INFINITY) {
3929
+ continue;
3930
+ }
3931
+
3701
3932
  // Q*K^T
3702
3933
  {
3703
- // each simdgroup processes 1 query and 4 (NW/NL) keys
3704
- for (short cc = 0; cc < C/4; ++cc) {
3705
- qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
3934
+ // each simdgroup processes 1 query and NE (NW/NL) head elements
3935
+ for (short cc = 0; cc < C/NE; ++cc) {
3936
+ qk_t mqk = 0.0f;
3706
3937
 
3707
- device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
3938
+ device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
3708
3939
 
3709
- #pragma unroll(D16/NL)
3710
- for (short ii = 0; ii < D16; ii += NL) {
3940
+ #pragma unroll(DK4/NL)
3941
+ for (short ii = 0; ii < DK4; ii += NL) {
3711
3942
  const short i = ii + tx;
3712
3943
 
3713
- k4x4_t mk;
3714
- deq_k(pk + i/nl_k, i%nl_k, mk);
3944
+ k4_t mk;
3945
+ deq_k_t4(pk + i/nl_k, i%nl_k, mk);
3715
3946
 
3716
3947
  // note: this is less precise than the version below
3717
- //mqka[0] += dot(mq[ii/NL][0], mk[0]);
3718
- //mqka[1] += dot(mq[ii/NL][1], mk[1]);
3719
- //mqka[2] += dot(mq[ii/NL][2], mk[2]);
3720
- //mqka[3] += dot(mq[ii/NL][3], mk[3]);
3721
-
3722
- mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
3723
- mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
3724
- mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
3725
- mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
3948
+ //mqka[0] += dot(mq[0], mk[0]);
3949
+ //mqka[1] += dot(mq[1], mk[1]);
3950
+ //mqka[2] += dot(mq[2], mk[2]);
3951
+ //mqka[3] += dot(mq[3], mk[3]);
3952
+
3953
+ //q4x4_t mq = sq4x4[i];
3954
+ //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
3955
+ //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
3956
+ //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
3957
+ //mqka[3] += dot((float4) mq[3], (float4) mk[3]);
3958
+
3959
+ mqk += dot((float4) mk, (float4) sq4[i]);
3726
3960
  }
3727
3961
 
3728
- qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
3962
+ static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
3729
3963
 
3730
- // simdgroup reduce
3964
+ // simdgroup reduce (NE = 4)
3731
3965
  // [ 0 .. 7] -> [ 0]
3732
3966
  // [ 8 .. 15] -> [ 8]
3733
3967
  // [16 .. 23] -> [16]
3734
3968
  // [24 .. 31] -> [24]
3735
- //mqk += simd_shuffle_down(mqk, 16);
3736
- //mqk += simd_shuffle_down(mqk, 8);
3737
- mqk += simd_shuffle_down(mqk, 4);
3738
- mqk += simd_shuffle_down(mqk, 2);
3739
- mqk += simd_shuffle_down(mqk, 1);
3969
+ if (NE <= 1) {
3970
+ mqk += simd_shuffle_down(mqk, 16);
3971
+ }
3972
+ if (NE <= 2) {
3973
+ mqk += simd_shuffle_down(mqk, 8);
3974
+ }
3975
+ if (NE <= 4) {
3976
+ mqk += simd_shuffle_down(mqk, 4);
3977
+ }
3978
+ if (NE <= 8) {
3979
+ mqk += simd_shuffle_down(mqk, 2);
3980
+ }
3981
+ if (NE <= 16) {
3982
+ mqk += simd_shuffle_down(mqk, 1);
3983
+ }
3740
3984
 
3741
3985
  // mqk = mqk*scale + mask*slope
3742
3986
  if (tx == 0) {
@@ -3746,9 +3990,9 @@ kernel void kernel_flash_attn_ext_vec(
3746
3990
  mqk = args.logit_softcap*precise::tanh(mqk);
3747
3991
  }
3748
3992
 
3749
- mqk += sm[4*cc + ty]*slope;
3993
+ mqk += sm[NE*cc + ty]*slope;
3750
3994
 
3751
- ss[4*cc + ty] = mqk;
3995
+ ss[NE*cc + ty] = mqk;
3752
3996
  }
3753
3997
  }
3754
3998
  }
@@ -3757,13 +4001,13 @@ kernel void kernel_flash_attn_ext_vec(
3757
4001
 
3758
4002
  // online softmax
3759
4003
  {
3760
- const half m = M;
3761
- const half s = ss[tiisg];
4004
+ const float m = M;
4005
+ const float s = ss[tiisg];
3762
4006
 
3763
4007
  M = simd_max(max(M, s));
3764
4008
 
3765
- const half ms = exp(m - M);
3766
- const half vs = exp(s - M);
4009
+ const float ms = exp(m - M);
4010
+ const float vs = exp(s - M);
3767
4011
 
3768
4012
  S = S*ms + simd_sum(vs);
3769
4013
 
@@ -3771,8 +4015,8 @@ kernel void kernel_flash_attn_ext_vec(
3771
4015
  ss[tiisg] = vs;
3772
4016
 
3773
4017
  // O = diag(ms)*O
3774
- #pragma unroll(D16/NL)
3775
- for (short ii = 0; ii < D16; ii += NL) {
4018
+ #pragma unroll(DV4/NL)
4019
+ for (short ii = 0; ii < DV4; ii += NL) {
3776
4020
  lo[ii/NL] *= ms;
3777
4021
  }
3778
4022
  }
@@ -3781,19 +4025,20 @@ kernel void kernel_flash_attn_ext_vec(
3781
4025
 
3782
4026
  // O = O + (Q*K^T)*V
3783
4027
  {
3784
- for (short cc = 0; cc < C/4; ++cc) {
3785
- device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
4028
+ //#pragma unroll(C/NE)
4029
+ for (short cc = 0; cc < C/NE; ++cc) {
4030
+ device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
3786
4031
 
3787
- const s4x4_t ms(ss[4*cc + ty]);
4032
+ const s4_t ms(ss[NE*cc + ty]);
3788
4033
 
3789
- #pragma unroll(D16/NL)
3790
- for (short ii = 0; ii < D16; ii += NL) {
4034
+ #pragma unroll(DV4/NL)
4035
+ for (short ii = 0; ii < DV4; ii += NL) {
3791
4036
  const short i = ii + tx;
3792
4037
 
3793
- v4x4_t mv;
3794
- deq_v(pv4 + i/nl_v, i%nl_v, mv);
4038
+ v4_t mv;
4039
+ deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
3795
4040
 
3796
- lo[ii/NL] += mv*ms;
4041
+ lo[ii/NL] += o4_t(float4(mv)*float4(ms));
3797
4042
  }
3798
4043
  }
3799
4044
  }
@@ -3806,7 +4051,7 @@ kernel void kernel_flash_attn_ext_vec(
3806
4051
  }
3807
4052
  }
3808
4053
 
3809
- // simdgroup reduce
4054
+ // simdgroup reduce (NE = 4)
3810
4055
  // [ 0, 8, 16, 24] -> [ 0]
3811
4056
  // [ 1, 9, 17, 25] -> [ 1]
3812
4057
  // [ 2, 10, 18, 26] -> [ 2]
@@ -3815,37 +4060,48 @@ kernel void kernel_flash_attn_ext_vec(
3815
4060
  // [ 5, 13, 21, 29] -> [ 5]
3816
4061
  // [ 6, 14, 22, 30] -> [ 6]
3817
4062
  // [ 7, 15, 23, 31] -> [ 7]
3818
- for (short ii = 0; ii < D16; ii += NL) {
3819
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
3820
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
3821
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
3822
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
3823
- //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
3824
-
3825
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
3826
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
3827
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
3828
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
3829
- //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
3830
-
3831
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
3832
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
3833
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
3834
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
3835
- //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
3836
-
3837
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
3838
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
3839
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
3840
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
3841
- //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
4063
+ for (short ii = 0; ii < DV4; ii += NL) {
4064
+ if (NE > 1) {
4065
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
4066
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
4067
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
4068
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
4069
+ }
4070
+
4071
+ if (NE > 2) {
4072
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
4073
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
4074
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
4075
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
4076
+ }
4077
+
4078
+ if (NE > 4) {
4079
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
4080
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
4081
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
4082
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
4083
+ }
4084
+
4085
+ if (NE > 8) {
4086
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
4087
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
4088
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
4089
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
4090
+ }
4091
+
4092
+ if (NE > 16) {
4093
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
4094
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
4095
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
4096
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
4097
+ }
3842
4098
  }
3843
4099
 
3844
4100
  threadgroup_barrier(mem_flags::mem_threadgroup);
3845
4101
 
3846
4102
  // store results to shared memory
3847
- for (short i = tiisg; i < D16; i += NL) {
3848
- sr4x4[i] = lo[i/NL];
4103
+ for (short i = tiisg; i < DV4; i += NL) {
4104
+ sr4[i] = lo[i/NL];
3849
4105
  }
3850
4106
 
3851
4107
  threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3853,18 +4109,18 @@ kernel void kernel_flash_attn_ext_vec(
3853
4109
  // parallel reduce
3854
4110
  for (short r = nsg/2; r > 0; r >>= 1) {
3855
4111
  if (sgitg < r) {
3856
- const half S0 = ss[ 0];
3857
- const half S1 = ss[r*SH + 0];
4112
+ const float S0 = ss[ 0];
4113
+ const float S1 = ss[r*(SH/2) + 0];
3858
4114
 
3859
- const half M0 = ss[ 1];
3860
- const half M1 = ss[r*SH + 1];
4115
+ const float M0 = ss[ 1];
4116
+ const float M1 = ss[r*(SH/2) + 1];
3861
4117
 
3862
- const half M = max(M0, M1);
4118
+ const float M = max(M0, M1);
3863
4119
 
3864
- const half ms0 = exp(M0 - M);
3865
- const half ms1 = exp(M1 - M);
4120
+ const float ms0 = exp(M0 - M);
4121
+ const float ms1 = exp(M1 - M);
3866
4122
 
3867
- const half S = S0*ms0 + S1*ms1;
4123
+ const float S = S0*ms0 + S1*ms1;
3868
4124
 
3869
4125
  if (tiisg == 0) {
3870
4126
  ss[0] = S;
@@ -3872,22 +4128,22 @@ kernel void kernel_flash_attn_ext_vec(
3872
4128
  }
3873
4129
 
3874
4130
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3875
- for (short i = tiisg; i < D16; i += NW) {
3876
- sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
4131
+ for (short i = tiisg; i < DV4; i += NW) {
4132
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
3877
4133
  }
3878
4134
  }
3879
4135
 
3880
4136
  threadgroup_barrier(mem_flags::mem_threadgroup);
3881
4137
  }
3882
4138
 
3883
- device float4x4 * dst44 = (device float4x4 *) dst;
4139
+ device float4 * dst4 = (device float4 *) dst;
3884
4140
 
3885
4141
  // final rescale with 1/S and store to global memory
3886
4142
  if (sgitg == 0) {
3887
4143
  const float S = ss[0];
3888
4144
 
3889
- for (short i = tiisg; i < D16; i += NW) {
3890
- dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
4145
+ for (short i = tiisg; i < DV4; i += NW) {
4146
+ dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
3891
4147
  }
3892
4148
  }
3893
4149
  }
@@ -3896,34 +4152,84 @@ kernel void kernel_flash_attn_ext_vec(
3896
4152
  // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3897
4153
  //
3898
4154
  #define FA_TYPES \
3899
- half4, half4x4, \
3900
- half4x4, \
3901
- half4x4, \
3902
- float, \
3903
- half, half4, half4x4, \
3904
- half4x4
4155
+ half4, \
4156
+ half4, \
4157
+ half4, \
4158
+ float, \
4159
+ float, float4, \
4160
+ half4
4161
+
4162
+ typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
4163
+
4164
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
4165
+ #if defined(GGML_METAL_USE_BF16)
4166
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
4167
+ #endif
4168
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 8>;
4169
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 8>;
4170
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 8>;
4171
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 8>;
4172
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 8>;
4173
+
4174
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
4175
+ #if defined(GGML_METAL_USE_BF16)
4176
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
4177
+ #endif
4178
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
4179
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
4180
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
4181
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
4182
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
4183
+
4184
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
4185
+ #if defined(GGML_METAL_USE_BF16)
4186
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
4187
+ #endif
4188
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 4>;
4189
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 4>;
4190
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 4>;
4191
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 4>;
4192
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
3905
4193
 
3906
- typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
4194
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 4>;
4195
+ #if defined(GGML_METAL_USE_BF16)
4196
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 4>;
4197
+ #endif
4198
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 4>;
4199
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 4>;
4200
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 4>;
4201
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 4>;
4202
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 4>;
4203
+
4204
+ template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
4205
+ #if defined(GGML_METAL_USE_BF16)
4206
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 4>;
4207
+ #endif
4208
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 4>;
4209
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 4>;
4210
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 4>;
4211
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 4>;
4212
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
3907
4213
 
3908
- template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
4214
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
3909
4215
  #if defined(GGML_METAL_USE_BF16)
3910
- template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
4216
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 4>;
3911
4217
  #endif
3912
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
3913
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
3914
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
3915
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
3916
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
4218
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 4>;
4219
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 4>;
4220
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 4>;
4221
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
4222
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
3917
4223
 
3918
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
4224
+ template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
3919
4225
  #if defined(GGML_METAL_USE_BF16)
3920
- template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
4226
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
3921
4227
  #endif
3922
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
3923
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
3924
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
3925
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
3926
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
4228
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
4229
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
4230
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
4231
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
4232
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
3927
4233
 
3928
4234
  #undef FA_TYPES
3929
4235
 
@@ -4298,7 +4604,7 @@ kernel void kernel_cpy_f32_iq4_nl(
4298
4604
  float amax = 0.0f; // absolute max
4299
4605
  float max = 0.0f;
4300
4606
 
4301
- for (int j = 0; j < QK4_0; j++) {
4607
+ for (int j = 0; j < QK4_NL; j++) {
4302
4608
  const float v = src[j];
4303
4609
  if (amax < fabs(v)) {
4304
4610
  amax = fabs(v);
@@ -4332,6 +4638,49 @@ kernel void kernel_cpy_f32_iq4_nl(
4332
4638
  }
4333
4639
  }
4334
4640
 
4641
+ template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
4642
+ kernel void kernel_cpy_q_f32(
4643
+ constant ggml_metal_kargs_cpy & args,
4644
+ device const char * src0,
4645
+ device char * dst,
4646
+ uint3 tgpig[[threadgroup_position_in_grid]],
4647
+ ushort3 tpitg[[thread_position_in_threadgroup]],
4648
+ ushort3 ntg[[threads_per_threadgroup]]) {
4649
+ const int i03 = tgpig[2];
4650
+ const int i02 = tgpig[1];
4651
+ const int i01 = tgpig[0];
4652
+
4653
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4654
+
4655
+ const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
4656
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
4657
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
4658
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
4659
+
4660
+ device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4661
+ device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4662
+
4663
+ for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
4664
+ T4x4 temp;
4665
+ dequantize_func(src_data + i00/nl, i00%nl, temp);
4666
+ dst_data[i00] = temp;
4667
+ }
4668
+ }
4669
+
4670
+ typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
4671
+
4672
+ template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
4673
+ template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
4674
+ template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
4675
+ template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
4676
+ template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
4677
+
4678
+ template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
4679
+ template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
4680
+ template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
4681
+ template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
4682
+ template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
4683
+
4335
4684
  kernel void kernel_concat(
4336
4685
  constant ggml_metal_kargs_concat & args,
4337
4686
  device const char * src0,
@@ -4363,7 +4712,7 @@ kernel void kernel_concat(
4363
4712
  }
4364
4713
  }
4365
4714
 
4366
- template<typename args_t>
4715
+ template<int nr0, int nsg, int nw, typename args_t>
4367
4716
  void kernel_mul_mv_q2_K_f32_impl(
4368
4717
  args_t args,
4369
4718
  device const char * src0,
@@ -4379,7 +4728,7 @@ void kernel_mul_mv_q2_K_f32_impl(
4379
4728
  const int r1 = tgpig.y;
4380
4729
  const int im = tgpig.z;
4381
4730
 
4382
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4731
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4383
4732
 
4384
4733
  const uint i12 = im%args.ne12;
4385
4734
  const uint i13 = im/args.ne12;
@@ -4391,20 +4740,19 @@ void kernel_mul_mv_q2_K_f32_impl(
4391
4740
  device const float * y = (device const float *) (src1 + offset1);
4392
4741
 
4393
4742
  float yl[32];
4394
- float sumf[N_DST]={0.f}, all_sum;
4743
+ float sumf[nr0]={0.f};
4395
4744
 
4396
- const int ix = tiisg/8; // 0...3
4397
- const int it = tiisg%8; // 0...7
4398
- const int iq = it/4; // 0 or 1
4399
- const int ir = it%4; // 0...3
4400
- const int is = (8*ir)/16;// 0 or 1
4745
+ const short ix = tiisg/8; // 0...3
4746
+ const short it = tiisg%8; // 0...7
4747
+ const short iq = it/4; // 0 or 1
4748
+ const short ir = it%4; // 0...3
4749
+ const short is = (8*ir)/16;// 0 or 1
4401
4750
 
4402
4751
  device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
4403
4752
 
4404
4753
  for (int ib = ix; ib < nb; ib += 4) {
4405
-
4406
4754
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4407
- for (int i = 0; i < 8; ++i) {
4755
+ for (short i = 0; i < 8; ++i) {
4408
4756
  yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
4409
4757
  yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
4410
4758
  yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
@@ -4415,8 +4763,7 @@ void kernel_mul_mv_q2_K_f32_impl(
4415
4763
  device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
4416
4764
  device const half * dh = &x[ib].d;
4417
4765
 
4418
- for (int row = 0; row < N_DST; row++) {
4419
-
4766
+ for (short row = 0; row < nr0; row++) {
4420
4767
  float4 acc1 = {0.f, 0.f, 0.f, 0.f};
4421
4768
  float4 acc2 = {0.f, 0.f, 0.f, 0.f};
4422
4769
  for (int i = 0; i < 8; i += 2) {
@@ -4447,10 +4794,10 @@ void kernel_mul_mv_q2_K_f32_impl(
4447
4794
 
4448
4795
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4449
4796
 
4450
- for (int row = 0; row < N_DST; ++row) {
4451
- all_sum = simd_sum(sumf[row]);
4797
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
4798
+ float sum_all = simd_sum(sumf[row]);
4452
4799
  if (tiisg == 0) {
4453
- dst_f32[first_row + row] = all_sum;
4800
+ dst_f32[first_row + row] = sum_all;
4454
4801
  }
4455
4802
  }
4456
4803
  }
@@ -4465,10 +4812,10 @@ kernel void kernel_mul_mv_q2_K_f32(
4465
4812
  ushort tiisg[[thread_index_in_simdgroup]],
4466
4813
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4467
4814
 
4468
- kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4815
+ kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4469
4816
  }
4470
4817
 
4471
- template<typename args_t>
4818
+ template<int nr0, int nsg, int nw, typename args_t>
4472
4819
  void kernel_mul_mv_q3_K_f32_impl(
4473
4820
  args_t args,
4474
4821
  device const char * src0,
@@ -4485,7 +4832,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4485
4832
  const int r1 = tgpig.y;
4486
4833
  const int im = tgpig.z;
4487
4834
 
4488
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
4835
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4489
4836
 
4490
4837
  const uint i12 = im%args.ne12;
4491
4838
  const uint i13 = im/args.ne12;
@@ -4501,13 +4848,12 @@ void kernel_mul_mv_q3_K_f32_impl(
4501
4848
  //const uint16_t kmask1 = 0x3030;
4502
4849
  //const uint16_t kmask2 = 0x0f0f;
4503
4850
 
4504
- const int tid = tiisg/4;
4505
- const int ix = tiisg%4;
4506
- const int ip = tid/4; // 0 or 1
4507
- const int il = 2*((tid%4)/2); // 0 or 2
4508
- const int ir = tid%2;
4509
- const int n = 8;
4510
- const int l0 = n*ir;
4851
+ const short tid = tiisg/4;
4852
+ const short ix = tiisg%4;
4853
+ const short ip = tid/4; // 0 or 1
4854
+ const short il = 2*((tid%4)/2); // 0 or 2
4855
+ const short ir = tid%2;
4856
+ const short l0 = 8*ir;
4511
4857
 
4512
4858
  // One would think that the Metal compiler would figure out that ip and il can only have
4513
4859
  // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
@@ -4532,8 +4878,8 @@ void kernel_mul_mv_q3_K_f32_impl(
4532
4878
  const uint16_t s_shift1 = 4*ip;
4533
4879
  const uint16_t s_shift2 = s_shift1 + il;
4534
4880
 
4535
- const int q_offset = 32*ip + l0;
4536
- const int y_offset = 128*ip + 32*il + l0;
4881
+ const short q_offset = 32*ip + l0;
4882
+ const short y_offset = 128*ip + 32*il + l0;
4537
4883
 
4538
4884
  device const float * y1 = yy + ix*QK_K + y_offset;
4539
4885
 
@@ -4541,10 +4887,11 @@ void kernel_mul_mv_q3_K_f32_impl(
4541
4887
  thread uint16_t * scales16 = (thread uint16_t *)&scales32;
4542
4888
  thread const int8_t * scales = (thread const int8_t *)&scales32;
4543
4889
 
4544
- float sumf1[2] = {0.f};
4545
- float sumf2[2] = {0.f};
4890
+ float sumf1[nr0] = {0.f};
4891
+ float sumf2[nr0] = {0.f};
4892
+
4546
4893
  for (int i = ix; i < nb; i += 4) {
4547
- for (int l = 0; l < 8; ++l) {
4894
+ for (short l = 0; l < 8; ++l) {
4548
4895
  yl[l+ 0] = y1[l+ 0];
4549
4896
  yl[l+ 8] = y1[l+16];
4550
4897
  yl[l+16] = y1[l+32];
@@ -4556,7 +4903,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4556
4903
  device const uint16_t * a = (device const uint16_t *)(x[i].scales);
4557
4904
  device const half * dh = &x[i].d;
4558
4905
 
4559
- for (int row = 0; row < 2; ++row) {
4906
+ for (short row = 0; row < nr0; ++row) {
4560
4907
  const float d_all = (float)dh[0];
4561
4908
 
4562
4909
  scales16[0] = a[4];
@@ -4567,7 +4914,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4567
4914
  scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
4568
4915
 
4569
4916
  float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
4570
- for (int l = 0; l < n; l += 2) {
4917
+ for (short l = 0; l < 8; l += 2) {
4571
4918
  const int32_t qs = q[l/2];
4572
4919
  s1 += yl[l+0] * (qs & qm[il/2][0]);
4573
4920
  s2 += yl[l+1] * (qs & qm[il/2][1]);
@@ -4582,7 +4929,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4582
4929
  sumf2[row] += d2 * (scales[2] - 32);
4583
4930
 
4584
4931
  s1 = s2 = s3 = s4 = s5 = s6 = 0;
4585
- for (int l = 0; l < n; l += 2) {
4932
+ for (short l = 0; l < 8; l += 2) {
4586
4933
  const int32_t qs = q[l/2+8];
4587
4934
  s1 += yl[l+8] * (qs & qm[il/2][0]);
4588
4935
  s2 += yl[l+9] * (qs & qm[il/2][1]);
@@ -4605,7 +4952,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4605
4952
  y1 += 4 * QK_K;
4606
4953
  }
4607
4954
 
4608
- for (int row = 0; row < 2; ++row) {
4955
+ for (int row = 0; row < nr0; ++row) {
4609
4956
  const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
4610
4957
  sumf1[row] = simd_sum(sumf);
4611
4958
  }
@@ -4613,7 +4960,7 @@ void kernel_mul_mv_q3_K_f32_impl(
4613
4960
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4614
4961
 
4615
4962
  if (tiisg == 0) {
4616
- for (int row = 0; row < 2; ++row) {
4963
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
4617
4964
  dst_f32[first_row + row] = sumf1[row];
4618
4965
  }
4619
4966
  }
@@ -4629,10 +4976,10 @@ kernel void kernel_mul_mv_q3_K_f32(
4629
4976
  ushort tiisg[[thread_index_in_simdgroup]],
4630
4977
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4631
4978
 
4632
- kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4979
+ kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4633
4980
  }
4634
4981
 
4635
- template<typename args_t>
4982
+ template<int nr0, int nsg, int nw, typename args_t>
4636
4983
  void kernel_mul_mv_q4_K_f32_impl(
4637
4984
  args_t args,
4638
4985
  device const char * src0,
@@ -4642,22 +4989,22 @@ void kernel_mul_mv_q4_K_f32_impl(
4642
4989
  uint3 tgpig,
4643
4990
  ushort tiisg,
4644
4991
  ushort sgitg) {
4645
-
4646
4992
  const uint16_t kmask1 = 0x3f3f;
4647
4993
  const uint16_t kmask2 = 0x0f0f;
4648
4994
  const uint16_t kmask3 = 0xc0c0;
4649
4995
 
4650
- const int ix = tiisg/8; // 0...3
4651
- const int it = tiisg%8; // 0...7
4652
- const int iq = it/4; // 0 or 1
4653
- const int ir = it%4; // 0...3
4996
+ const short ix = tiisg/8; // 0...3
4997
+ const short it = tiisg%8; // 0...7
4998
+ const short iq = it/4; // 0 or 1
4999
+ const short ir = it%4; // 0...3
4654
5000
 
4655
5001
  const int nb = args.ne00/QK_K;
5002
+
4656
5003
  const int r0 = tgpig.x;
4657
5004
  const int r1 = tgpig.y;
4658
5005
  const int im = tgpig.z;
4659
- //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4660
- const int first_row = r0 * N_DST;
5006
+
5007
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4661
5008
 
4662
5009
  const uint i12 = im%args.ne12;
4663
5010
  const uint i13 = im/args.ne12;
@@ -4670,7 +5017,8 @@ void kernel_mul_mv_q4_K_f32_impl(
4670
5017
 
4671
5018
  float yl[16];
4672
5019
  float yh[16];
4673
- float sumf[N_DST]={0.f}, all_sum;
5020
+
5021
+ float sumf[nr0]={0.f};
4674
5022
 
4675
5023
  device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
4676
5024
 
@@ -4679,7 +5027,8 @@ void kernel_mul_mv_q4_K_f32_impl(
4679
5027
 
4680
5028
  for (int ib = ix; ib < nb; ib += 4) {
4681
5029
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4682
- for (int i = 0; i < 8; ++i) {
5030
+
5031
+ for (short i = 0; i < 8; ++i) {
4683
5032
  yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
4684
5033
  yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
4685
5034
  yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
@@ -4690,7 +5039,7 @@ void kernel_mul_mv_q4_K_f32_impl(
4690
5039
  device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
4691
5040
  device const half * dh = &x[ib].d;
4692
5041
 
4693
- for (int row = 0; row < N_DST; row++) {
5042
+ for (short row = 0; row < nr0; row++) {
4694
5043
  sc16[0] = sc[0] & kmask1;
4695
5044
  sc16[1] = sc[2] & kmask1;
4696
5045
  sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -4700,19 +5049,21 @@ void kernel_mul_mv_q4_K_f32_impl(
4700
5049
 
4701
5050
  float4 acc1 = {0.f, 0.f, 0.f, 0.f};
4702
5051
  float4 acc2 = {0.f, 0.f, 0.f, 0.f};
4703
- for (int i = 0; i < 8; i += 2) {
4704
- acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
4705
- acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
4706
- acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
4707
- acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
4708
- acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
4709
- acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
4710
- acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
4711
- acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
5052
+
5053
+ for (short i = 0; i < 4; ++i) {
5054
+ acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
5055
+ acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
5056
+ acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
5057
+ acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
5058
+ acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
5059
+ acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
5060
+ acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
5061
+ acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
4712
5062
  }
4713
5063
 
4714
5064
  float dall = dh[0];
4715
5065
  float dmin = dh[1];
5066
+
4716
5067
  sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
4717
5068
  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
4718
5069
  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
@@ -4729,10 +5080,10 @@ void kernel_mul_mv_q4_K_f32_impl(
4729
5080
 
4730
5081
  device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
4731
5082
 
4732
- for (int row = 0; row < N_DST; ++row) {
4733
- all_sum = simd_sum(sumf[row]);
5083
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5084
+ float sum_all = simd_sum(sumf[row]);
4734
5085
  if (tiisg == 0) {
4735
- dst_f32[first_row + row] = all_sum;
5086
+ dst_f32[first_row + row] = sum_all;
4736
5087
  }
4737
5088
  }
4738
5089
  }
@@ -4747,10 +5098,10 @@ kernel void kernel_mul_mv_q4_K_f32(
4747
5098
  ushort tiisg[[thread_index_in_simdgroup]],
4748
5099
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4749
5100
 
4750
- kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5101
+ kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4751
5102
  }
4752
5103
 
4753
- template<typename args_t>
5104
+ template<int nr0, int nsg, int nw, typename args_t>
4754
5105
  void kernel_mul_mv_q5_K_f32_impl(
4755
5106
  args_t args,
4756
5107
  device const char * src0,
@@ -4767,7 +5118,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4767
5118
  const int r1 = tgpig.y;
4768
5119
  const int im = tgpig.z;
4769
5120
 
4770
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
5121
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4771
5122
 
4772
5123
  const uint i12 = im%args.ne12;
4773
5124
  const uint i13 = im/args.ne12;
@@ -4778,7 +5129,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4778
5129
  device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
4779
5130
  device const float * yy = (device const float *) (src1 + offset1);
4780
5131
 
4781
- float sumf[2]={0.f};
5132
+ float sumf[nr0]={0.f};
4782
5133
 
4783
5134
  float yl[16], yh[16];
4784
5135
 
@@ -4786,15 +5137,14 @@ void kernel_mul_mv_q5_K_f32_impl(
4786
5137
  const uint16_t kmask2 = 0x0f0f;
4787
5138
  const uint16_t kmask3 = 0xc0c0;
4788
5139
 
4789
- const int tid = tiisg/4;
4790
- const int ix = tiisg%4;
4791
- const int iq = tid/4;
4792
- const int ir = tid%4;
4793
- const int n = 8;
5140
+ const short tid = tiisg/4;
5141
+ const short ix = tiisg%4;
5142
+ const short iq = tid/4;
5143
+ const short ir = tid%4;
4794
5144
 
4795
- const int l0 = n*ir;
4796
- const int q_offset = 32*iq + l0;
4797
- const int y_offset = 64*iq + l0;
5145
+ const short l0 = 8*ir;
5146
+ const short q_offset = 32*iq + l0;
5147
+ const short y_offset = 64*iq + l0;
4798
5148
 
4799
5149
  const uint8_t hm1 = 1u << (2*iq);
4800
5150
  const uint8_t hm2 = hm1 << 1;
@@ -4814,14 +5164,14 @@ void kernel_mul_mv_q5_K_f32_impl(
4814
5164
 
4815
5165
  device const float * y2 = y1 + 128;
4816
5166
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
4817
- for (int l = 0; l < 8; ++l) {
5167
+ for (short l = 0; l < 8; ++l) {
4818
5168
  yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
4819
5169
  yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
4820
5170
  yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
4821
5171
  yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
4822
5172
  }
4823
5173
 
4824
- for (int row = 0; row < 2; ++row) {
5174
+ for (short row = 0; row < nr0; ++row) {
4825
5175
  device const uint8_t * q2 = q1 + 64;
4826
5176
 
4827
5177
  sc16[0] = a[0] & kmask1;
@@ -4831,7 +5181,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4831
5181
 
4832
5182
  float4 acc1 = {0.f};
4833
5183
  float4 acc2 = {0.f};
4834
- for (int l = 0; l < n; ++l) {
5184
+ for (short l = 0; l < 8; ++l) {
4835
5185
  uint8_t h = qh[l];
4836
5186
  acc1[0] += yl[l+0] * (q1[l] & 0x0F);
4837
5187
  acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -4861,7 +5211,7 @@ void kernel_mul_mv_q5_K_f32_impl(
4861
5211
 
4862
5212
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4863
5213
 
4864
- for (int row = 0; row < 2; ++row) {
5214
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
4865
5215
  const float tot = simd_sum(sumf[row]);
4866
5216
  if (tiisg == 0) {
4867
5217
  dst_f32[first_row + row] = tot;
@@ -4879,10 +5229,10 @@ kernel void kernel_mul_mv_q5_K_f32(
4879
5229
  ushort tiisg[[thread_index_in_simdgroup]],
4880
5230
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4881
5231
 
4882
- kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5232
+ kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4883
5233
  }
4884
5234
 
4885
- template <typename args_t>
5235
+ template<int nr0, int nsg, int nw, typename args_t>
4886
5236
  void kernel_mul_mv_q6_K_f32_impl(
4887
5237
  args_t args,
4888
5238
  device const char * src0,
@@ -4904,58 +5254,77 @@ void kernel_mul_mv_q6_K_f32_impl(
4904
5254
  const int r1 = tgpig.y;
4905
5255
  const int im = tgpig.z;
4906
5256
 
4907
- const int row = 2*r0 + sgitg;
5257
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4908
5258
 
4909
5259
  const uint i12 = im%args.ne12;
4910
5260
  const uint i13 = im/args.ne12;
4911
5261
 
4912
- const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4913
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
5262
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
5263
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
4914
5264
 
4915
5265
  device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
4916
5266
  device const float * yy = (device const float *) (src1 + offset1);
4917
5267
 
4918
- float sumf = 0;
5268
+ float sumf[nr0] = { 0.f };
4919
5269
 
4920
- const int tid = tiisg/2;
4921
- const int ix = tiisg%2;
4922
- const int ip = tid/8; // 0 or 1
4923
- const int il = tid%8;
4924
- const int n = 4;
4925
- const int l0 = n*il;
4926
- const int is = 8*ip + l0/16;
5270
+ float yl[16];
5271
+
5272
+ const short tid = tiisg/2;
5273
+ const short ix = tiisg%2;
5274
+ const short ip = tid/8; // 0 or 1
5275
+ const short il = tid%8;
5276
+ const short l0 = 4*il;
5277
+ const short is = 8*ip + l0/16;
4927
5278
 
4928
- const int y_offset = 128*ip + l0;
4929
- const int q_offset_l = 64*ip + l0;
4930
- const int q_offset_h = 32*ip + l0;
5279
+ const short y_offset = 128*ip + l0;
5280
+ const short q_offset_l = 64*ip + l0;
5281
+ const short q_offset_h = 32*ip + l0;
4931
5282
 
4932
5283
  for (int i = ix; i < nb; i += 2) {
4933
5284
  device const uint8_t * q1 = x[i].ql + q_offset_l;
4934
5285
  device const uint8_t * q2 = q1 + 32;
4935
5286
  device const uint8_t * qh = x[i].qh + q_offset_h;
4936
5287
  device const int8_t * sc = x[i].scales + is;
5288
+ device const half * dh = &x[i].d;
4937
5289
 
4938
5290
  device const float * y = yy + i * QK_K + y_offset;
4939
5291
 
4940
- const float dall = x[i].d;
4941
-
4942
- float4 sums = {0.f, 0.f, 0.f, 0.f};
4943
- for (int l = 0; l < n; ++l) {
4944
- sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
4945
- sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
4946
- sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
4947
- sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
5292
+ for (short l = 0; l < 4; ++l) {
5293
+ yl[4*l + 0] = y[l + 0];
5294
+ yl[4*l + 1] = y[l + 32];
5295
+ yl[4*l + 2] = y[l + 64];
5296
+ yl[4*l + 3] = y[l + 96];
4948
5297
  }
4949
5298
 
4950
- sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
5299
+ for (short row = 0; row < nr0; ++row) {
5300
+ const float dall = dh[0];
4951
5301
 
5302
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
5303
+
5304
+ for (short l = 0; l < 4; ++l) {
5305
+ sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
5306
+ sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
5307
+ sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
5308
+ sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
5309
+ }
5310
+
5311
+ sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
5312
+
5313
+ q1 += args.nb01;
5314
+ q2 += args.nb01;
5315
+ qh += args.nb01;
5316
+ sc += args.nb01;
5317
+ dh += args.nb01/2;
5318
+ }
4952
5319
  }
4953
5320
 
4954
5321
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
4955
5322
 
4956
- const float tot = simd_sum(sumf);
4957
- if (tiisg == 0) {
4958
- dst_f32[row] = tot;
5323
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5324
+ float sum_all = simd_sum(sumf[row]);
5325
+ if (tiisg == 0) {
5326
+ dst_f32[first_row + row] = sum_all;
5327
+ }
4959
5328
  }
4960
5329
  }
4961
5330
 
@@ -4969,12 +5338,12 @@ kernel void kernel_mul_mv_q6_K_f32(
4969
5338
  ushort tiisg[[thread_index_in_simdgroup]],
4970
5339
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4971
5340
 
4972
- kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5341
+ kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
4973
5342
  }
4974
5343
 
4975
5344
  // ======================= "True" 2-bit
4976
5345
 
4977
- template<typename args_t>
5346
+ template<int nr0, int nsg, int nw, typename args_t>
4978
5347
  void kernel_mul_mv_iq2_xxs_f32_impl(
4979
5348
  args_t args,
4980
5349
  device const char * src0,
@@ -4990,7 +5359,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4990
5359
  const int r1 = tgpig.y;
4991
5360
  const int im = tgpig.z;
4992
5361
 
4993
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5362
+ const int first_row = (r0 * nsg + sgitg) * nr0;
4994
5363
 
4995
5364
  const uint i12 = im%args.ne12;
4996
5365
  const uint i13 = im/args.ne12;
@@ -5002,7 +5371,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5002
5371
  device const float * y = (device const float *) (src1 + offset1);
5003
5372
 
5004
5373
  float yl[32];
5005
- float sumf[N_DST]={0.f}, all_sum;
5374
+ float sumf[nr0]={0.f};
5006
5375
 
5007
5376
  const int nb32 = nb * (QK_K / 32);
5008
5377
 
@@ -5023,8 +5392,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5023
5392
  device const float * y4 = y + 32 * ix;
5024
5393
 
5025
5394
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5026
-
5027
- for (int i = 0; i < 32; ++i) {
5395
+ for (short i = 0; i < 32; ++i) {
5028
5396
  yl[i] = y4[i];
5029
5397
  }
5030
5398
 
@@ -5035,18 +5403,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5035
5403
  device const uint16_t * q2 = xr->qs + 4 * ib;
5036
5404
  device const half * dh = &xr->d;
5037
5405
 
5038
- for (int row = 0; row < N_DST; row++) {
5039
-
5406
+ for (short row = 0; row < nr0; row++) {
5040
5407
  const float db = dh[0];
5041
5408
  device const uint8_t * aux8 = (device const uint8_t *)q2;
5042
5409
  const uint32_t aux32 = q2[2] | (q2[3] << 16);
5043
5410
  const float d = db * (0.5f + (aux32 >> 28));
5044
5411
 
5045
5412
  float sum = 0;
5046
- for (int l = 0; l < 4; ++l) {
5413
+ for (short l = 0; l < 4; ++l) {
5047
5414
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
5048
5415
  const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
5049
- for (int j = 0; j < 8; ++j) {
5416
+ for (short j = 0; j < 8; ++j) {
5050
5417
  sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5051
5418
  }
5052
5419
  }
@@ -5061,10 +5428,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
5061
5428
 
5062
5429
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5063
5430
 
5064
- for (int row = 0; row < N_DST; ++row) {
5065
- all_sum = simd_sum(sumf[row]);
5431
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5432
+ float sum_all = simd_sum(sumf[row]);
5066
5433
  if (tiisg == 0) {
5067
- dst_f32[first_row + row] = all_sum * 0.25f;
5434
+ dst_f32[first_row + row] = sum_all * 0.25f;
5068
5435
  }
5069
5436
  }
5070
5437
  }
@@ -5079,10 +5446,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
5079
5446
  uint3 tgpig[[threadgroup_position_in_grid]],
5080
5447
  ushort tiisg[[thread_index_in_simdgroup]],
5081
5448
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5082
- kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5449
+ kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5083
5450
  }
5084
5451
 
5085
- template<typename args_t>
5452
+ template<int nr0, int nsg, int nw, typename args_t>
5086
5453
  void kernel_mul_mv_iq2_xs_f32_impl(
5087
5454
  args_t args,
5088
5455
  device const char * src0,
@@ -5098,7 +5465,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5098
5465
  const int r1 = tgpig.y;
5099
5466
  const int im = tgpig.z;
5100
5467
 
5101
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5468
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5102
5469
 
5103
5470
  const uint i12 = im%args.ne12;
5104
5471
  const uint i13 = im/args.ne12;
@@ -5110,7 +5477,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5110
5477
  device const float * y = (device const float *) (src1 + offset1);
5111
5478
 
5112
5479
  float yl[32];
5113
- float sumf[N_DST]={0.f}, all_sum;
5480
+ float sumf[nr0]={0.f};
5114
5481
 
5115
5482
  const int nb32 = nb * (QK_K / 32);
5116
5483
 
@@ -5131,8 +5498,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5131
5498
  device const float * y4 = y + 32 * ix;
5132
5499
 
5133
5500
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5134
-
5135
- for (int i = 0; i < 32; ++i) {
5501
+ for (short i = 0; i < 32; ++i) {
5136
5502
  yl[i] = y4[i];
5137
5503
  }
5138
5504
 
@@ -5144,8 +5510,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5144
5510
  device const uint8_t * sc = xr->scales + ib;
5145
5511
  device const half * dh = &xr->d;
5146
5512
 
5147
- for (int row = 0; row < N_DST; row++) {
5148
-
5513
+ for (short row = 0; row < nr0; row++) {
5149
5514
  const float db = dh[0];
5150
5515
  const uint8_t ls1 = sc[0] & 0xf;
5151
5516
  const uint8_t ls2 = sc[0] >> 4;
@@ -5153,17 +5518,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5153
5518
  const float d2 = db * (0.5f + ls2);
5154
5519
 
5155
5520
  float sum1 = 0, sum2 = 0;
5156
- for (int l = 0; l < 2; ++l) {
5521
+ for (short l = 0; l < 2; ++l) {
5157
5522
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
5158
5523
  const uint8_t signs = ssigns[(q2[l] >> 9)];
5159
- for (int j = 0; j < 8; ++j) {
5524
+ for (short j = 0; j < 8; ++j) {
5160
5525
  sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5161
5526
  }
5162
5527
  }
5163
- for (int l = 2; l < 4; ++l) {
5528
+ for (short l = 2; l < 4; ++l) {
5164
5529
  const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
5165
5530
  const uint8_t signs = ssigns[(q2[l] >> 9)];
5166
- for (int j = 0; j < 8; ++j) {
5531
+ for (short j = 0; j < 8; ++j) {
5167
5532
  sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
5168
5533
  }
5169
5534
  }
@@ -5179,10 +5544,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
5179
5544
 
5180
5545
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5181
5546
 
5182
- for (int row = 0; row < N_DST; ++row) {
5183
- all_sum = simd_sum(sumf[row]);
5547
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5548
+ float sum_all = simd_sum(sumf[row]);
5184
5549
  if (tiisg == 0) {
5185
- dst_f32[first_row + row] = all_sum * 0.25f;
5550
+ dst_f32[first_row + row] = sum_all * 0.25f;
5186
5551
  }
5187
5552
  }
5188
5553
  }
@@ -5198,10 +5563,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
5198
5563
  ushort tiisg[[thread_index_in_simdgroup]],
5199
5564
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5200
5565
 
5201
- kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5566
+ kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5202
5567
  }
5203
5568
 
5204
- template <typename args_t>
5569
+ template<int nr0, int nsg, int nw, typename args_t>
5205
5570
  void kernel_mul_mv_iq3_xxs_f32_impl(
5206
5571
  args_t args,
5207
5572
  device const char * src0,
@@ -5217,7 +5582,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5217
5582
  const int r1 = tgpig.y;
5218
5583
  const int im = tgpig.z;
5219
5584
 
5220
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5585
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5221
5586
 
5222
5587
  const uint i12 = im%args.ne12;
5223
5588
  const uint i13 = im/args.ne12;
@@ -5229,7 +5594,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5229
5594
  device const float * y = (device const float *) (src1 + offset1);
5230
5595
 
5231
5596
  float yl[32];
5232
- float sumf[N_DST]={0.f}, all_sum;
5597
+ float sumf[nr0]={0.f};
5233
5598
 
5234
5599
  const int nb32 = nb * (QK_K / 32);
5235
5600
 
@@ -5250,7 +5615,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5250
5615
  device const float * y4 = y + 32 * ix;
5251
5616
 
5252
5617
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5253
- for (int i = 0; i < 32; ++i) {
5618
+ for (short i = 0; i < 32; ++i) {
5254
5619
  yl[i] = y4[i];
5255
5620
  }
5256
5621
 
@@ -5262,17 +5627,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5262
5627
  device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
5263
5628
  device const half * dh = &xr->d;
5264
5629
 
5265
- for (int row = 0; row < N_DST; row++) {
5630
+ for (short row = 0; row < nr0; row++) {
5266
5631
  const float db = dh[0];
5267
5632
  const uint32_t aux32 = gas[0] | (gas[1] << 16);
5268
5633
  const float d = db * (0.5f + (aux32 >> 28));
5269
5634
 
5270
5635
  float2 sum = {0};
5271
- for (int l = 0; l < 4; ++l) {
5636
+ for (short l = 0; l < 4; ++l) {
5272
5637
  const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
5273
5638
  const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
5274
5639
  const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
5275
- for (int j = 0; j < 4; ++j) {
5640
+ for (short j = 0; j < 4; ++j) {
5276
5641
  sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
5277
5642
  sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
5278
5643
  }
@@ -5289,10 +5654,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
5289
5654
 
5290
5655
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5291
5656
 
5292
- for (int row = 0; row < N_DST; ++row) {
5293
- all_sum = simd_sum(sumf[row]);
5657
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5658
+ float sum_all = simd_sum(sumf[row]);
5294
5659
  if (tiisg == 0) {
5295
- dst_f32[first_row + row] = all_sum * 0.5f;
5660
+ dst_f32[first_row + row] = sum_all * 0.5f;
5296
5661
  }
5297
5662
  }
5298
5663
  }
@@ -5308,10 +5673,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
5308
5673
  ushort tiisg[[thread_index_in_simdgroup]],
5309
5674
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5310
5675
 
5311
- kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5676
+ kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5312
5677
  }
5313
5678
 
5314
- template<typename args_t>
5679
+ template<int nr0, int nsg, int nw, typename args_t>
5315
5680
  void kernel_mul_mv_iq3_s_f32_impl(
5316
5681
  args_t args,
5317
5682
  device const char * src0,
@@ -5327,7 +5692,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
5327
5692
  const int r1 = tgpig.y;
5328
5693
  const int im = tgpig.z;
5329
5694
 
5330
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5695
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5331
5696
 
5332
5697
  const uint i12 = im%args.ne12;
5333
5698
  const uint i13 = im/args.ne12;
@@ -5339,7 +5704,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
5339
5704
  device const float * y = (device const float *) (src1 + offset1);
5340
5705
 
5341
5706
  float yl[32];
5342
- float sumf[N_DST]={0.f}, all_sum;
5707
+ float sumf[nr0]={0.f};
5343
5708
 
5344
5709
  const int nb32 = nb * (QK_K / 32);
5345
5710
 
@@ -5356,8 +5721,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
5356
5721
  device const float * y4 = y + 32 * ix;
5357
5722
 
5358
5723
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5359
-
5360
- for (int i = 0; i < 32; ++i) {
5724
+ for (short i = 0; i < 32; ++i) {
5361
5725
  yl[i] = y4[i];
5362
5726
  }
5363
5727
 
@@ -5371,18 +5735,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
5371
5735
  device const uint8_t * signs = xr->signs + 4 * ib;
5372
5736
  device const half * dh = &xr->d;
5373
5737
 
5374
- for (int row = 0; row < N_DST; row++) {
5375
-
5738
+ for (short row = 0; row < nr0; row++) {
5376
5739
  const float db = dh[0];
5377
5740
  const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
5378
5741
 
5379
5742
  float2 sum = {0};
5380
- for (int l = 0; l < 4; ++l) {
5743
+ for (short l = 0; l < 4; ++l) {
5381
5744
  const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
5382
5745
  const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
5383
5746
  const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
5384
5747
  const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
5385
- for (int j = 0; j < 4; ++j) {
5748
+ for (short j = 0; j < 4; ++j) {
5386
5749
  sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
5387
5750
  sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
5388
5751
  }
@@ -5401,10 +5764,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
5401
5764
 
5402
5765
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5403
5766
 
5404
- for (int row = 0; row < N_DST; ++row) {
5405
- all_sum = simd_sum(sumf[row]);
5767
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5768
+ float sum_all = simd_sum(sumf[row]);
5406
5769
  if (tiisg == 0) {
5407
- dst_f32[first_row + row] = all_sum;
5770
+ dst_f32[first_row + row] = sum_all;
5408
5771
  }
5409
5772
  }
5410
5773
  }
@@ -5420,10 +5783,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
5420
5783
  ushort tiisg[[thread_index_in_simdgroup]],
5421
5784
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5422
5785
 
5423
- kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5786
+ kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5424
5787
  }
5425
5788
 
5426
- template <typename args_t>
5789
+ template<int nr0, int nsg, int nw, typename args_t>
5427
5790
  void kernel_mul_mv_iq2_s_f32_impl(
5428
5791
  args_t args,
5429
5792
  device const char * src0,
@@ -5439,7 +5802,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
5439
5802
  const int r1 = tgpig.y;
5440
5803
  const int im = tgpig.z;
5441
5804
 
5442
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5805
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5443
5806
 
5444
5807
  const uint i12 = im%args.ne12;
5445
5808
  const uint i13 = im/args.ne12;
@@ -5451,7 +5814,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
5451
5814
  device const float * y = (device const float *) (src1 + offset1);
5452
5815
 
5453
5816
  float yl[32];
5454
- float sumf[N_DST]={0.f}, all_sum;
5817
+ float sumf[nr0]={0.f};
5455
5818
 
5456
5819
  const int nb32 = nb * (QK_K / 32);
5457
5820
 
@@ -5463,13 +5826,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
5463
5826
  // threadgroup_barrier(mem_flags::mem_threadgroup);
5464
5827
  //}
5465
5828
 
5466
- const int ix = tiisg;
5829
+ const short ix = tiisg;
5467
5830
 
5468
5831
  device const float * y4 = y + 32 * ix;
5469
5832
 
5470
5833
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5471
-
5472
- for (int i = 0; i < 32; ++i) {
5834
+ for (short i = 0; i < 32; ++i) {
5473
5835
  yl[i] = y4[i];
5474
5836
  }
5475
5837
 
@@ -5483,19 +5845,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
5483
5845
  device const uint8_t * signs = qs + QK_K/8;
5484
5846
  device const half * dh = &xr->d;
5485
5847
 
5486
- for (int row = 0; row < N_DST; row++) {
5487
-
5848
+ for (short row = 0; row < nr0; row++) {
5488
5849
  const float db = dh[0];
5489
5850
  const float d1 = db * (0.5f + (sc[0] & 0xf));
5490
5851
  const float d2 = db * (0.5f + (sc[0] >> 4));
5491
5852
 
5492
5853
  float2 sum = {0};
5493
- for (int l = 0; l < 2; ++l) {
5854
+ for (short l = 0; l < 2; ++l) {
5494
5855
  //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
5495
5856
  //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
5496
5857
  constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
5497
5858
  constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
5498
- for (int j = 0; j < 8; ++j) {
5859
+ for (short j = 0; j < 8; ++j) {
5499
5860
  sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
5500
5861
  sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
5501
5862
  }
@@ -5514,10 +5875,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
5514
5875
 
5515
5876
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5516
5877
 
5517
- for (int row = 0; row < N_DST; ++row) {
5518
- all_sum = simd_sum(sumf[row]);
5878
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5879
+ float sum_all = simd_sum(sumf[row]);
5519
5880
  if (tiisg == 0) {
5520
- dst_f32[first_row + row] = all_sum * 0.25f;
5881
+ dst_f32[first_row + row] = sum_all * 0.25f;
5521
5882
  }
5522
5883
  }
5523
5884
  }
@@ -5533,10 +5894,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
5533
5894
  ushort tiisg[[thread_index_in_simdgroup]],
5534
5895
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5535
5896
 
5536
- kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5897
+ kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5537
5898
  }
5538
5899
 
5539
- template<typename args_t>
5900
+ template<int nr0, int nsg, int nw, typename args_t>
5540
5901
  void kernel_mul_mv_iq1_s_f32_impl(
5541
5902
  args_t args,
5542
5903
  device const char * src0,
@@ -5552,7 +5913,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
5552
5913
  const int r1 = tgpig.y;
5553
5914
  const int im = tgpig.z;
5554
5915
 
5555
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5916
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5556
5917
 
5557
5918
  const uint i12 = im%args.ne12;
5558
5919
  const uint i13 = im/args.ne12;
@@ -5564,18 +5925,17 @@ void kernel_mul_mv_iq1_s_f32_impl(
5564
5925
  device const float * y = (device const float *) (src1 + offset1);
5565
5926
 
5566
5927
  float yl[32];
5567
- float sumf[N_DST]={0.f}, all_sum;
5928
+ float sumf[nr0]={0.f};
5568
5929
 
5569
5930
  const int nb32 = nb * (QK_K / 32);
5570
5931
 
5571
- const int ix = tiisg;
5932
+ const short ix = tiisg;
5572
5933
 
5573
5934
  device const float * y4 = y + 32 * ix;
5574
5935
 
5575
5936
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5576
-
5577
5937
  float sumy = 0;
5578
- for (int i = 0; i < 32; ++i) {
5938
+ for (short i = 0; i < 32; ++i) {
5579
5939
  yl[i] = y4[i];
5580
5940
  sumy += yl[i];
5581
5941
  }
@@ -5588,15 +5948,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
5588
5948
  device const uint16_t * qh = xr->qh + ib;
5589
5949
  device const half * dh = &xr->d;
5590
5950
 
5591
- for (int row = 0; row < N_DST; row++) {
5592
-
5951
+ for (short row = 0; row < nr0; row++) {
5593
5952
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5594
5953
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
5595
5954
  constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
5596
5955
  constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
5597
5956
 
5598
5957
  float sum = 0;
5599
- for (int j = 0; j < 4; ++j) {
5958
+ for (short j = 0; j < 4; ++j) {
5600
5959
  sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
5601
5960
  + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
5602
5961
  + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5614,15 +5973,28 @@ void kernel_mul_mv_iq1_s_f32_impl(
5614
5973
 
5615
5974
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5616
5975
 
5617
- for (int row = 0; row < N_DST; ++row) {
5618
- all_sum = simd_sum(sumf[row]);
5976
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
5977
+ float sum_all = simd_sum(sumf[row]);
5619
5978
  if (tiisg == 0) {
5620
- dst_f32[first_row + row] = all_sum;
5979
+ dst_f32[first_row + row] = sum_all;
5621
5980
  }
5622
5981
  }
5623
5982
  }
5624
5983
 
5625
- template <typename args_t>
5984
+ [[host_name("kernel_mul_mv_iq1_s_f32")]]
5985
+ kernel void kernel_mul_mv_iq1_s_f32(
5986
+ constant ggml_metal_kargs_mul_mv & args,
5987
+ device const char * src0,
5988
+ device const char * src1,
5989
+ device char * dst,
5990
+ uint3 tgpig[[threadgroup_position_in_grid]],
5991
+ ushort tiisg[[thread_index_in_simdgroup]],
5992
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5993
+
5994
+ kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5995
+ }
5996
+
5997
+ template<int nr0, int nsg, int nw, typename args_t>
5626
5998
  void kernel_mul_mv_iq1_m_f32_impl(
5627
5999
  args_t args,
5628
6000
  device const char * src0,
@@ -5634,11 +6006,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
5634
6006
  ushort sgitg) {
5635
6007
 
5636
6008
  const int nb = args.ne00/QK_K;
6009
+
5637
6010
  const int r0 = tgpig.x;
5638
6011
  const int r1 = tgpig.y;
5639
6012
  const int im = tgpig.z;
5640
6013
 
5641
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
6014
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5642
6015
 
5643
6016
  const uint i12 = im%args.ne12;
5644
6017
  const uint i13 = im/args.ne12;
@@ -5650,20 +6023,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
5650
6023
  device const float * y = (device const float *) (src1 + offset1);
5651
6024
 
5652
6025
  float yl[32];
5653
- float sumf[N_DST]={0.f}, all_sum;
6026
+ float sumf[nr0]={0.f};
5654
6027
 
5655
6028
  const int nb32 = nb * (QK_K / 32);
5656
6029
 
5657
- const int ix = tiisg;
6030
+ const short ix = tiisg;
5658
6031
 
5659
6032
  device const float * y4 = y + 32 * ix;
5660
6033
 
5661
6034
  iq1m_scale_t scale;
5662
6035
 
5663
6036
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5664
-
5665
6037
  float4 sumy = {0.f};
5666
- for (int i = 0; i < 8; ++i) {
6038
+ for (short i = 0; i < 8; ++i) {
5667
6039
  yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
5668
6040
  yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
5669
6041
  yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
@@ -5678,7 +6050,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5678
6050
  device const uint8_t * qh = xr->qh + 2 * ib;
5679
6051
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5680
6052
 
5681
- for (int row = 0; row < N_DST; row++) {
6053
+ for (short row = 0; row < nr0; row++) {
5682
6054
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5683
6055
 
5684
6056
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -5687,7 +6059,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5687
6059
  constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
5688
6060
 
5689
6061
  float2 sum = {0.f};
5690
- for (int j = 0; j < 4; ++j) {
6062
+ for (short j = 0; j < 4; ++j) {
5691
6063
  sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
5692
6064
  + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
5693
6065
  sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5709,15 +6081,28 @@ void kernel_mul_mv_iq1_m_f32_impl(
5709
6081
 
5710
6082
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5711
6083
 
5712
- for (int row = 0; row < N_DST; ++row) {
5713
- all_sum = simd_sum(sumf[row]);
6084
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
6085
+ float sum_all = simd_sum(sumf[row]);
5714
6086
  if (tiisg == 0) {
5715
- dst_f32[first_row + row] = all_sum;
6087
+ dst_f32[first_row + row] = sum_all;
5716
6088
  }
5717
6089
  }
5718
6090
  }
5719
6091
 
5720
- template<typename args_t>
6092
+ [[host_name("kernel_mul_mv_iq1_m_f32")]]
6093
+ kernel void kernel_mul_mv_iq1_m_f32(
6094
+ constant ggml_metal_kargs_mul_mv & args,
6095
+ device const char * src0,
6096
+ device const char * src1,
6097
+ device char * dst,
6098
+ uint3 tgpig[[threadgroup_position_in_grid]],
6099
+ ushort tiisg[[thread_index_in_simdgroup]],
6100
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6101
+
6102
+ kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
6103
+ }
6104
+
6105
+ template<int nr0, int nsg, int nw, typename args_t>
5721
6106
  void kernel_mul_mv_iq4_nl_f32_impl(
5722
6107
  args_t args,
5723
6108
  device const char * src0,
@@ -5730,10 +6115,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5730
6115
 
5731
6116
  threadgroup float * shmem_f32 = (threadgroup float *) shmem;
5732
6117
  const int nb = args.ne00/QK4_NL;
6118
+
5733
6119
  const int r0 = tgpig.x;
5734
6120
  const int r1 = tgpig.y;
5735
6121
  const int im = tgpig.z;
5736
- const int first_row = (r0 * 2 + sgitg) * 2;
6122
+
6123
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5737
6124
 
5738
6125
  const uint i12 = im%args.ne12;
5739
6126
  const uint i13 = im/args.ne12;
@@ -5744,14 +6131,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5744
6131
  device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
5745
6132
  device const float * y = (device const float *) (src1 + offset1);
5746
6133
 
5747
- const int ix = tiisg/2; // 0...15
5748
- const int it = tiisg%2; // 0 or 1
6134
+ const short ix = tiisg/2; // 0...15
6135
+ const short it = tiisg%2; // 0 or 1
5749
6136
 
5750
6137
  shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
5751
6138
  threadgroup_barrier(mem_flags::mem_threadgroup);
5752
6139
 
5753
6140
  float4 yl[4];
5754
- float sumf[2]={0.f}, all_sum;
6141
+ float sumf[nr0]={0.f};
5755
6142
 
5756
6143
  device const float * yb = y + ix * QK4_NL + it * 8;
5757
6144
 
@@ -5761,12 +6148,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5761
6148
  float4 qf1, qf2;
5762
6149
 
5763
6150
  for (int ib = ix; ib < nb; ib += 16) {
5764
-
5765
6151
  device const float4 * y4 = (device const float4 *)yb;
5766
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
5767
-
5768
- for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
6152
+ yl[0] = y4[0];
6153
+ yl[1] = y4[4];
6154
+ yl[2] = y4[1];
6155
+ yl[3] = y4[5];
5769
6156
 
6157
+ for (short row = 0; row < nr0; row++) {
5770
6158
  device const block_iq4_nl & xb = x[row*nb + ib];
5771
6159
  device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
5772
6160
 
@@ -5791,7 +6179,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5791
6179
  acc1 += acc2;
5792
6180
 
5793
6181
  sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
5794
-
5795
6182
  }
5796
6183
 
5797
6184
  yb += 16 * QK4_NL;
@@ -5799,15 +6186,29 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5799
6186
 
5800
6187
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5801
6188
 
5802
- for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
5803
- all_sum = simd_sum(sumf[row]);
6189
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
6190
+ float sum_all = simd_sum(sumf[row]);
5804
6191
  if (tiisg == 0) {
5805
- dst_f32[first_row + row] = all_sum;
6192
+ dst_f32[first_row + row] = sum_all;
5806
6193
  }
5807
6194
  }
5808
6195
  }
5809
6196
 
5810
- template<typename args_t>
6197
+ [[host_name("kernel_mul_mv_iq4_nl_f32")]]
6198
+ kernel void kernel_mul_mv_iq4_nl_f32(
6199
+ constant ggml_metal_kargs_mul_mv & args,
6200
+ device const char * src0,
6201
+ device const char * src1,
6202
+ device char * dst,
6203
+ threadgroup char * shmem [[threadgroup(0)]],
6204
+ uint3 tgpig[[threadgroup_position_in_grid]],
6205
+ ushort tiisg[[thread_index_in_simdgroup]],
6206
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6207
+
6208
+ kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
6209
+ }
6210
+
6211
+ template<int nr0, int nsg, int nw, typename args_t>
5811
6212
  void kernel_mul_mv_iq4_xs_f32_impl(
5812
6213
  args_t args,
5813
6214
  device const char * src0,
@@ -5823,7 +6224,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5823
6224
  const int r0 = tgpig.x;
5824
6225
  const int r1 = tgpig.y;
5825
6226
  const int im = tgpig.z;
5826
- const int first_row = (r0 * 2 + sgitg) * 2;
6227
+ const int first_row = (r0 * nsg + sgitg) * nr0;
5827
6228
 
5828
6229
  const uint i12 = im%args.ne12;
5829
6230
  const uint i13 = im/args.ne12;
@@ -5834,16 +6235,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5834
6235
  device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
5835
6236
  device const float * y = (device const float *) (src1 + offset1);
5836
6237
 
5837
- const int ix = tiisg/16; // 0 or 1
5838
- const int it = tiisg%16; // 0...15
5839
- const int ib = it/2;
5840
- const int il = it%2;
6238
+ const short ix = tiisg/16; // 0 or 1
6239
+ const short it = tiisg%16; // 0...15
6240
+ const short ib = it/2;
6241
+ const short il = it%2;
5841
6242
 
5842
6243
  shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
5843
6244
  threadgroup_barrier(mem_flags::mem_threadgroup);
5844
6245
 
5845
6246
  float4 yl[4];
5846
- float sumf[2]={0.f}, all_sum;
6247
+ float sumf[nr0]={0.f};
5847
6248
 
5848
6249
  device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
5849
6250
 
@@ -5854,9 +6255,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5854
6255
 
5855
6256
  for (int ibl = ix; ibl < nb; ibl += 2) {
5856
6257
  device const float4 * y4 = (device const float4 *)yb;
5857
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
6258
+ yl[0] = y4[0];
6259
+ yl[1] = y4[4];
6260
+ yl[2] = y4[1];
6261
+ yl[3] = y4[5];
5858
6262
 
5859
- for (int row = 0; row < 2; ++row) {
6263
+ for (short row = 0; row < nr0; ++row) {
5860
6264
  device const block_iq4_xs & xb = x[row*nb + ibl];
5861
6265
  device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
5862
6266
 
@@ -5880,7 +6284,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5880
6284
 
5881
6285
  const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
5882
6286
  sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
5883
-
5884
6287
  }
5885
6288
 
5886
6289
  yb += 2 * QK_K;
@@ -5888,54 +6291,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5888
6291
 
5889
6292
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
5890
6293
 
5891
- for (int row = 0; row < 2; ++row) {
5892
- all_sum = simd_sum(sumf[row]);
6294
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
6295
+ float sum_all = simd_sum(sumf[row]);
5893
6296
  if (tiisg == 0) {
5894
- dst_f32[first_row + row] = all_sum;
6297
+ dst_f32[first_row + row] = sum_all;
5895
6298
  }
5896
6299
  }
5897
6300
  }
5898
6301
 
5899
- [[host_name("kernel_mul_mv_iq1_s_f32")]]
5900
- kernel void kernel_mul_mv_iq1_s_f32(
5901
- constant ggml_metal_kargs_mul_mv & args,
5902
- device const char * src0,
5903
- device const char * src1,
5904
- device char * dst,
5905
- uint3 tgpig[[threadgroup_position_in_grid]],
5906
- ushort tiisg[[thread_index_in_simdgroup]],
5907
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5908
-
5909
- kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5910
- }
5911
-
5912
- [[host_name("kernel_mul_mv_iq1_m_f32")]]
5913
- kernel void kernel_mul_mv_iq1_m_f32(
5914
- constant ggml_metal_kargs_mul_mv & args,
5915
- device const char * src0,
5916
- device const char * src1,
5917
- device char * dst,
5918
- uint3 tgpig[[threadgroup_position_in_grid]],
5919
- ushort tiisg[[thread_index_in_simdgroup]],
5920
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5921
-
5922
- kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
5923
- }
5924
-
5925
- [[host_name("kernel_mul_mv_iq4_nl_f32")]]
5926
- kernel void kernel_mul_mv_iq4_nl_f32(
5927
- constant ggml_metal_kargs_mul_mv & args,
5928
- device const char * src0,
5929
- device const char * src1,
5930
- device char * dst,
5931
- threadgroup char * shmem [[threadgroup(0)]],
5932
- uint3 tgpig[[threadgroup_position_in_grid]],
5933
- ushort tiisg[[thread_index_in_simdgroup]],
5934
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5935
-
5936
- kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5937
- }
5938
-
5939
6302
  [[host_name("kernel_mul_mv_iq4_xs_f32")]]
5940
6303
  kernel void kernel_mul_mv_iq4_xs_f32(
5941
6304
  constant ggml_metal_kargs_mul_mv & args,
@@ -5947,7 +6310,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5947
6310
  ushort tiisg[[thread_index_in_simdgroup]],
5948
6311
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5949
6312
 
5950
- kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
6313
+ kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
5951
6314
  }
5952
6315
 
5953
6316
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -5955,28 +6318,21 @@ kernel void kernel_get_rows_q(
5955
6318
  device const void * src0,
5956
6319
  device const void * src1,
5957
6320
  device float * dst,
5958
- constant int64_t & ne00,
5959
- constant uint64_t & nb01,
5960
- constant uint64_t & nb02,
5961
- constant int64_t & ne10,
5962
- constant uint64_t & nb10,
5963
- constant uint64_t & nb11,
5964
- constant uint64_t & nb1,
5965
- constant uint64_t & nb2,
6321
+ constant ggml_metal_kargs_get_rows & args,
5966
6322
  uint3 tgpig[[threadgroup_position_in_grid]],
5967
6323
  uint tiitg[[thread_index_in_threadgroup]],
5968
6324
  uint3 tptg [[threads_per_threadgroup]]) {
5969
6325
  const int64_t i10 = tgpig.x;
5970
6326
  const int64_t i11 = tgpig.y;
5971
6327
 
5972
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
6328
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
5973
6329
 
5974
6330
  const int64_t i02 = i11;
5975
6331
 
5976
- for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
6332
+ for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
5977
6333
  float4x4 temp;
5978
- dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5979
- *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
6334
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
6335
+ *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
5980
6336
  }
5981
6337
  }
5982
6338
 
@@ -5985,27 +6341,20 @@ kernel void kernel_get_rows_f(
5985
6341
  device const void * src0,
5986
6342
  device const void * src1,
5987
6343
  device float * dst,
5988
- constant int64_t & ne00,
5989
- constant uint64_t & nb01,
5990
- constant uint64_t & nb02,
5991
- constant int64_t & ne10,
5992
- constant uint64_t & nb10,
5993
- constant uint64_t & nb11,
5994
- constant uint64_t & nb1,
5995
- constant uint64_t & nb2,
6344
+ constant ggml_metal_kargs_get_rows & args,
5996
6345
  uint3 tgpig[[threadgroup_position_in_grid]],
5997
6346
  uint tiitg[[thread_index_in_threadgroup]],
5998
6347
  uint3 tptg [[threads_per_threadgroup]]) {
5999
6348
  const int64_t i10 = tgpig.x;
6000
6349
  const int64_t i11 = tgpig.y;
6001
6350
 
6002
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
6351
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
6003
6352
 
6004
6353
  const int64_t i02 = i11;
6005
6354
 
6006
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
6007
- (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
6008
- ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
6355
+ for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
6356
+ (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
6357
+ ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
6009
6358
  }
6010
6359
  }
6011
6360
 
@@ -6013,27 +6362,20 @@ kernel void kernel_get_rows_i32(
6013
6362
  device const void * src0,
6014
6363
  device const void * src1,
6015
6364
  device int32_t * dst,
6016
- constant int64_t & ne00,
6017
- constant uint64_t & nb01,
6018
- constant uint64_t & nb02,
6019
- constant int64_t & ne10,
6020
- constant uint64_t & nb10,
6021
- constant uint64_t & nb11,
6022
- constant uint64_t & nb1,
6023
- constant uint64_t & nb2,
6365
+ constant ggml_metal_kargs_get_rows & args,
6024
6366
  uint3 tgpig[[threadgroup_position_in_grid]],
6025
6367
  uint tiitg[[thread_index_in_threadgroup]],
6026
6368
  uint3 tptg [[threads_per_threadgroup]]) {
6027
6369
  const int64_t i10 = tgpig.x;
6028
6370
  const int64_t i11 = tgpig.y;
6029
6371
 
6030
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
6372
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
6031
6373
 
6032
6374
  const int64_t i02 = i11;
6033
6375
 
6034
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
6035
- (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
6036
- ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
6376
+ for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
6377
+ (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
6378
+ ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
6037
6379
  }
6038
6380
  }
6039
6381
 
@@ -6192,127 +6534,219 @@ kernel void kernel_mul_mm(
6192
6534
  }
6193
6535
  }
6194
6536
 
6195
- // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
6196
- // TODO: this kernel needs to be reimplemented from scratch for better performance
6197
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6198
- void kernel_mul_mm_id_impl(
6199
- int32_t ne00,
6200
- int32_t ne02,
6201
- uint64_t nb01,
6202
- uint64_t nb02,
6203
- int32_t ne11,
6204
- int32_t ne12,
6205
- uint64_t nb10,
6206
- uint64_t nb11,
6207
- uint64_t nb12,
6208
- int32_t ne0,
6209
- int32_t ne1,
6210
- int64_t ne0ne1,
6211
- device const char * src0,
6212
- device const char * src1,
6213
- threadgroup ushort2 * rowids,
6214
- device char * dst,
6215
- threadgroup char * shmem,
6537
+ template<typename T4>
6538
+ kernel void kernel_mul_mm_id_map0(
6539
+ constant ggml_metal_kargs_mul_mm_id_map0 & args,
6540
+ device const char * src1,
6541
+ device const char * src2,
6542
+ device char * hsrc1,
6543
+ device char * htpe,
6544
+ device char * hids,
6545
+ uint3 tgpig[[threadgroup_position_in_grid]],
6546
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6547
+ ushort3 ntg[[threads_per_threadgroup]]) {
6548
+ const int ide = tgpig[0]; // expert id
6549
+
6550
+ int n_all = 0;
6551
+
6552
+ device int32_t * ids_i32 = (device int32_t *) (hids);
6553
+
6554
+ for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
6555
+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
6556
+
6557
+ for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
6558
+ if (src2_i32[i20] != ide) {
6559
+ continue;
6560
+ }
6561
+
6562
+ device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
6563
+ device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
6564
+
6565
+ for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
6566
+ hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
6567
+ }
6568
+
6569
+ if (tpitg.x == 0) {
6570
+ ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
6571
+ }
6572
+
6573
+ ++n_all;
6574
+ }
6575
+ }
6576
+
6577
+ if (tpitg.x == 0) {
6578
+ device int32_t * tpe_i32 = (device int32_t *) (htpe);
6579
+ tpe_i32[ide] = n_all;
6580
+ }
6581
+ }
6582
+
6583
+ typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
6584
+
6585
+ template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
6586
+
6587
+ template<typename T>
6588
+ kernel void kernel_mul_mm_id_map1(
6589
+ constant ggml_metal_kargs_mul_mm_id_map1 & args,
6590
+ device const char * hdst,
6591
+ device const char * hids,
6592
+ device char * dst,
6593
+ uint3 tgpig[[threadgroup_position_in_grid]],
6594
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6595
+ ushort3 ntg[[threads_per_threadgroup]]) {
6596
+ const int i20 = tgpig[0]; // used expert
6597
+ const int i21 = tgpig[1]; // token
6598
+
6599
+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
6600
+ device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
6601
+
6602
+ const int id = ids_i32[i21*args.ne20 + i20];
6603
+
6604
+ const int ide = id / args.neh1;
6605
+ const int idt = id % args.neh1;
6606
+
6607
+ device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
6608
+
6609
+ for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
6610
+ dst_f32x4[i0] = hdst_f32x4[i0];
6611
+ }
6612
+ }
6613
+
6614
+ typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
6615
+
6616
+ template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
6617
+
6618
+ template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
6619
+ kernel void kernel_mul_mm_id(
6620
+ constant ggml_metal_kargs_mul_mm_id & args,
6621
+ device const char * src0,
6622
+ device const char * src1,
6623
+ device const char * tpe,
6624
+ device char * dst,
6625
+ threadgroup char * shmem [[threadgroup(0)]],
6216
6626
  uint3 tgpig[[threadgroup_position_in_grid]],
6217
6627
  ushort tiitg[[thread_index_in_threadgroup]],
6218
6628
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6219
6629
 
6220
- threadgroup half * sa = (threadgroup half *)(shmem);
6221
- threadgroup float * sb = (threadgroup float *)(shmem + 4096);
6630
+ threadgroup T * sa = (threadgroup T *)(shmem);
6631
+ threadgroup half * sb = (threadgroup half *)(shmem + 4096);
6222
6632
 
6223
6633
  const int r0 = tgpig.y;
6224
6634
  const int r1 = tgpig.x;
6635
+ const int im = tgpig.z;
6225
6636
 
6226
- if (r1*BLOCK_SIZE_N >= ne1) return;
6637
+ device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
6638
+
6639
+ const int neh1 = tpe_i32[im];
6640
+
6641
+ if (r1*BLOCK_SIZE_N >= neh1) {
6642
+ return;
6643
+ }
6227
6644
 
6228
6645
  // if this block is of 64x32 shape or smaller
6229
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
6230
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
6646
+ const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
6647
+ const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
6231
6648
 
6232
6649
  // a thread shouldn't load data outside of the matrix
6233
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
6234
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
6650
+ const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
6651
+ const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
6235
6652
 
6236
- simdgroup_half8x8 ma[4];
6237
- simdgroup_float8x8 mb[2];
6653
+ simdgroup_T8x8 ma[4];
6654
+ simdgroup_half8x8 mb[2];
6238
6655
  simdgroup_float8x8 mc[8];
6239
- for (int i = 0; i < 8; i++){
6656
+
6657
+ for (short i = 0; i < 8; i++){
6240
6658
  mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
6241
6659
  }
6660
+
6242
6661
  short il = (tiitg % THREAD_PER_ROW);
6243
6662
 
6244
- ushort offset1 = il/nl;
6663
+ const int i12 = im%args.neh12;
6664
+ const int i13 = im/args.neh12;
6245
6665
 
6246
- threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
6666
+ const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
6667
+ const short offset1 = il/nl;
6247
6668
 
6248
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
6249
- device const float * y = (device const float *)(src1
6250
- + nb12 * id[1]
6251
- + nb11 * (id[0] % ne11)
6252
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
6669
+ device const block_q * x = (device const block_q *)(src0
6670
+ + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
6671
+
6672
+ device const half * y = (device const half *)(src1
6673
+ + args.nbh13*i13
6674
+ + args.nbh12*i12
6675
+ + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
6676
+ + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
6253
6677
 
6254
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
6678
+ for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
6255
6679
  // load data and store to threadgroup memory
6256
- half4x4 temp_a;
6680
+ T4x4 temp_a;
6257
6681
  dequantize_func(x, il, temp_a);
6682
+
6258
6683
  threadgroup_barrier(mem_flags::mem_threadgroup);
6259
6684
 
6260
- for (int i = 0; i < 16; i++) {
6261
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
6262
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
6263
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
6685
+ #pragma unroll(16)
6686
+ for (short i = 0; i < 16; i++) {
6687
+ *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
6688
+ + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
6689
+ + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
6264
6690
  }
6265
6691
 
6266
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
6692
+ *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
6267
6693
 
6268
6694
  il = (il + 2 < nl) ? il + 2 : il % 2;
6269
- x = (il < 2) ? x + (2+nl-1)/nl : x;
6695
+ x = (il < 2) ? x + (2 + nl - 1)/nl : x;
6270
6696
  y += BLOCK_SIZE_K;
6271
6697
 
6272
6698
  threadgroup_barrier(mem_flags::mem_threadgroup);
6273
6699
 
6274
6700
  // load matrices from threadgroup memory and conduct outer products
6275
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
6276
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
6701
+ threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
6702
+ threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
6277
6703
 
6278
- #pragma unroll(BLOCK_SIZE_K/8)
6279
- for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
6704
+ #pragma unroll(4)
6705
+ for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
6280
6706
  #pragma unroll(4)
6281
- for (int i = 0; i < 4; i++) {
6707
+ for (short i = 0; i < 4; i++) {
6282
6708
  simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
6283
6709
  }
6710
+
6284
6711
  simdgroup_barrier(mem_flags::mem_none);
6712
+
6285
6713
  #pragma unroll(2)
6286
- for (int i = 0; i < 2; i++) {
6714
+ for (short i = 0; i < 2; i++) {
6287
6715
  simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
6288
6716
  }
6289
6717
 
6290
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
6291
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
6292
-
6293
6718
  #pragma unroll(8)
6294
- for (int i = 0; i < 8; i++){
6719
+ for (short i = 0; i < 8; i++){
6295
6720
  simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
6296
6721
  }
6722
+
6723
+ lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
6724
+ lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
6297
6725
  }
6298
6726
  }
6299
6727
 
6300
- {
6728
+ if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
6729
+ device float * C = (device float *) dst +
6730
+ (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
6731
+ (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
6732
+
6733
+ for (short i = 0; i < 8; i++) {
6734
+ simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
6735
+ }
6736
+ } else {
6737
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
6301
6738
  threadgroup_barrier(mem_flags::mem_threadgroup);
6302
6739
  threadgroup float * temp_str = ((threadgroup float *) shmem) \
6303
- + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
6304
- for (int i = 0; i < 8; i++) {
6305
- simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
6740
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
6741
+ for (short i = 0; i < 8; i++) {
6742
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
6306
6743
  }
6307
6744
 
6308
6745
  threadgroup_barrier(mem_flags::mem_threadgroup);
6309
6746
 
6310
6747
  if (sgitg == 0) {
6311
6748
  for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6312
- threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
6313
- int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
6314
-
6315
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
6749
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
6316
6750
  device float4 * D4 = (device float4 *) D;
6317
6751
 
6318
6752
  threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
@@ -6332,66 +6766,6 @@ void kernel_mul_mm_id_impl(
6332
6766
  }
6333
6767
  }
6334
6768
 
6335
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6336
- kernel void kernel_mul_mm_id(
6337
- constant ggml_metal_kargs_mul_mm_id & args,
6338
- device const char * src0s,
6339
- device const char * src1,
6340
- device char * dst,
6341
- device const char * ids,
6342
- threadgroup char * shmem [[threadgroup(0)]],
6343
- uint3 tgpig[[threadgroup_position_in_grid]],
6344
- ushort tiitg[[thread_index_in_threadgroup]],
6345
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6346
-
6347
- const int32_t i02 = tgpig.z;
6348
-
6349
- tgpig.z = 0;
6350
-
6351
- device const char * src0 = src0s + i02*args.nb02;
6352
-
6353
- // row indices
6354
- threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
6355
-
6356
- // TODO: parallelize this loop
6357
- int32_t _ne1 = 0;
6358
- for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
6359
- for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
6360
- int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
6361
- if (id == i02) {
6362
- if (tiitg == 0) {
6363
- rowids[_ne1] = ushort2(ii0, ii1);
6364
- }
6365
- _ne1++;
6366
- }
6367
- }
6368
- }
6369
-
6370
- threadgroup_barrier(mem_flags::mem_threadgroup);
6371
-
6372
- kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
6373
- args.ne00,
6374
- args.ne02,
6375
- args.nb01,
6376
- args.nb02,
6377
- args.ne11,
6378
- args.ne12,
6379
- args.nb10,
6380
- args.nb11,
6381
- args.nb12,
6382
- args.ne0,
6383
- _ne1,
6384
- (int64_t)args.ne0*args.ne1,
6385
- src0,
6386
- src1,
6387
- rowids,
6388
- dst,
6389
- shmem,
6390
- tgpig,
6391
- tiitg,
6392
- sgitg);
6393
- }
6394
-
6395
6769
  #define QK_NL 16
6396
6770
 
6397
6771
  //
@@ -6432,63 +6806,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
6432
6806
  // matrix-matrix multiplication
6433
6807
  //
6434
6808
 
6435
- typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
6809
+ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
6436
6810
 
6437
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
6438
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6811
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
6812
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6439
6813
  #if defined(GGML_METAL_USE_BF16)
6440
- template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
6814
+ template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
6441
6815
  #endif
6442
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
6443
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
6444
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
6445
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
6446
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
6447
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
6448
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
6449
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
6450
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
6451
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
6452
- template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6453
- template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6454
- template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6455
- template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
6456
- template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
6457
- template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
6458
- template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
6459
- template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
6460
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6816
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
6817
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
6818
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
6819
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
6820
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
6821
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
6822
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
6823
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
6824
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
6825
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
6826
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6827
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6828
+ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6829
+ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
6830
+ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
6831
+ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
6832
+ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
6833
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
6834
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6461
6835
 
6462
6836
  //
6463
6837
  // indirect matrix-matrix multiplication
6464
6838
  //
6465
6839
 
6466
- typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
6840
+ typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
6467
6841
 
6468
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
6469
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
6842
+ template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
6843
+ template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
6470
6844
  #if defined(GGML_METAL_USE_BF16)
6471
- template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
6845
+ template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
6472
6846
  #endif
6473
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
6474
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
6475
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
6476
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
6477
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
6478
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
6479
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
6480
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
6481
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
6482
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
6483
- template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6484
- template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6485
- template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6486
- template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
6487
- template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
6488
- template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6489
- template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
6490
- template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6491
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6847
+ template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
6848
+ template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
6849
+ template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
6850
+ template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
6851
+ template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
6852
+ template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
6853
+ template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
6854
+ template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
6855
+ template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
6856
+ template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
6857
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6858
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6859
+ template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6860
+ template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
6861
+ template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
6862
+ template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
6863
+ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
6864
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
6865
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6866
+
6492
6867
 
6493
6868
  //
6494
6869
  // matrix-vector multiplication
@@ -6612,121 +6987,103 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t
6612
6987
  #if defined(GGML_METAL_USE_BF16)
6613
6988
  template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
6614
6989
  #endif
6615
- template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
6616
- template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6617
- template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6618
- template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6619
- template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
6620
- template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
6621
- template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
6622
- template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
6623
- template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
6624
- template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
6625
- template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
6626
- template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
6627
- template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
6628
- template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
6629
- template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
6630
- template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6631
- template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6632
- template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6633
- template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6990
+ template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
6991
+
6992
+ template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
6993
+ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
6994
+ template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
6995
+ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
6996
+
6997
+ template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
6998
+ template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
6999
+ template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
7000
+ template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH>>>;
7001
+ template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH>>>;
7002
+ template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH>>>;
7003
+ template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH>>>;
7004
+ template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
7005
+ template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH>>>;
7006
+ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
7007
+ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH>>>;
7008
+ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH>>>;
7009
+ template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
7010
+ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
6634
7011
 
6635
7012
  kernel void kernel_pool_2d_max_f32(
6636
7013
  device const float * src0,
6637
7014
  device float * dst,
6638
- constant int32_t & k0,
6639
- constant int32_t & k1,
6640
- constant int32_t & s0,
6641
- constant int32_t & s1,
6642
- constant int32_t & p0,
6643
- constant int32_t & p1,
6644
- constant int64_t & IH,
6645
- constant int64_t & IW,
6646
- constant int64_t & OH,
6647
- constant int64_t & OW,
6648
- constant int64_t & parallel_elements,
7015
+ constant ggml_metal_kargs_pool_2d & args,
6649
7016
  uint gid[[thread_position_in_grid]]) {
6650
7017
 
6651
- if (gid >= parallel_elements) {
7018
+ if (gid >= args.parallel_elements) {
6652
7019
  return;
6653
7020
  }
6654
7021
 
6655
7022
  const int idx = gid;
6656
- const int I_HW = IH * IW;
6657
- const int O_HW = OH * OW;
7023
+ const int I_HW = args.IH * args.IW;
7024
+ const int O_HW = args.OH * args.OW;
6658
7025
  const int nc = idx / O_HW;
6659
- const int cur_oh = idx % O_HW / OW;
6660
- const int cur_ow = idx % O_HW % OW;
7026
+ const int cur_oh = idx % O_HW / args.OW;
7027
+ const int cur_ow = idx % O_HW % args.OW;
6661
7028
 
6662
7029
  device const float * i_ptr = src0 + nc * I_HW;
6663
7030
  device float * o_ptr = dst + nc * O_HW;
6664
7031
 
6665
- const int start_h = cur_oh * s1 - p1;
7032
+ const int start_h = cur_oh * args.s1 - args.p1;
6666
7033
  const int bh = MAX(0, start_h);
6667
- const int eh = MIN(IH, start_h + k1);
6668
- const int start_w = cur_ow * s0 - p0;
7034
+ const int eh = MIN(args.IH, start_h + args.k1);
7035
+ const int start_w = cur_ow * args.s0 - args.p0;
6669
7036
  const int bw = MAX(0, start_w);
6670
- const int ew = MIN(IW, start_w + k0);
7037
+ const int ew = MIN(args.IW, start_w + args.k0);
6671
7038
 
6672
7039
  float res = -INFINITY;
6673
7040
 
6674
7041
  for (int i = bh; i < eh; i += 1) {
6675
7042
  for (int j = bw; j < ew; j += 1) {
6676
- res = MAX(res, i_ptr[i * IW + j]);
7043
+ res = MAX(res, i_ptr[i * args.IW + j]);
6677
7044
  }
6678
7045
  }
6679
7046
 
6680
- o_ptr[cur_oh * OW + cur_ow] = res;
7047
+ o_ptr[cur_oh * args.OW + cur_ow] = res;
6681
7048
  }
6682
7049
 
6683
7050
  kernel void kernel_pool_2d_avg_f32(
6684
7051
  device const float * src0,
6685
7052
  device float * dst,
6686
- constant int32_t & k0,
6687
- constant int32_t & k1,
6688
- constant int32_t & s0,
6689
- constant int32_t & s1,
6690
- constant int32_t & p0,
6691
- constant int32_t & p1,
6692
- constant int64_t & IH,
6693
- constant int64_t & IW,
6694
- constant int64_t & OH,
6695
- constant int64_t & OW,
6696
- constant int64_t & parallel_elements,
7053
+ constant ggml_metal_kargs_pool_2d & args,
6697
7054
  uint gid[[thread_position_in_grid]]) {
6698
7055
 
6699
- if (gid >= parallel_elements) {
7056
+ if (gid >= args.parallel_elements) {
6700
7057
  return;
6701
7058
  }
6702
7059
 
6703
7060
  const int idx = gid;
6704
- const int I_HW = IH * IW;
6705
- const int O_HW = OH * OW;
7061
+ const int I_HW = args.IH * args.IW;
7062
+ const int O_HW = args.OH * args.OW;
6706
7063
  const int nc = idx / O_HW;
6707
- const int cur_oh = idx % O_HW / OW;
6708
- const int cur_ow = idx % O_HW % OW;
7064
+ const int cur_oh = idx % O_HW / args.OW;
7065
+ const int cur_ow = idx % O_HW % args.OW;
6709
7066
 
6710
7067
  device const float * i_ptr = src0 + nc * I_HW;
6711
7068
  device float * o_ptr = dst + nc * O_HW;
6712
7069
 
6713
- const int start_h = cur_oh * s1 - p1;
7070
+ const int start_h = cur_oh * args.s1 - args.p1;
6714
7071
  const int bh = MAX(0, start_h);
6715
- const int eh = MIN(IH, start_h + k1);
6716
- const int start_w = cur_ow * s0 - p0;
7072
+ const int eh = MIN(args.IH, start_h + args.k1);
7073
+ const int start_w = cur_ow * args.s0 - args.p0;
6717
7074
  const int bw = MAX(0, start_w);
6718
- const int ew = MIN(IW, start_w + k0);
7075
+ const int ew = MIN(args.IW, start_w + args.k0);
6719
7076
  // const float scale = 1. / ((eh - bh) * (ew - bw));
6720
- const float scale = 1. / (k0 * k1);
7077
+ const float scale = 1. / (args.k0 * args.k1);
6721
7078
 
6722
7079
  float res = 0;
6723
7080
 
6724
7081
  for (int i = bh; i < eh; i += 1) {
6725
7082
  for (int j = bw; j < ew; j += 1) {
6726
- float cur = i_ptr[i * IW + j];
7083
+ float cur = i_ptr[i * args.IW + j];
6727
7084
  res += cur * scale;
6728
7085
  }
6729
7086
  }
6730
7087
 
6731
- o_ptr[cur_oh * OW + cur_ow] = res;
7088
+ o_ptr[cur_oh * args.OW + cur_ow] = res;
6732
7089
  }