whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -9,6 +9,12 @@ __embed_ggml-common.h__
9
9
 
10
10
  #include <metal_stdlib>
11
11
 
12
+ #ifdef GGML_METAL_HAS_TENSOR
13
+ #include <metal_tensor>
14
+
15
+ #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
16
+ #endif
17
+
12
18
  using namespace metal;
13
19
 
14
20
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
@@ -1243,6 +1249,22 @@ kernel void kernel_scale_f32_4(
1243
1249
  dst[tpig] = src0[tpig] * args.scale + args.bias;
1244
1250
  }
1245
1251
 
1252
+ kernel void kernel_fill_f32(
1253
+ constant ggml_metal_kargs_fill & args,
1254
+ device const float * src0,
1255
+ device float * dst,
1256
+ uint tpig[[thread_position_in_grid]]) {
1257
+ dst[tpig] = args.val;
1258
+ }
1259
+
1260
+ kernel void kernel_fill_f32_4(
1261
+ constant ggml_metal_kargs_fill & args,
1262
+ device const float4 * src0,
1263
+ device float4 * dst,
1264
+ uint tpig[[thread_position_in_grid]]) {
1265
+ dst[tpig] = args.val;
1266
+ }
1267
+
1246
1268
  kernel void kernel_clamp_f32(
1247
1269
  constant ggml_metal_kargs_clamp & args,
1248
1270
  device const float * src0,
@@ -1589,6 +1611,36 @@ kernel void kernel_exp_f32_4(
1589
1611
  dst[tpig] = exp(src0[tpig]);
1590
1612
  }
1591
1613
 
1614
+ kernel void kernel_softplus_f32(
1615
+ device const float * src0,
1616
+ device float * dst,
1617
+ uint tpig[[thread_position_in_grid]]) {
1618
+ device const float & x = src0[tpig];
1619
+ dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1620
+ }
1621
+
1622
+ kernel void kernel_softplus_f32_4(
1623
+ device const float4 * src0,
1624
+ device float4 * dst,
1625
+ uint tpig[[thread_position_in_grid]]) {
1626
+ device const float4 & x = src0[tpig];
1627
+ dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1628
+ }
1629
+
1630
+ kernel void kernel_expm1_f32(
1631
+ device const float * src0,
1632
+ device float * dst,
1633
+ uint tpig[[thread_position_in_grid]]) {
1634
+ dst[tpig] = exp(src0[tpig]) - 1.0f;
1635
+ }
1636
+
1637
+ kernel void kernel_expm1_f32_4(
1638
+ device const float4 * src0,
1639
+ device float4 * dst,
1640
+ uint tpig[[thread_position_in_grid]]) {
1641
+ dst[tpig] = exp(src0[tpig]) - 1.0f;
1642
+ }
1643
+
1592
1644
  kernel void kernel_reglu_f32(
1593
1645
  constant ggml_metal_kargs_glu & args,
1594
1646
  device const char * src0,
@@ -1723,6 +1775,55 @@ kernel void kernel_geglu_quick_f32(
1723
1775
  }
1724
1776
  }
1725
1777
 
1778
+ kernel void kernel_op_sum_f32(
1779
+ constant ggml_metal_kargs_sum & args,
1780
+ device const float * src0,
1781
+ device float * dst,
1782
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
1783
+ uint3 tgpig[[threadgroup_position_in_grid]],
1784
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1785
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1786
+ ushort tiisg[[thread_index_in_simdgroup]],
1787
+ ushort3 ntg[[threads_per_threadgroup]]) {
1788
+
1789
+ if (args.np == 0) {
1790
+ return;
1791
+ }
1792
+
1793
+ // TODO: become function constant
1794
+ const uint nsg = (ntg.x + 31) / 32;
1795
+
1796
+ float sumf = 0;
1797
+
1798
+ for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
1799
+ sumf += src0[i0];
1800
+ }
1801
+
1802
+ sumf = simd_sum(sumf);
1803
+
1804
+ if (tiisg == 0) {
1805
+ shmem_f32[sgitg] = sumf;
1806
+ }
1807
+
1808
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1809
+
1810
+ float total = 0;
1811
+
1812
+ if (sgitg == 0) {
1813
+ float v = 0;
1814
+
1815
+ if (tpitg.x < nsg) {
1816
+ v = shmem_f32[tpitg.x];
1817
+ }
1818
+
1819
+ total = simd_sum(v);
1820
+
1821
+ if (tpitg.x == 0) {
1822
+ dst[0] = total;
1823
+ }
1824
+ }
1825
+ }
1826
+
1726
1827
  template <bool norm>
1727
1828
  kernel void kernel_sum_rows(
1728
1829
  constant ggml_metal_kargs_sum_rows & args,
@@ -1778,6 +1879,186 @@ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1778
1879
  template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1779
1880
  template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1780
1881
 
1882
+ template<typename T>
1883
+ kernel void kernel_cumsum_blk(
1884
+ constant ggml_metal_kargs_cumsum_blk & args,
1885
+ device const char * src0,
1886
+ device char * tmp,
1887
+ device char * dst,
1888
+ threadgroup char * shmem [[threadgroup(0)]],
1889
+ uint3 tgpig[[threadgroup_position_in_grid]],
1890
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1891
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1892
+ ushort tiisg[[thread_index_in_simdgroup]],
1893
+ ushort3 ntg[[threads_per_threadgroup]]) {
1894
+ const int ib = tgpig[0]/args.ne01;
1895
+
1896
+ const int i00 = ib*ntg.x;
1897
+ const int i01 = tgpig[0]%args.ne01;
1898
+ const int i02 = tgpig[1];
1899
+ const int i03 = tgpig[2];
1900
+
1901
+ device const float * src0_row = (device const float *) (src0 +
1902
+ args.nb01*i01 +
1903
+ args.nb02*i02 +
1904
+ args.nb03*i03);
1905
+
1906
+ threadgroup float * shmem_f32 = (threadgroup float *) shmem;
1907
+
1908
+ float v = 0.0f;
1909
+
1910
+ if (i00 + tpitg.x < args.ne00) {
1911
+ v = src0_row[i00 + tpitg.x];
1912
+ }
1913
+
1914
+ float s = simd_prefix_inclusive_sum(v);
1915
+
1916
+ if (tiisg == N_SIMDWIDTH - 1) {
1917
+ shmem_f32[sgitg] = s;
1918
+ }
1919
+
1920
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1921
+
1922
+ if (sgitg == 0) {
1923
+ shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
1924
+ }
1925
+
1926
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1927
+
1928
+ s += shmem_f32[sgitg];
1929
+
1930
+ device float * dst_row = (device float *) dst +
1931
+ args.ne00*i01 +
1932
+ args.ne00*args.ne01*i02 +
1933
+ args.ne00*args.ne01*args.ne02*i03;
1934
+
1935
+ if (i00 + tpitg.x < args.ne00) {
1936
+ dst_row[i00 + tpitg.x] = s;
1937
+ }
1938
+
1939
+ if (args.outb && tpitg.x == ntg.x - 1) {
1940
+ device float * tmp_row = (device float *) tmp +
1941
+ args.net0*i01 +
1942
+ args.net0*args.net1*i02 +
1943
+ args.net0*args.net1*args.net2*i03;
1944
+
1945
+ tmp_row[ib] = s;
1946
+ }
1947
+ }
1948
+
1949
+ typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
1950
+
1951
+ template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
1952
+
1953
+ template<typename T>
1954
+ kernel void kernel_cumsum_add(
1955
+ constant ggml_metal_kargs_cumsum_add & args,
1956
+ device const char * tmp,
1957
+ device char * dst,
1958
+ uint3 tgpig[[threadgroup_position_in_grid]],
1959
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1960
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1961
+ ushort tiisg[[thread_index_in_simdgroup]],
1962
+ ushort3 ntg[[threads_per_threadgroup]]) {
1963
+ const int ib = tgpig[0]/args.ne01;
1964
+
1965
+ if (ib == 0) {
1966
+ return;
1967
+ }
1968
+
1969
+ const int i00 = ib*ntg.x;
1970
+ const int i01 = tgpig[0]%args.ne01;
1971
+ const int i02 = tgpig[1];
1972
+ const int i03 = tgpig[2];
1973
+
1974
+ device const float * tmp_row = (device const float *) (tmp +
1975
+ args.nbt1*i01 +
1976
+ args.nbt2*i02 +
1977
+ args.nbt3*i03);
1978
+
1979
+ device float * dst_row = (device float *) dst +
1980
+ args.ne00*i01 +
1981
+ args.ne00*args.ne01*i02 +
1982
+ args.ne00*args.ne01*args.ne02*i03;
1983
+
1984
+ if (i00 + tpitg.x < args.ne00) {
1985
+ dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
1986
+ }
1987
+ }
1988
+
1989
+ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
1990
+
1991
+ template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
1992
+
1993
+
1994
+ template<uint32_t ttype>
1995
+ bool _ggml_vec_tri_cmp(const int i, const int r);
1996
+
1997
+ template<>
1998
+ bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
1999
+ return i < r;
2000
+ }
2001
+
2002
+ template<>
2003
+ bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
2004
+ return i <= r;
2005
+ }
2006
+
2007
+ template<>
2008
+ bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
2009
+ return i > r;
2010
+ }
2011
+
2012
+ template<>
2013
+ bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
2014
+ return i >= r;
2015
+ }
2016
+
2017
+ template<typename T, int ttype>
2018
+ kernel void kernel_tri(
2019
+ constant ggml_metal_kargs_tri & args,
2020
+ device const char * src0,
2021
+ device const char * dst,
2022
+ uint3 tgpig[[threadgroup_position_in_grid]],
2023
+ ushort3 tpitg[[thread_position_in_threadgroup]],
2024
+ ushort3 ntg[[threads_per_threadgroup]]) {
2025
+ const int i3 = tgpig.z;
2026
+ const int i2 = tgpig.y;
2027
+ const int i1 = tgpig.x;
2028
+
2029
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
2030
+ return;
2031
+ }
2032
+
2033
+ device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
2034
+ device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
2035
+
2036
+ // Each thread is a single element of the row if ne00 < max threads per
2037
+ // threadgroup, so this will loop once for each index that this thread is
2038
+ // responsible for
2039
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
2040
+ // Use the comparison as a mask for branchless
2041
+ dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
2042
+ }
2043
+ }
2044
+
2045
+ typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
2046
+
2047
+ template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
2048
+ template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
2049
+ template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
2050
+ template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
2051
+ template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
2052
+ template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
2053
+ template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
2054
+ template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
2055
+ #if defined(GGML_METAL_HAS_BF16)
2056
+ template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
2057
+ template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
2058
+ template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
2059
+ template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
2060
+ #endif
2061
+
1781
2062
  template<typename T>
1782
2063
  kernel void kernel_soft_max(
1783
2064
  constant ggml_metal_kargs_soft_max & args,
@@ -2032,124 +2313,134 @@ kernel void kernel_ssm_conv_f32_f32(
2032
2313
  x[0] = sumf;
2033
2314
  }
2034
2315
 
2035
- // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
2036
- kernel void kernel_ssm_scan_f32(
2037
- constant ggml_metal_kargs_ssm_scan & args,
2038
- device const void * src0,
2039
- device const void * src1,
2040
- device const void * src2,
2041
- device const void * src3,
2042
- device const void * src4,
2043
- device const void * src5,
2044
- device const void * src6,
2045
- device float * dst,
2046
- threadgroup float * shared [[threadgroup(0)]],
2047
- uint3 tgpig[[threadgroup_position_in_grid]],
2048
- uint3 tpitg[[thread_position_in_threadgroup]],
2049
- ushort sgitg[[simdgroup_index_in_threadgroup]],
2050
- ushort tiisg[[thread_index_in_simdgroup]],
2051
- ushort sgptg[[simdgroups_per_threadgroup]],
2052
- uint3 tgpg[[threadgroups_per_grid]]) {
2316
+ kernel void kernel_ssm_conv_f32_f32_4(
2317
+ constant ggml_metal_kargs_ssm_conv & args,
2318
+ device const void * src0,
2319
+ device const void * src1,
2320
+ device float * dst,
2321
+ uint3 tgpig[[threadgroup_position_in_grid]],
2322
+ uint3 tpitg[[thread_position_in_threadgroup]],
2323
+ uint3 ntg[[threads_per_threadgroup]]) {
2324
+ const int64_t ir = tgpig.x;
2325
+ const int64_t i2 = tgpig.y;
2326
+ const int64_t i3 = tgpig.z;
2327
+
2328
+ const int64_t nc = args.ne10;
2329
+ //const int64_t ncs = args.ne00;
2330
+ //const int64_t nr = args.ne01;
2331
+ //const int64_t n_t = args.ne1;
2332
+ //const int64_t n_s = args.ne2;
2053
2333
 
2054
- const int64_t i0 = tpitg.x;
2055
- const int64_t i1 = 0;
2056
- const int64_t ir = tgpig.x; // current head
2057
- const int64_t i3 = tgpig.y; // current seq
2334
+ device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2335
+ device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
2336
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2058
2337
 
2059
- const uint64_t nb00 = sizeof(float);
2060
- const uint64_t nb10 = sizeof(float);
2061
- const uint64_t nb20 = sizeof(float);
2338
+ float sumf = 0.0f;
2062
2339
 
2063
- const int64_t nc = args.d_state;
2064
- const int64_t nr = args.d_inner;
2065
- const int64_t nh = args.n_head;
2066
- const int64_t ng = args.n_group;
2067
- const int64_t n_t = args.n_seq_tokens;
2340
+ for (int64_t i0 = 0; i0 < nc/4; ++i0) {
2341
+ sumf += dot(s[i0], c[i0]);
2342
+ }
2068
2343
 
2069
- const int64_t s_off = args.s_off;
2344
+ x[0] = sumf;
2345
+ }
2070
2346
 
2071
- device const int32_t * ids = (device const int32_t *) src6;
2347
+ constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
2072
2348
 
2073
- device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
2074
- device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2075
- const int64_t i = i0 + i1*nc;
2076
- const int64_t g = ir / (nh / ng); // repeat_interleave
2077
- float s0 = s0_buff[i];
2078
- float s = s_buff[i];
2079
-
2080
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
2081
- device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
2082
- device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
2083
- device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
2084
- device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
2085
- device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
2086
-
2087
- for (int64_t i2 = 0; i2 < n_t; ++i2) {
2088
- device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
2089
- device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
2090
- device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
2091
- device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
2092
- device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2093
-
2094
- const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
2095
- const float x_dt = x[0] * dt_soft_plus;
2096
-
2097
- const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
2098
- s = state;
2099
-
2100
- // Parallel sum: This relies on the fact that this kernel will be
2101
- // dispatched with each threadgroup having (d_state, 1, 1) threads which
2102
- // are subdivided into SIMD groups of size `sgptg`. The goal is to
2103
- // compute y = sum({state * C[i] for i in range(d_state)}).
2104
- // To parallelize this effectively, we first use simd_sum over each SIMD
2105
- // group to compute the sum of each SIMD group, then place the result in
2106
- // the SIMD group's indexed bucket in the shared memory. We then sum
2107
- // over the individual group sums to compute the final sum.
2108
-
2109
- // Computed for each thread
2110
- float sumf = state * C[i0];
2111
-
2112
- // Sum the threads in the simd group => simd sum
2113
- sumf = simd_sum(sumf);
2114
-
2115
- if (sgptg > 1) {
2116
-
2117
- // Once per simd group, place the group sum into the shared buffer
2118
- if (tiisg == 0) {
2119
- shared[sgitg] = sumf;
2120
- }
2349
+ // Batched version: each threadgroup processes multiple tokens for better efficiency
2350
+ // Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
2351
+ kernel void kernel_ssm_conv_f32_f32_batched(
2352
+ constant ggml_metal_kargs_ssm_conv & args,
2353
+ device const void * src0,
2354
+ device const void * src1,
2355
+ device float * dst,
2356
+ uint3 tgpig[[threadgroup_position_in_grid]],
2357
+ uint3 tpitg[[thread_position_in_threadgroup]],
2358
+ uint3 ntg[[threads_per_threadgroup]]) {
2359
+ // tgpig.x = row index (ir)
2360
+ // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
2361
+ // tgpig.z = sequence index (i3)
2362
+ // tpitg.x = thread within batch (0..BATCH_SIZE-1)
2363
+ const short BATCH_SIZE = FC_ssm_conv_bs;
2364
+
2365
+ const int64_t ir = tgpig.x;
2366
+ const int64_t i2_base = tgpig.y * BATCH_SIZE;
2367
+ const int64_t i3 = tgpig.z;
2368
+ const int64_t i2_off = tpitg.x;
2369
+ const int64_t i2 = i2_base + i2_off;
2370
+
2371
+ const int64_t nc = args.ne10; // conv kernel size (typically 4)
2372
+ const int64_t n_t = args.ne1; // number of tokens
2373
+
2374
+ // Bounds check for partial batches at the end
2375
+ if (i2 >= n_t) {
2376
+ return;
2377
+ }
2121
2378
 
2122
- // Wait for all threads in the threadgroup to reach this point. This
2123
- // ensures that all elements of the shared buffer are populated with the
2124
- // sum of the individual simd groups.
2125
- threadgroup_barrier(mem_flags::mem_threadgroup);
2379
+ // Load conv weights (shared across all tokens for this row)
2380
+ device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
2126
2381
 
2127
- // For simd group 0 at indices < num simd groups, extract the shared
2128
- // simd sum
2129
- sumf = 0.0f;
2130
- if (sgitg == 0) {
2131
- if (tiisg < sgptg) {
2132
- sumf = shared[tiisg];
2133
- }
2134
- sumf = simd_sum(sumf);
2135
- if (tiisg == 0) {
2136
- y[0] = sumf;
2137
- }
2138
- }
2139
- } else if (tiisg == 0) {
2140
- y[0] = sumf;
2141
- }
2382
+ // Load source for this specific token
2383
+ device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2384
+
2385
+ // Output location for this token
2386
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2142
2387
 
2143
- // recurse
2144
- s0 = s;
2388
+ float sumf = 0.0f;
2389
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
2390
+ sumf += s[i0] * c[i0];
2145
2391
  }
2146
2392
 
2147
- // Assign the final state to the output buffer
2148
- s_buff[i] = s;
2393
+ x[0] = sumf;
2394
+ }
2395
+
2396
+ kernel void kernel_ssm_conv_f32_f32_batched_4(
2397
+ constant ggml_metal_kargs_ssm_conv & args,
2398
+ device const void * src0,
2399
+ device const void * src1,
2400
+ device float * dst,
2401
+ uint3 tgpig[[threadgroup_position_in_grid]],
2402
+ uint3 tpitg[[thread_position_in_threadgroup]],
2403
+ uint3 ntg[[threads_per_threadgroup]]) {
2404
+ // tgpig.x = row index (ir)
2405
+ // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
2406
+ // tgpig.z = sequence index (i3)
2407
+ // tpitg.x = thread within batch (0..BATCH_SIZE-1)
2408
+ const short BATCH_SIZE = FC_ssm_conv_bs;
2409
+
2410
+ const int64_t ir = tgpig.x;
2411
+ const int64_t i2_base = tgpig.y * BATCH_SIZE;
2412
+ const int64_t i3 = tgpig.z;
2413
+ const int64_t i2_off = tpitg.x;
2414
+ const int64_t i2 = i2_base + i2_off;
2415
+
2416
+ const int64_t nc = args.ne10; // conv kernel size (typically 4)
2417
+ const int64_t n_t = args.ne1; // number of tokens
2418
+
2419
+ // Bounds check for partial batches at the end
2420
+ if (i2 >= n_t) {
2421
+ return;
2422
+ }
2423
+
2424
+ // Load conv weights (shared across all tokens for this row)
2425
+ device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
2426
+
2427
+ // Load source for this specific token
2428
+ device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2429
+
2430
+ // Output location for this token
2431
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2432
+
2433
+ float sumf = 0.0f;
2434
+ for (int64_t i0 = 0; i0 < nc/4; ++i0) {
2435
+ sumf += dot(s[i0], c[i0]);
2436
+ }
2437
+
2438
+ x[0] = sumf;
2149
2439
  }
2150
2440
 
2151
2441
  // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
2152
- kernel void kernel_ssm_scan_group_f32(
2442
+ // Optimized version: reduces redundant memory loads by having one thread load shared values
2443
+ kernel void kernel_ssm_scan_f32(
2153
2444
  constant ggml_metal_kargs_ssm_scan & args,
2154
2445
  device const void * src0,
2155
2446
  device const void * src1,
@@ -2160,103 +2451,111 @@ kernel void kernel_ssm_scan_group_f32(
2160
2451
  device const void * src6,
2161
2452
  device float * dst,
2162
2453
  threadgroup float * shared [[threadgroup(0)]],
2163
- uint3 tgpig[[threadgroup_position_in_grid]],
2164
- uint3 tpitg[[thread_position_in_threadgroup]],
2165
- ushort sgitg[[simdgroup_index_in_threadgroup]],
2166
- ushort tiisg[[thread_index_in_simdgroup]],
2167
- ushort sgptg[[simdgroups_per_threadgroup]],
2168
- uint3 tgpg[[threadgroups_per_grid]]) {
2454
+ uint3 tgpig[[threadgroup_position_in_grid]],
2455
+ ushort3 tpitg[[thread_position_in_threadgroup]],
2456
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
2457
+ ushort tiisg[[thread_index_in_simdgroup]],
2458
+ ushort sgptg[[simdgroups_per_threadgroup]],
2459
+ uint3 tgpg[[threadgroups_per_grid]]) {
2460
+ constexpr short NW = N_SIMDWIDTH;
2169
2461
 
2170
- const int64_t i0 = tpitg.x;
2171
- const int64_t i1 = tgpig.x;
2172
- const int64_t ir = tgpig.y; // current head
2173
- const int64_t i3 = tgpig.z; // current seq
2462
+ // Shared memory layout:
2463
+ // [0..sgptg*NW-1]: partial sums for reduction (existing)
2464
+ // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
2465
+ // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
2466
+ threadgroup float * shared_sums = shared;
2467
+ threadgroup float * shared_x_dt = shared + sgptg * NW;
2468
+ threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
2469
+
2470
+ shared_sums[tpitg.x] = 0.0f;
2174
2471
 
2175
- const uint64_t nb00 = sizeof(float);
2176
- const uint64_t nb10 = sizeof(float);
2177
- const uint64_t nb20 = sizeof(float);
2472
+ const int32_t i0 = tpitg.x;
2473
+ const int32_t i1 = tgpig.x;
2474
+ const int32_t ir = tgpig.y; // current head
2475
+ const int32_t i3 = tgpig.z; // current seq
2178
2476
 
2179
- const int64_t nc = args.d_state;
2180
- const int64_t nr = args.d_inner;
2181
- const int64_t nh = args.n_head;
2182
- const int64_t ng = args.n_group;
2183
- const int64_t n_t = args.n_seq_tokens;
2477
+ const int32_t nc = args.d_state;
2478
+ const int32_t nr = args.d_inner;
2479
+ const int32_t nh = args.n_head;
2480
+ const int32_t ng = args.n_group;
2481
+ const int32_t n_t = args.n_seq_tokens;
2184
2482
 
2185
- const int64_t s_off = args.s_off;
2483
+ const int32_t s_off = args.s_off;
2186
2484
 
2187
2485
  device const int32_t * ids = (device const int32_t *) src6;
2188
2486
 
2189
2487
  device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
2190
2488
  device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2191
- const int64_t i = i0 + i1*nc;
2192
- const int64_t g = ir / (nh / ng); // repeat_interleave
2489
+
2490
+ const int32_t i = i0 + i1*nc;
2491
+ const int32_t g = ir / (nh / ng); // repeat_interleave
2492
+
2193
2493
  float s0 = s0_buff[i];
2194
- float s = s_buff[i];
2195
-
2196
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
2197
- device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
2198
- device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
2199
- device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
2200
- device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
2201
- device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
2202
-
2203
- for (int64_t i2 = 0; i2 < n_t; ++i2) {
2204
- device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
2205
- device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
2206
- device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
2207
- device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
2208
- device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2209
-
2210
- const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
2211
- const float x_dt = x[0] * dt_soft_plus;
2212
- const float dA = exp(dt_soft_plus * A[0]);
2213
-
2214
- const float state = (s0 * dA) + (B[i0] * x_dt);
2215
- s = state;
2216
-
2217
- // Parallel sum: This relies on the fact that this kernel will be
2218
- // dispatched with each threadgroup having (d_state, 1, 1) threads which
2219
- // are subdivided into SIMD groups of size `sgptg`. The goal is to
2220
- // compute y = sum({state * C[i] for i in range(d_state)}).
2221
- // To parallelize this effectively, we first use simd_sum over each SIMD
2222
- // group to compute the sum of each SIMD group, then place the result in
2223
- // the SIMD group's indexed bucket in the shared memory. We then sum
2224
- // over the individual group sums to compute the final sum.
2225
-
2226
- // Computed for each thread
2227
- float sumf = state * C[i0];
2228
-
2229
- // Sum the threads in the simd group => simd sum
2230
- sumf = simd_sum(sumf);
2231
-
2232
- // Once per simd group, place the group sum into the shared buffer
2233
- if (tiisg == 0) {
2234
- shared[sgitg] = sumf;
2494
+ float s = 0.0f;
2495
+
2496
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
2497
+
2498
+ const float A0 = A[i0%args.ne30];
2499
+
2500
+ device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
2501
+ device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
2502
+ device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
2503
+ device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
2504
+
2505
+ device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
2506
+
2507
+ for (int i2 = 0; i2 < n_t; i2 += sgptg) {
2508
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2509
+
2510
+ // Pre-compute x_dt and dA for this batch of tokens
2511
+ // Only first sgptg threads do the loads and expensive math
2512
+ if (i0 < sgptg && i2 + i0 < n_t) {
2513
+ // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
2514
+ device const float * x_t = x + i0 * args.ns12;
2515
+ device const float * dt_t = dt + i0 * args.ns21;
2516
+
2517
+ const float dt0 = dt_t[0];
2518
+ const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
2519
+ shared_x_dt[i0] = x_t[0] * dtsp;
2520
+ shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
2235
2521
  }
2236
2522
 
2237
- // Wait for all threads in the threadgroup to reach this point. This
2238
- // ensures that all elements of the shared buffer are populated with the
2239
- // sum of the individual simd groups.
2240
2523
  threadgroup_barrier(mem_flags::mem_threadgroup);
2241
2524
 
2242
- // For simd group 0 at indices < num simd groups, extract the shared
2243
- // simd sum
2244
- sumf = 0.0f;
2245
- if (sgitg == 0) {
2246
- if (tiisg < sgptg) {
2247
- sumf = shared[tiisg];
2248
- }
2249
- sumf = simd_sum(sumf);
2525
+ for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2526
+ const float x_dt = shared_x_dt[t];
2527
+ const float dA = exp(shared_dA[t] * A0);
2528
+
2529
+ s = (s0 * dA) + (B[i0] * x_dt);
2530
+
2531
+ const float sumf = simd_sum(s * C[i0]);
2532
+
2250
2533
  if (tiisg == 0) {
2251
- y[0] = sumf;
2534
+ shared_sums[t*NW + sgitg] = sumf;
2252
2535
  }
2536
+
2537
+ // recurse
2538
+ s0 = s;
2539
+
2540
+ B += args.ns42;
2541
+ C += args.ns52;
2542
+ }
2543
+
2544
+ // Advance pointers for next batch
2545
+ x += sgptg * args.ns12;
2546
+ dt += sgptg * args.ns21;
2547
+
2548
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2549
+
2550
+ const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
2551
+
2552
+ if (tiisg == 0 && i2 + sgitg < n_t) {
2553
+ y[sgitg*nh*nr] = sumf;
2253
2554
  }
2254
2555
 
2255
- // recurse
2256
- s0 = s;
2556
+ y += sgptg*nh*nr;
2257
2557
  }
2258
2558
 
2259
- // Assign the final state to the output buffer
2260
2559
  s_buff[i] = s;
2261
2560
  }
2262
2561
 
@@ -3761,6 +4060,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_
3761
4060
  template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
3762
4061
  #endif
3763
4062
 
4063
+ constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
4064
+
3764
4065
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
3765
4066
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
3766
4067
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -3830,7 +4131,7 @@ kernel void kernel_rope_norm(
3830
4131
 
3831
4132
  const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
3832
4133
 
3833
- const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
4134
+ const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
3834
4135
 
3835
4136
  rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
3836
4137
 
@@ -3883,7 +4184,7 @@ kernel void kernel_rope_neox(
3883
4184
 
3884
4185
  const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
3885
4186
 
3886
- const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
4187
+ const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
3887
4188
 
3888
4189
  rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
3889
4190
 
@@ -3941,20 +4242,32 @@ kernel void kernel_rope_multi(
3941
4242
  const int sector = ic % sect_dims;
3942
4243
 
3943
4244
  float theta_base;
3944
- if (sector < args.sect_0) {
3945
- theta_base = (float) pos[i2];
3946
- } else if (sector < sec_w01) {
3947
- theta_base = (float) pos[i2 + args.ne02];
3948
- } else if (sector < sec_w012) {
3949
- theta_base = (float) pos[i2 + args.ne02 * 2];
4245
+ if (FC_rope_is_imrope) {
4246
+ if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
4247
+ theta_base = (float) pos[i2 + args.ne02 * 1];
4248
+ } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
4249
+ theta_base = (float) pos[i2 + args.ne02 * 2];
4250
+ } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
4251
+ theta_base = (float) pos[i2 + args.ne02 * 0];
4252
+ } else { // e
4253
+ theta_base = (float) pos[i2 + args.ne02 * 3];
4254
+ }
3950
4255
  } else {
3951
- theta_base = (float) pos[i2 + args.ne02 * 3];
4256
+ if (sector < args.sect_0) {
4257
+ theta_base = (float) pos[i2];
4258
+ } else if (sector < sec_w01) {
4259
+ theta_base = (float) pos[i2 + args.ne02 * 1];
4260
+ } else if (sector < sec_w012) {
4261
+ theta_base = (float) pos[i2 + args.ne02 * 2];
4262
+ } else {
4263
+ theta_base = (float) pos[i2 + args.ne02 * 3];
4264
+ }
3952
4265
  }
3953
4266
  // end of mrope
3954
4267
 
3955
4268
  const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
3956
4269
 
3957
- const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
4270
+ const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
3958
4271
 
3959
4272
  rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
3960
4273
 
@@ -4021,7 +4334,7 @@ kernel void kernel_rope_vision(
4021
4334
  const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
4022
4335
  // end of mrope
4023
4336
 
4024
- const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
4337
+ const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
4025
4338
 
4026
4339
  rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
4027
4340
 
@@ -4178,6 +4491,120 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4178
4491
  //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4179
4492
  //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4180
4493
 
4494
+ template <typename TK>
4495
+ kernel void kernel_conv_2d(
4496
+ constant ggml_metal_kargs_conv_2d & args,
4497
+ device const char * weights,
4498
+ device const char * src,
4499
+ device char * dst,
4500
+ uint3 tgpig[[threadgroup_position_in_grid]],
4501
+ uint3 tgpg[[threadgroups_per_grid]],
4502
+ uint3 tpitg[[thread_position_in_threadgroup]],
4503
+ uint3 ntg[[threads_per_threadgroup]]) {
4504
+
4505
+ const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
4506
+ const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
4507
+ const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
4508
+ const uint thread_index = tg_index * threads_per_tg + local_thread;
4509
+ const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
4510
+ const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
4511
+
4512
+ for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
4513
+ uint64_t tmp = index;
4514
+
4515
+ const int32_t ow = tmp % args.OW; tmp /= args.OW;
4516
+ const int32_t oh = tmp % args.OH; tmp /= args.OH;
4517
+ const int32_t oc = tmp % args.OC; tmp /= args.OC;
4518
+ const int32_t n = tmp;
4519
+
4520
+ float acc = 0.0f;
4521
+
4522
+ const int32_t base_x = ow*args.s0 - args.p0;
4523
+ const int32_t base_y = oh*args.s1 - args.p1;
4524
+
4525
+ int32_t ky_start = 0;
4526
+ if (base_y < 0) {
4527
+ ky_start = (-base_y + args.d1 - 1)/args.d1;
4528
+ }
4529
+ int32_t ky_end = args.KH;
4530
+ const int32_t y_max = args.IH - 1 - base_y;
4531
+ if (y_max < 0) {
4532
+ ky_end = ky_start;
4533
+ } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
4534
+ ky_end = min(ky_end, y_max/args.d1 + 1);
4535
+ }
4536
+
4537
+ int32_t kx_start = 0;
4538
+ if (base_x < 0) {
4539
+ kx_start = (-base_x + args.d0 - 1)/args.d0;
4540
+ }
4541
+ int32_t kx_end = args.KW;
4542
+ const int32_t x_max = args.IW - 1 - base_x;
4543
+ if (x_max < 0) {
4544
+ kx_end = kx_start;
4545
+ } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
4546
+ kx_end = min(kx_end, x_max/args.d0 + 1);
4547
+ }
4548
+
4549
+ if (ky_start < ky_end && kx_start < kx_end) {
4550
+ const uint64_t src_base_n = (uint64_t) n * args.nb13;
4551
+ const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
4552
+
4553
+ for (int32_t ic = 0; ic < args.IC; ++ic) {
4554
+ const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
4555
+ const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
4556
+
4557
+ for (int32_t ky = ky_start; ky < ky_end; ++ky) {
4558
+ const int32_t iy = base_y + ky*args.d1;
4559
+ const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
4560
+ const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
4561
+
4562
+ for (int32_t kx = kx_start; kx < kx_end; ++kx) {
4563
+ const int32_t ix = base_x + kx*args.d0;
4564
+ const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
4565
+ const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
4566
+
4567
+ const float x = *(device const float *)(src + src_offs);
4568
+ const float w = (float) (*(device const TK *)(weights + w_offs));
4569
+
4570
+ acc += x * w;
4571
+ }
4572
+ }
4573
+ }
4574
+ }
4575
+
4576
+ const uint64_t dst_offs =
4577
+ (uint64_t) n * args.nb3 +
4578
+ (uint64_t) oc * args.nb2 +
4579
+ (uint64_t) oh * args.nb1 +
4580
+ (uint64_t) ow * args.nb0;
4581
+
4582
+ *(device float *)(dst + dst_offs) = acc;
4583
+ }
4584
+ }
4585
+
4586
+ template [[host_name("kernel_conv_2d_f32_f32")]]
4587
+ kernel void kernel_conv_2d<float>(
4588
+ constant ggml_metal_kargs_conv_2d & args,
4589
+ device const char * weights,
4590
+ device const char * src,
4591
+ device char * dst,
4592
+ uint3 tgpig[[threadgroup_position_in_grid]],
4593
+ uint3 tgpg[[threadgroups_per_grid]],
4594
+ uint3 tpitg[[thread_position_in_threadgroup]],
4595
+ uint3 ntg[[threads_per_threadgroup]]);
4596
+
4597
+ template [[host_name("kernel_conv_2d_f16_f32")]]
4598
+ kernel void kernel_conv_2d<half>(
4599
+ constant ggml_metal_kargs_conv_2d & args,
4600
+ device const char * weights,
4601
+ device const char * src,
4602
+ device char * dst,
4603
+ uint3 tgpig[[threadgroup_position_in_grid]],
4604
+ uint3 tgpg[[threadgroups_per_grid]],
4605
+ uint3 tpitg[[thread_position_in_threadgroup]],
4606
+ uint3 ntg[[threads_per_threadgroup]]);
4607
+
4181
4608
  typedef void (conv_transpose_1d_t)(
4182
4609
  constant ggml_metal_kargs_conv_transpose_1d & args,
4183
4610
  device const float * src0,
@@ -4231,6 +4658,97 @@ kernel void kernel_conv_transpose_1d<half>(
4231
4658
  uint3 tgpig[[threadgroup_position_in_grid]],
4232
4659
  uint3 tgpg[[threadgroups_per_grid]]);
4233
4660
 
4661
+
4662
+ typedef void (conv_transpose_2d_t)(
4663
+ constant ggml_metal_kargs_conv_transpose_2d & args,
4664
+ device const float * src0,
4665
+ device const float * src1,
4666
+ device char * dst,
4667
+ uint3 tgpig[[threadgroup_position_in_grid]],
4668
+ uint3 tgpg[[threadgroups_per_grid]]);
4669
+
4670
+ template <typename T>
4671
+ kernel void kernel_conv_transpose_2d(
4672
+ constant ggml_metal_kargs_conv_transpose_2d & args,
4673
+ device const T * src0,
4674
+ device const float * src1,
4675
+ device char * dst,
4676
+ threadgroup float * shared_sum [[threadgroup(0)]],
4677
+ uint3 tgpig[[threadgroup_position_in_grid]],
4678
+ uint3 tpitg[[thread_position_in_threadgroup]],
4679
+ uint3 ntg[[threads_per_threadgroup]]) {
4680
+
4681
+ const int64_t out_x = tgpig[0];
4682
+ const int64_t out_y = tgpig[1];
4683
+ const int64_t out_c = tgpig[2];
4684
+
4685
+ const int64_t kw = tpitg[0];
4686
+ const int64_t kh = tpitg[1];
4687
+
4688
+ float v = 0.0f;
4689
+
4690
+ for (int64_t in_c = 0; in_c < args.IC; in_c++) {
4691
+ int64_t in_y = out_y - kh;
4692
+
4693
+ if (in_y < 0 || in_y % args.s0) continue;
4694
+
4695
+ in_y /= args.s0;
4696
+
4697
+ if (in_y >= args.IH) continue;
4698
+
4699
+ int64_t in_x = out_x - kw;
4700
+
4701
+ if (in_x < 0 || in_x % args.s0) continue;
4702
+
4703
+ in_x /= args.s0;
4704
+
4705
+ if (in_x >= args.IW) continue;
4706
+
4707
+ const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
4708
+ const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
4709
+
4710
+ v += (float)src0[kernel_idx] * src1[input_idx];
4711
+ }
4712
+
4713
+ const uint tid = tpitg.y * ntg.x + tpitg.x;
4714
+ shared_sum[tid] = v;
4715
+
4716
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4717
+
4718
+ if (tid == 0) {
4719
+ float total = 0.0f;
4720
+ const uint num_threads = ntg.x * ntg.y;
4721
+ for (uint i = 0; i < num_threads; i++) {
4722
+ total += shared_sum[i];
4723
+ }
4724
+
4725
+ device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
4726
+ dst_ptr[0] = total;
4727
+ }
4728
+ }
4729
+
4730
+ template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
4731
+ kernel void kernel_conv_transpose_2d<float>(
4732
+ constant ggml_metal_kargs_conv_transpose_2d & args,
4733
+ device const float * src0,
4734
+ device const float * src1,
4735
+ device char * dst,
4736
+ threadgroup float * shared_sum [[threadgroup(0)]],
4737
+ uint3 tgpig[[threadgroup_position_in_grid]],
4738
+ uint3 tpitg[[thread_position_in_threadgroup]],
4739
+ uint3 ntg[[threads_per_threadgroup]]);
4740
+
4741
+ template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
4742
+ kernel void kernel_conv_transpose_2d<half>(
4743
+ constant ggml_metal_kargs_conv_transpose_2d & args,
4744
+ device const half * src0,
4745
+ device const float * src1,
4746
+ device char * dst,
4747
+ threadgroup float * shared_sum [[threadgroup(0)]],
4748
+ uint3 tgpig[[threadgroup_position_in_grid]],
4749
+ uint3 tpitg[[thread_position_in_threadgroup]],
4750
+ uint3 ntg[[threads_per_threadgroup]]);
4751
+
4234
4752
  kernel void kernel_upscale_f32(
4235
4753
  constant ggml_metal_kargs_upscale & args,
4236
4754
  device const char * src0,
@@ -4368,69 +4886,234 @@ kernel void kernel_timestep_embedding_f32(
4368
4886
  // bitonic sort implementation following the CUDA kernels as reference
4369
4887
  typedef void (argsort_t)(
4370
4888
  constant ggml_metal_kargs_argsort & args,
4371
- device const float * x,
4889
+ device const char * src0,
4372
4890
  device int32_t * dst,
4373
- threadgroup int32_t * shared_values [[threadgroup(0)]],
4374
- uint3 tgpig[[threadgroup_position_in_grid]],
4375
- uint3 tpitg[[thread_position_in_threadgroup]]);
4891
+ threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
4892
+ uint3 tgpig[[threadgroup_position_in_grid]],
4893
+ ushort3 tpitg[[thread_position_in_threadgroup]],
4894
+ ushort3 ntg[[threads_per_threadgroup]]);
4376
4895
 
4377
4896
  template<ggml_sort_order order>
4378
4897
  kernel void kernel_argsort_f32_i32(
4379
4898
  constant ggml_metal_kargs_argsort & args,
4380
- device const float * x,
4899
+ device const char * src0,
4381
4900
  device int32_t * dst,
4382
- threadgroup int32_t * shared_values [[threadgroup(0)]],
4383
- uint3 tgpig[[threadgroup_position_in_grid]],
4384
- uint3 tpitg[[thread_position_in_threadgroup]]) {
4901
+ threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
4902
+ uint3 tgpig[[threadgroup_position_in_grid]],
4903
+ ushort3 tpitg[[thread_position_in_threadgroup]],
4904
+ ushort3 ntg[[threads_per_threadgroup]]) {
4385
4905
  // bitonic sort
4386
- int col = tpitg[0];
4387
- int row = tgpig[1];
4906
+ const int col = tpitg[0];
4907
+ const int ib = tgpig[0] / args.ne01;
4388
4908
 
4389
- if (col >= args.ncols_pad) return;
4909
+ const int i00 = ib*ntg.x;
4910
+ const int i01 = tgpig[0] % args.ne01;
4911
+ const int i02 = tgpig[1];
4912
+ const int i03 = tgpig[2];
4390
4913
 
4391
- device const float * x_row = x + row * args.ncols;
4392
- threadgroup int32_t * dst_row = shared_values;
4914
+ device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
4393
4915
 
4394
4916
  // initialize indices
4395
- dst_row[col] = col;
4917
+ shmem_i32[col] = i00 + col;
4396
4918
 
4397
4919
  threadgroup_barrier(mem_flags::mem_threadgroup);
4398
4920
 
4399
- for (int k = 2; k <= args.ncols_pad; k *= 2) {
4921
+ for (int k = 2; k <= ntg.x; k *= 2) {
4400
4922
  for (int j = k / 2; j > 0; j /= 2) {
4401
4923
  int ixj = col ^ j;
4402
4924
  if (ixj > col) {
4403
4925
  if ((col & k) == 0) {
4404
- if (dst_row[col] >= args.ncols ||
4405
- (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
4406
- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
4407
- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
4926
+ if (shmem_i32[col] >= args.ne00 ||
4927
+ (shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
4928
+ src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
4929
+ src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
4408
4930
  ) {
4409
- SWAP(dst_row[col], dst_row[ixj]);
4931
+ SWAP(shmem_i32[col], shmem_i32[ixj]);
4410
4932
  }
4411
4933
  } else {
4412
- if (dst_row[ixj] >= args.ncols ||
4413
- (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
4414
- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
4415
- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
4934
+ if (shmem_i32[ixj] >= args.ne00 ||
4935
+ (shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
4936
+ src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
4937
+ src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
4416
4938
  ) {
4417
- SWAP(dst_row[col], dst_row[ixj]);
4939
+ SWAP(shmem_i32[col], shmem_i32[ixj]);
4418
4940
  }
4419
4941
  }
4420
4942
  }
4943
+
4421
4944
  threadgroup_barrier(mem_flags::mem_threadgroup);
4422
4945
  }
4423
4946
  }
4424
4947
 
4948
+ const int64_t i0 = ib*args.top_k;
4949
+
4425
4950
  // copy the result to dst without the padding
4426
- if (col < args.ncols) {
4427
- dst[row * args.ncols + col] = dst_row[col];
4951
+ if (i0 + col < args.ne0 && col < args.top_k) {
4952
+ dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
4953
+
4954
+ dst[col] = shmem_i32[col];
4428
4955
  }
4429
4956
  }
4430
4957
 
4431
4958
  template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
4432
4959
  template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
4433
4960
 
4961
+ typedef void (argsort_merge_t)(
4962
+ constant ggml_metal_kargs_argsort_merge & args,
4963
+ device const char * src0,
4964
+ device const int32_t * tmp,
4965
+ device int32_t * dst,
4966
+ uint3 tgpig[[threadgroup_position_in_grid]],
4967
+ ushort3 tpitg[[thread_position_in_threadgroup]],
4968
+ ushort3 ntg[[threads_per_threadgroup]]);
4969
+
4970
+ template<ggml_sort_order order>
4971
+ kernel void kernel_argsort_merge_f32_i32(
4972
+ constant ggml_metal_kargs_argsort_merge & args,
4973
+ device const char * src0,
4974
+ device const int32_t * tmp,
4975
+ device int32_t * dst,
4976
+ uint3 tgpig[[threadgroup_position_in_grid]],
4977
+ ushort3 tpitg[[thread_position_in_threadgroup]],
4978
+ ushort3 ntg[[threads_per_threadgroup]]) {
4979
+
4980
+ const int im = tgpig[0] / args.ne01;
4981
+ const int i01 = tgpig[0] % args.ne01;
4982
+ const int i02 = tgpig[1];
4983
+ const int i03 = tgpig[2];
4984
+
4985
+ const int start = im * (2 * args.len);
4986
+
4987
+ const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
4988
+ const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
4989
+
4990
+ const int total = len0 + len1;
4991
+
4992
+ device const int32_t * tmp0 = tmp + start
4993
+ + i01*args.ne0
4994
+ + i02*args.ne0*args.ne01
4995
+ + i03*args.ne0*args.ne01*args.ne02;
4996
+
4997
+ device const int32_t * tmp1 = tmp0 + args.len;
4998
+
4999
+ dst += start
5000
+ + i01*args.top_k
5001
+ + i02*args.top_k*args.ne01
5002
+ + i03*args.top_k*args.ne01*args.ne02;
5003
+
5004
+ device const float * src0_row = (device const float *)(src0
5005
+ + args.nb01*i01
5006
+ + args.nb02*i02
5007
+ + args.nb03*i03);
5008
+
5009
+ if (total == 0) {
5010
+ return;
5011
+ }
5012
+
5013
+ const int chunk = (total + ntg.x - 1) / ntg.x;
5014
+
5015
+ const int k0 = tpitg.x * chunk;
5016
+ const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
5017
+
5018
+ if (k0 >= args.top_k) {
5019
+ return;
5020
+ }
5021
+
5022
+ if (k0 >= total) {
5023
+ return;
5024
+ }
5025
+
5026
+ int low = k0 > len1 ? k0 - len1 : 0;
5027
+ int high = MIN(k0, len0);
5028
+
5029
+ // binary-search partition (i, j) such that i + j = k
5030
+ while (low < high) {
5031
+ const int mid = (low + high) >> 1;
5032
+
5033
+ const int32_t idx0 = tmp0[mid];
5034
+ const int32_t idx1 = tmp1[k0 - mid - 1];
5035
+
5036
+ const float val0 = src0_row[idx0];
5037
+ const float val1 = src0_row[idx1];
5038
+
5039
+ bool take_left;
5040
+ if (order == GGML_SORT_ORDER_ASC) {
5041
+ take_left = (val0 <= val1);
5042
+ } else {
5043
+ take_left = (val0 >= val1);
5044
+ }
5045
+
5046
+ if (take_left) {
5047
+ low = mid + 1;
5048
+ } else {
5049
+ high = mid;
5050
+ }
5051
+ }
5052
+
5053
+ int i = low;
5054
+ int j = k0 - i;
5055
+
5056
+ // keep the merge fronts into registers
5057
+ int32_t idx0 = 0;
5058
+ float val0 = 0.0f;
5059
+ if (i < len0) {
5060
+ idx0 = tmp0[i];
5061
+ val0 = src0_row[idx0];
5062
+ }
5063
+
5064
+ int32_t idx1 = 0;
5065
+ float val1 = 0.0f;
5066
+ if (j < len1) {
5067
+ idx1 = tmp1[j];
5068
+ val1 = src0_row[idx1];
5069
+ }
5070
+
5071
+ for (int k = k0; k < k1; ++k) {
5072
+ int32_t out_idx;
5073
+
5074
+ if (i >= len0) {
5075
+ while (k < k1) {
5076
+ dst[k++] = tmp1[j++];
5077
+ }
5078
+ break;
5079
+ } else if (j >= len1) {
5080
+ while (k < k1) {
5081
+ dst[k++] = tmp0[i++];
5082
+ }
5083
+ break;
5084
+ } else {
5085
+ bool take_left;
5086
+
5087
+ if (order == GGML_SORT_ORDER_ASC) {
5088
+ take_left = (val0 <= val1);
5089
+ } else {
5090
+ take_left = (val0 >= val1);
5091
+ }
5092
+
5093
+ if (take_left) {
5094
+ out_idx = idx0;
5095
+ ++i;
5096
+ if (i < len0) {
5097
+ idx0 = tmp0[i];
5098
+ val0 = src0_row[idx0];
5099
+ }
5100
+ } else {
5101
+ out_idx = idx1;
5102
+ ++j;
5103
+ if (j < len1) {
5104
+ idx1 = tmp1[j];
5105
+ val1 = src0_row[idx1];
5106
+ }
5107
+ }
5108
+ }
5109
+
5110
+ dst[k] = out_idx;
5111
+ }
5112
+ }
5113
+
5114
+ template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
5115
+ template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
5116
+
4434
5117
  kernel void kernel_leaky_relu_f32(
4435
5118
  constant ggml_metal_kargs_leaky_relu & args,
4436
5119
  device const float * src0,
@@ -4449,10 +5132,142 @@ kernel void kernel_leaky_relu_f32_4(
4449
5132
  dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
4450
5133
  }
4451
5134
 
5135
+ constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
5136
+
5137
+ constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
5138
+
5139
+ // pad the last chunk of C elements of k and v into a an extra pad buffer
5140
+ kernel void kernel_flash_attn_ext_pad(
5141
+ constant ggml_metal_kargs_flash_attn_ext_pad & args,
5142
+ device const char * k,
5143
+ device const char * v,
5144
+ device const char * mask,
5145
+ device char * dst,
5146
+ uint3 tgpig[[threadgroup_position_in_grid]],
5147
+ ushort tiitg[[thread_index_in_threadgroup]],
5148
+ ushort3 ntg[[threads_per_threadgroup]]) {
5149
+ const int32_t C = FC_flash_attn_ext_pad_ncpsg;
5150
+
5151
+ device char * k_pad = dst;
5152
+ device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
5153
+ device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
5154
+
5155
+ const int32_t icp = args.ne11 % C;
5156
+ const int32_t ic0 = args.ne11 - icp;
5157
+
5158
+ const int32_t i1 = tgpig[0];
5159
+ const int32_t i2 = tgpig[1];
5160
+ const int32_t i3 = tgpig[2];
5161
+
5162
+ if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
5163
+ device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
5164
+ device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
5165
+
5166
+ device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
5167
+ device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
5168
+
5169
+ if (i1 >= icp) {
5170
+ // here it is not important the exact value that will be used as we rely on masking out the scores in the attention
5171
+ for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
5172
+ k_dst[i] = 0;
5173
+ }
5174
+ for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
5175
+ v_dst[i] = 0;
5176
+ }
5177
+ } else {
5178
+ for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
5179
+ k_dst[i] = k_src[i];
5180
+ }
5181
+ for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
5182
+ v_dst[i] = v_src[i];
5183
+ }
5184
+ }
5185
+ }
5186
+
5187
+ if (FC_flash_attn_ext_pad_has_mask) {
5188
+ if (i2 < args.ne32 && i3 < args.ne33) {
5189
+ for (int ib = i1; ib < args.ne31; ib += C) {
5190
+ device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
5191
+ device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
5192
+
5193
+ for (int i = tiitg; i < C; i += ntg.x) {
5194
+ if (i >= icp) {
5195
+ mask_dst[i] = -MAXHALF;
5196
+ } else {
5197
+ mask_dst[i] = mask_src[i];
5198
+ }
5199
+ }
5200
+ }
5201
+ }
5202
+ }
5203
+ }
5204
+
5205
+ constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
5206
+ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
5207
+
5208
+ // scan the blocks of the mask that are not masked
5209
+ // 0 - masked (i.e. full of -INF, skip)
5210
+ // 1 - not masked (i.e. at least one element of the mask is not -INF)
5211
+ kernel void kernel_flash_attn_ext_blk(
5212
+ constant ggml_metal_kargs_flash_attn_ext_blk & args,
5213
+ device const char * mask,
5214
+ device char * dst,
5215
+ uint3 tgpig[[threadgroup_position_in_grid]],
5216
+ ushort tiisg[[thread_index_in_simdgroup]]) {
5217
+ // block size C x Q
5218
+ const int32_t Q = FC_flash_attn_ext_blk_nqptg;
5219
+ const int32_t C = FC_flash_attn_ext_blk_ncpsg;
5220
+
5221
+ constexpr short NW = N_SIMDWIDTH;
5222
+
5223
+ const int32_t i3 = tgpig[2]/args.ne32;
5224
+ const int32_t i2 = tgpig[2]%args.ne32;
5225
+ const int32_t i1 = tgpig[1];
5226
+ const int32_t i0 = tgpig[0];
5227
+
5228
+ char res = i0*C + C > args.ne30 ? 1 : 0;
5229
+
5230
+ device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
5231
+
5232
+ // fast route
5233
+ if (res == 0) {
5234
+ if (simd_max(*mask_src) > -MAXHALF/2) {
5235
+ res = 1;
5236
+ }
5237
+ }
5238
+
5239
+ // detailed check of the elements of the block
5240
+ if ((C > NW || Q > 1) && res == 0) {
5241
+ half m = -MAXHALF;
5242
+
5243
+ FOR_UNROLL (short j = 0; j < Q; ++j) {
5244
+ FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
5245
+ m = max(m, mask_src[ii*NW]);
5246
+ }
5247
+
5248
+ mask_src += args.nb31/2;
5249
+ }
5250
+
5251
+ if (simd_max(m) > -MAXHALF/2) {
5252
+ res = 1;
5253
+ }
5254
+ }
5255
+
5256
+ const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
5257
+ const int32_t nblk0 = ((args.ne30 + C - 1)/C);
5258
+
5259
+ if (tiisg == 0) {
5260
+ dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
5261
+ }
5262
+ }
5263
+
4452
5264
  constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
4453
5265
  constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
4454
5266
  constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
4455
5267
  constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
5268
+ constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
5269
+
5270
+ constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
4456
5271
 
4457
5272
  //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
4458
5273
  //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
@@ -4499,6 +5314,8 @@ void kernel_flash_attn_ext_impl(
4499
5314
  device const char * v,
4500
5315
  device const char * mask,
4501
5316
  device const char * sinks,
5317
+ device const char * pad,
5318
+ device const char * blk,
4502
5319
  device char * dst,
4503
5320
  threadgroup half * shmem_f16,
4504
5321
  uint3 tgpig,
@@ -4564,6 +5381,13 @@ void kernel_flash_attn_ext_impl(
4564
5381
  pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
4565
5382
  }
4566
5383
 
5384
+ {
5385
+ const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
5386
+ const int32_t nblk0 = ((args.ne11 + C - 1)/C);
5387
+
5388
+ blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
5389
+ }
5390
+
4567
5391
  {
4568
5392
  q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
4569
5393
 
@@ -4623,16 +5447,75 @@ void kernel_flash_attn_ext_impl(
4623
5447
 
4624
5448
  // loop over the KV cache
4625
5449
  // each simdgroup handles blocks of Q rows and C columns
4626
- for (int ic = 0; ic < args.ne11; ic += C) {
4627
- // read the mask into shared mem
4628
- if (FC_flash_attn_ext_has_mask) {
4629
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
4630
- const short j = jj*NSG + sgitg;
4631
-
4632
- sm2[j*SH + tiisg] = pm2[jj][tiisg];
5450
+ for (int ic0 = 0; ; ++ic0) {
5451
+ int ic = ic0*C;
5452
+ if (ic >= args.ne11) {
5453
+ break;
5454
+ }
5455
+
5456
+ // the last partial chunk uses the pad buffer as source
5457
+ if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
5458
+ k = pad;
5459
+ v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
5460
+ mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
5461
+
5462
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
5463
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
5464
+
5465
+ k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
5466
+ v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
5467
+
5468
+ if (!FC_flash_attn_ext_has_mask) {
5469
+ threadgroup half * sm = (threadgroup half *) (sm2);
5470
+
5471
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5472
+ const short j = jj*NSG + sgitg;
5473
+
5474
+ for (short i = tiisg; i < C; i += NW) {
5475
+ if (ic + i >= args.ne11) {
5476
+ sm[2*j*SH + i] = -MAXHALF;
5477
+ }
5478
+ }
5479
+ }
5480
+ } else {
5481
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5482
+ const short j = jj*NSG + sgitg;
5483
+
5484
+ pm2[jj] = (device const half2 *) ((device const half *) mask +
5485
+ (iq1 + j)*C +
5486
+ (iq2%args.ne32)*(C*args.ne31) +
5487
+ (iq3%args.ne33)*(C*args.ne31*args.ne32));
5488
+ }
5489
+ }
5490
+
5491
+ ic = 0;
5492
+ }
5493
+
5494
+ // read the mask into shared mem
5495
+ if (FC_flash_attn_ext_has_mask) {
5496
+ if (blk[ic0] == 0) {
5497
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5498
+ pm2[jj] += NW;
5499
+ }
5500
+
5501
+ continue;
5502
+ }
5503
+
5504
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5505
+ const short j = jj*NSG + sgitg;
5506
+
5507
+ if (FC_flash_attn_ext_bc_mask) {
5508
+ sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
5509
+ } else {
5510
+ sm2[j*SH + tiisg] = pm2[jj][tiisg];
5511
+ }
5512
+
4633
5513
  pm2[jj] += NW;
4634
5514
  }
4635
5515
 
5516
+ #if 0
5517
+ // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
5518
+
4636
5519
  threadgroup_barrier(mem_flags::mem_threadgroup);
4637
5520
 
4638
5521
  // used to detect blocks full of -INF
@@ -4651,13 +5534,14 @@ void kernel_flash_attn_ext_impl(
4651
5534
 
4652
5535
  continue;
4653
5536
  }
5537
+ #endif
4654
5538
  }
4655
5539
 
4656
5540
  // Q*K^T
4657
5541
  // this is compile-time check, so it does not have runtime overhead
4658
5542
  if (is_same<kd4x4_t, k4x4_t>::value) {
4659
5543
  // we can read directly from global memory
4660
- device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
5544
+ device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
4661
5545
  threadgroup const q_t * pq = sq;
4662
5546
  threadgroup s_t * ps = ss;
4663
5547
 
@@ -4668,26 +5552,24 @@ void kernel_flash_attn_ext_impl(
4668
5552
 
4669
5553
  constexpr short NC = (C/8)/NSG;
4670
5554
 
4671
- // TODO: not good to unroll for large contexts - not sure why?
5555
+ // note: do not unroll for large heads
5556
+ #pragma unroll (DK <= 64 ? NC : 1)
4672
5557
  for (short cc = 0; cc < NC; ++cc) {
4673
5558
  qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
4674
5559
 
4675
- if (DK8 % 16 != 0) {
5560
+ if (DK % 16 != 0) {
4676
5561
  k8x8_t mk;
4677
5562
  q8x8_t mq;
4678
5563
 
4679
5564
  FOR_UNROLL (short i = 0; i < DK8; ++i) {
4680
5565
  simdgroup_barrier(mem_flags::mem_none);
4681
5566
 
4682
- simdgroup_load(mk, pk, NS10, 0, true);
4683
- simdgroup_load(mq, pq, DK);
5567
+ simdgroup_load(mk, pk + 8*i, NS10, 0, true);
5568
+ simdgroup_load(mq, pq + 8*i, DK);
4684
5569
 
4685
5570
  simdgroup_barrier(mem_flags::mem_none);
4686
5571
 
4687
5572
  simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
4688
-
4689
- pk += 8;
4690
- pq += 8;
4691
5573
  }
4692
5574
  } else {
4693
5575
  k8x8_t mk[2];
@@ -4696,26 +5578,22 @@ void kernel_flash_attn_ext_impl(
4696
5578
  FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
4697
5579
  simdgroup_barrier(mem_flags::mem_none);
4698
5580
 
4699
- simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
4700
- simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
5581
+ simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
5582
+ simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
4701
5583
 
4702
- simdgroup_load(mq[0], pq + 0*8, DK);
4703
- simdgroup_load(mq[1], pq + 1*8, DK);
5584
+ simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
5585
+ simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
4704
5586
 
4705
5587
  simdgroup_barrier(mem_flags::mem_none);
4706
5588
 
4707
5589
  simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
4708
5590
  simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
4709
-
4710
- pk += 16;
4711
- pq += 16;
4712
5591
  }
4713
5592
  }
4714
5593
 
4715
5594
  simdgroup_store(mqk, ps, SH, 0, false);
4716
5595
 
4717
- pk += 8*(NSG*NS10 - DK8);
4718
- pq += 8*(NSG*0 - DK8);
5596
+ pk += 8*(NSG*NS10);
4719
5597
  ps += 8*(NSG);
4720
5598
  }
4721
5599
  } else {
@@ -4729,7 +5607,7 @@ void kernel_flash_attn_ext_impl(
4729
5607
  qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
4730
5608
 
4731
5609
  for (short ii = 0; ii < DK16; ii += 4) {
4732
- device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
5610
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
4733
5611
 
4734
5612
  if (DK16%4 == 0) {
4735
5613
  // the head is evenly divisible by 4*16 = 64, so no need for bound checks
@@ -4849,27 +5727,50 @@ void kernel_flash_attn_ext_impl(
4849
5727
  }
4850
5728
 
4851
5729
  {
4852
- auto sst = ss;
4853
-
4854
- device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
5730
+ device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
4855
5731
 
4856
5732
  pv += 8*sgitg;
4857
5733
 
4858
- FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
4859
- s8x8_t vs;
4860
- simdgroup_load(vs, sst, SH, 0, false);
5734
+ if (DV <= 64) {
5735
+ FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
5736
+ s8x8_t vs;
5737
+ simdgroup_load(vs, ss + 8*cc, SH, 0, false);
4861
5738
 
4862
- FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
4863
- v8x8_t mv;
5739
+ FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
5740
+ v8x8_t mv[2];
4864
5741
 
4865
- simdgroup_load(mv, pv, NS20, 0, false);
4866
- simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
5742
+ simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
5743
+ simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
4867
5744
 
4868
- pv += 8*NSG;
5745
+ simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
5746
+ simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
5747
+ }
5748
+
5749
+ pv += 8*NS20;
4869
5750
  }
5751
+ } else {
5752
+ FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
5753
+ s8x8_t vs[2];
5754
+
5755
+ simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
5756
+ simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
4870
5757
 
4871
- pv += 8*(NS20 - NO*NSG);
4872
- sst += 8;
5758
+ FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
5759
+ v8x8_t mv[4];
5760
+
5761
+ simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
5762
+ simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
5763
+ simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
5764
+ simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
5765
+
5766
+ simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
5767
+ simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
5768
+ simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
5769
+ simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
5770
+ }
5771
+
5772
+ pv += 2*8*NS20;
5773
+ }
4873
5774
  }
4874
5775
  }
4875
5776
 
@@ -4893,7 +5794,7 @@ void kernel_flash_attn_ext_impl(
4893
5794
  simdgroup_load(vs, ss + 8*cc, SH, 0, false);
4894
5795
 
4895
5796
  for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
4896
- device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
5797
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
4897
5798
 
4898
5799
  if (DV16%4 == 0) {
4899
5800
  // no need for bound checks
@@ -4983,7 +5884,7 @@ void kernel_flash_attn_ext_impl(
4983
5884
 
4984
5885
  device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
4985
5886
 
4986
- const float scale = 1.0f/S[jj];
5887
+ const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
4987
5888
 
4988
5889
  if (DV4 % NW == 0) {
4989
5890
  FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
@@ -5028,8 +5929,8 @@ template<
5028
5929
  void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
5029
5930
  short DK, // K head size
5030
5931
  short DV, // V head size
5031
- short Q = 8, // queries per threadgroup
5032
- short C = 64> // cache items per threadgroup
5932
+ short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
5933
+ short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
5033
5934
  kernel void kernel_flash_attn_ext(
5034
5935
  constant ggml_metal_kargs_flash_attn_ext & args,
5035
5936
  device const char * q,
@@ -5037,13 +5938,15 @@ kernel void kernel_flash_attn_ext(
5037
5938
  device const char * v,
5038
5939
  device const char * mask,
5039
5940
  device const char * sinks,
5941
+ device const char * pad,
5942
+ device const char * blk,
5040
5943
  device char * dst,
5041
5944
  threadgroup half * shmem_f16 [[threadgroup(0)]],
5042
5945
  uint3 tgpig[[threadgroup_position_in_grid]],
5043
5946
  ushort tiisg[[thread_index_in_simdgroup]],
5044
5947
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5045
5948
  #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
5046
- #define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
5949
+ #define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
5047
5950
  switch (FC_flash_attn_ext_nsg) {
5048
5951
  // note: disabled cases to reduce library load time
5049
5952
  //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
@@ -5075,10 +5978,36 @@ kernel void kernel_flash_attn_ext(
5075
5978
  half, half4, simdgroup_half8x8
5076
5979
  //float, float4, simdgroup_float8x8
5077
5980
 
5981
+ #define FA_TYPES_F32 \
5982
+ half, half4, simdgroup_half8x8, \
5983
+ float, float4x4, simdgroup_float8x8, \
5984
+ float, float4x4, simdgroup_float8x8, \
5985
+ float, simdgroup_float8x8, \
5986
+ float, float2, simdgroup_float8x8, \
5987
+ float, float4, simdgroup_float8x8
5988
+ //half, half4, simdgroup_half8x8
5989
+
5078
5990
  typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
5079
5991
 
5992
+ template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
5993
+ template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
5994
+ template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 48, 48>;
5995
+ template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
5996
+ template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
5997
+ template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
5998
+ template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
5999
+ template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
6000
+ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
6001
+ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
6002
+ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
6003
+ template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
6004
+ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
6005
+
6006
+ template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
5080
6007
  template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
6008
+ template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 48, 48>;
5081
6009
  template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
6010
+ template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
5082
6011
  template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
5083
6012
  template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
5084
6013
  template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
@@ -5089,8 +6018,11 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
5089
6018
  template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
5090
6019
 
5091
6020
  #if defined(GGML_METAL_HAS_BF16)
6021
+ template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
5092
6022
  template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
6023
+ template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 48, 48>;
5093
6024
  template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
6025
+ template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
5094
6026
  template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
5095
6027
  template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
5096
6028
  template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
@@ -5101,8 +6033,11 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
5101
6033
  template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
5102
6034
  #endif
5103
6035
 
6036
+ template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
5104
6037
  template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
6038
+ template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 48, 48>;
5105
6039
  template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
6040
+ template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
5106
6041
  template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
5107
6042
  template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
5108
6043
  template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
@@ -5112,8 +6047,11 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
5112
6047
  template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
5113
6048
  template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
5114
6049
 
6050
+ template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
5115
6051
  template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
6052
+ template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 48, 48>;
5116
6053
  template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
6054
+ template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
5117
6055
  template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
5118
6056
  template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
5119
6057
  template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
@@ -5123,8 +6061,11 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
5123
6061
  template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
5124
6062
  template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
5125
6063
 
6064
+ template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
5126
6065
  template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
6066
+ template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 48, 48>;
5127
6067
  template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
6068
+ template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
5128
6069
  template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
5129
6070
  template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
5130
6071
  template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
@@ -5134,8 +6075,11 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
5134
6075
  template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
5135
6076
  template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
5136
6077
 
6078
+ template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
5137
6079
  template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
6080
+ template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 48, 48>;
5138
6081
  template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
6082
+ template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
5139
6083
  template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
5140
6084
  template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
5141
6085
  template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
@@ -5145,8 +6089,11 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
5145
6089
  template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
5146
6090
  template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
5147
6091
 
6092
+ template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
5148
6093
  template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
6094
+ template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 48, 48>;
5149
6095
  template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
6096
+ template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
5150
6097
  template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
5151
6098
  template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
5152
6099
  template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
@@ -5158,11 +6105,13 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_at
5158
6105
 
5159
6106
  #undef FA_TYPES
5160
6107
  #undef FA_TYPES_BF
6108
+ #undef FA_TYPES_F32
5161
6109
 
5162
6110
  constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
5163
6111
  constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
5164
6112
  constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
5165
6113
  constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
6114
+ constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
5166
6115
 
5167
6116
  //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
5168
6117
  //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
@@ -5189,9 +6138,9 @@ template<
5189
6138
  void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
5190
6139
  short DK, // K head size
5191
6140
  short DV, // V head size
5192
- short NE = 4, // head elements per thread
5193
- short Q = 1, // queries per threadgroup
5194
- short C = 32, // cache items per threadgroup
6141
+ short NE, // head elements per thread
6142
+ short Q, // queries per threadgroup
6143
+ short C, // cache items per threadgroup
5195
6144
  short NSG> // number of simd groups
5196
6145
  void kernel_flash_attn_ext_vec_impl(
5197
6146
  constant ggml_metal_kargs_flash_attn_ext_vec & args,
@@ -5200,6 +6149,7 @@ void kernel_flash_attn_ext_vec_impl(
5200
6149
  device const char * v,
5201
6150
  device const char * mask,
5202
6151
  device const char * sinks,
6152
+ device const char * pad,
5203
6153
  device char * dst,
5204
6154
  threadgroup half * shmem_f16 [[threadgroup(0)]],
5205
6155
  uint3 tgpig[[threadgroup_position_in_grid]],
@@ -5305,12 +6255,38 @@ void kernel_flash_attn_ext_vec_impl(
5305
6255
 
5306
6256
  // loop over the KV cache
5307
6257
  // each simdgroup handles blocks of Q rows and C columns
5308
- for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
5309
- const int ic = ic0 + C*sgitg;
6258
+ for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
6259
+ int ic = ic0*C;
5310
6260
  if (ic >= args.ne11) {
5311
6261
  break;
5312
6262
  }
5313
6263
 
6264
+ // the last partial chunk uses the pad buffer as source
6265
+ if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
6266
+ k = pad;
6267
+ v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
6268
+ mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
6269
+
6270
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
6271
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
6272
+
6273
+ k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
6274
+ v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
6275
+
6276
+ if (!FC_flash_attn_ext_vec_has_mask) {
6277
+ if (ic + tiisg >= args.ne11) {
6278
+ sm[tiisg] = -MAXHALF;
6279
+ }
6280
+ } else {
6281
+ pm = (device const half *) (mask) +
6282
+ iq1*C +
6283
+ (iq2%args.ne32)*(C*args.ne31) +
6284
+ (iq3%args.ne33)*(C*args.ne31*args.ne32);
6285
+ }
6286
+
6287
+ ic = 0;
6288
+ }
6289
+
5314
6290
  if (FC_flash_attn_ext_vec_has_mask) {
5315
6291
  sm[tiisg] = pm[ic + tiisg];
5316
6292
  }
@@ -5322,7 +6298,7 @@ void kernel_flash_attn_ext_vec_impl(
5322
6298
 
5323
6299
  // Q*K^T
5324
6300
  {
5325
- device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
6301
+ device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
5326
6302
  threadgroup const q4_t * pq4 = sq4;
5327
6303
 
5328
6304
  pk4 += ty*NS10/4 + tx;
@@ -5337,7 +6313,7 @@ void kernel_flash_attn_ext_vec_impl(
5337
6313
  mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
5338
6314
  }
5339
6315
  } else {
5340
- device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
6316
+ device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
5341
6317
 
5342
6318
  k4_t mk;
5343
6319
 
@@ -5435,7 +6411,7 @@ void kernel_flash_attn_ext_vec_impl(
5435
6411
  }
5436
6412
 
5437
6413
  if (is_same<vd4_t, v4_t>::value) {
5438
- device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
6414
+ device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
5439
6415
 
5440
6416
  pv4 += ty*NS20/4 + tx;
5441
6417
 
@@ -5448,7 +6424,7 @@ void kernel_flash_attn_ext_vec_impl(
5448
6424
  }
5449
6425
  } else {
5450
6426
  FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
5451
- device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
6427
+ device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
5452
6428
 
5453
6429
  FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
5454
6430
  const short i = ii*NL + tx;
@@ -5573,7 +6549,7 @@ void kernel_flash_attn_ext_vec_impl(
5573
6549
  device float4 * dst4 = (device float4 *) dst;
5574
6550
  device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
5575
6551
 
5576
- const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
6552
+ const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
5577
6553
 
5578
6554
  // interleave the workgroup data
5579
6555
  for (short i = tiisg; i < DV4; i += NW) {
@@ -5611,8 +6587,8 @@ template<
5611
6587
  short DK, // K head size
5612
6588
  short DV, // V head size
5613
6589
  short NE = 4, // head elements per thread
5614
- short Q = 1, // queries per threadgroup
5615
- short C = 32> // cache items per threadgroup
6590
+ short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
6591
+ short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
5616
6592
  kernel void kernel_flash_attn_ext_vec(
5617
6593
  constant ggml_metal_kargs_flash_attn_ext_vec & args,
5618
6594
  device const char * q,
@@ -5620,13 +6596,14 @@ kernel void kernel_flash_attn_ext_vec(
5620
6596
  device const char * v,
5621
6597
  device const char * mask,
5622
6598
  device const char * sinks,
6599
+ device const char * pad,
5623
6600
  device char * dst,
5624
6601
  threadgroup half * shmem_f16 [[threadgroup(0)]],
5625
6602
  uint3 tgpig[[threadgroup_position_in_grid]],
5626
6603
  ushort tiisg[[thread_index_in_simdgroup]],
5627
6604
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5628
6605
  #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
5629
- #define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
6606
+ #define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
5630
6607
  switch (FC_flash_attn_ext_vec_nsg) {
5631
6608
  // note: disabled cases to reduce library load time
5632
6609
  case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
@@ -5651,79 +6628,106 @@ kernel void kernel_flash_attn_ext_vec(
5651
6628
  float, float4, \
5652
6629
  float4
5653
6630
 
6631
+ #define FA_TYPES_F32 \
6632
+ half4, \
6633
+ float4, \
6634
+ float4, \
6635
+ float, \
6636
+ float, float4, \
6637
+ float4
6638
+
5654
6639
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
5655
6640
 
5656
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
6641
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
6642
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
5657
6643
  #if defined(GGML_METAL_HAS_BF16)
5658
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
6644
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
5659
6645
  #endif
5660
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
5661
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
5662
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
5663
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
5664
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
5665
-
5666
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
6646
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
6647
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
6648
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
6649
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
6650
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
6651
+
6652
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
6653
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
5667
6654
  #if defined(GGML_METAL_HAS_BF16)
5668
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
6655
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
5669
6656
  #endif
5670
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
5671
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
5672
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
5673
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
5674
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
5675
-
5676
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
6657
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
6658
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
6659
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
6660
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
6661
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
6662
+
6663
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
6664
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
5677
6665
  #if defined(GGML_METAL_HAS_BF16)
5678
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
6666
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
5679
6667
  #endif
5680
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
5681
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
5682
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
5683
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
5684
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
5685
-
5686
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
6668
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
6669
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
6670
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
6671
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
6672
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
6673
+
6674
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
6675
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
5687
6676
  #if defined(GGML_METAL_HAS_BF16)
5688
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
6677
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
5689
6678
  #endif
5690
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
5691
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
5692
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
5693
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
5694
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
5695
-
5696
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
6679
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
6680
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
6681
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
6682
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
6683
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
6684
+
6685
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
6686
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
5697
6687
  #if defined(GGML_METAL_HAS_BF16)
5698
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
6688
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
5699
6689
  #endif
5700
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
5701
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
5702
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
5703
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
5704
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
5705
-
5706
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
6690
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
6691
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
6692
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
6693
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
6694
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
6695
+
6696
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
6697
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
5707
6698
  #if defined(GGML_METAL_HAS_BF16)
5708
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
6699
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
5709
6700
  #endif
5710
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
5711
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
5712
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
5713
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
5714
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
5715
-
5716
- template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
6701
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
6702
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
6703
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
6704
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
6705
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
6706
+
6707
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
6708
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
5717
6709
  #if defined(GGML_METAL_HAS_BF16)
5718
- template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
6710
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
5719
6711
  #endif
5720
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
5721
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
5722
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
5723
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
5724
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
6712
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
6713
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
6714
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
6715
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
6716
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
6717
+
6718
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
6719
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
6720
+ #if defined(GGML_METAL_HAS_BF16)
6721
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
6722
+ #endif
6723
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
6724
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
6725
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
6726
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
6727
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
5725
6728
 
5726
6729
  #undef FA_TYPES
6730
+ #undef FA_TYPES_F32
5727
6731
 
5728
6732
  constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
5729
6733
  constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
@@ -5750,7 +6754,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
5750
6754
  const float m = simd_max(M);
5751
6755
  const float ms = exp(M - m);
5752
6756
 
5753
- S = 1.0f/simd_sum(S*ms);
6757
+ S = simd_sum(S*ms);
6758
+ S = S == 0.0f ? 0.0f : 1.0f/S;
5754
6759
 
5755
6760
  const short DV4 = DV/4;
5756
6761
 
@@ -5770,21 +6775,17 @@ kernel void kernel_flash_attn_ext_vec_reduce(
5770
6775
  }
5771
6776
 
5772
6777
  template<typename T0, typename T1>
5773
- kernel void kernel_cpy(
6778
+ kernel void kernel_cpy_t_t(
5774
6779
  constant ggml_metal_kargs_cpy & args,
5775
6780
  device const char * src0,
5776
6781
  device char * dst,
5777
6782
  uint3 tgpig[[threadgroup_position_in_grid]],
5778
- uint tiitg[[thread_index_in_threadgroup]],
5779
- ushort3 tpitg[[thread_position_in_threadgroup]],
5780
- ushort3 tptg[[threads_per_threadgroup]]) {
6783
+ ushort tiitg[[thread_index_in_threadgroup]],
6784
+ ushort3 ntg[[threads_per_threadgroup]]) {
5781
6785
  const int i03 = tgpig[2];
5782
6786
  const int i02 = tgpig[1];
5783
- const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
5784
-
5785
- if (i01 >= args.ne01) {
5786
- return;
5787
- }
6787
+ const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6788
+ const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
5788
6789
 
5789
6790
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
5790
6791
 
@@ -5795,190 +6796,71 @@ kernel void kernel_cpy(
5795
6796
 
5796
6797
  device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
5797
6798
 
5798
- for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
6799
+ for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
5799
6800
  device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
5800
6801
  dst_data[i00] = (T1) src[0];
6802
+ break;
5801
6803
  }
5802
6804
  }
5803
6805
 
5804
- typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
6806
+ typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
5805
6807
 
5806
- template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
5807
- template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
5808
- template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
5809
- template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
6808
+ template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
6809
+ template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
6810
+ template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
6811
+ template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
6812
+ template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
5810
6813
  #if defined(GGML_METAL_HAS_BF16)
5811
- template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
6814
+ template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
5812
6815
  #endif
5813
- template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
5814
- template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
6816
+ template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
6817
+ template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
5815
6818
  #if defined(GGML_METAL_HAS_BF16)
5816
- template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
5817
- template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
6819
+ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
6820
+ template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
5818
6821
  #endif
5819
6822
 
5820
- // TODO: templetify these kernels
5821
- kernel void kernel_cpy_f32_q8_0(
6823
+ template<short QK,
6824
+ typename block_q,
6825
+ void (*quantize_func)(device const float *, device block_q &)>
6826
+ kernel void kernel_cpy_f32_q(
5822
6827
  constant ggml_metal_kargs_cpy & args,
5823
6828
  device const char * src0,
5824
- device char * dst,
6829
+ device char * dst,
5825
6830
  uint3 tgpig[[threadgroup_position_in_grid]],
5826
- ushort3 tpitg[[thread_position_in_threadgroup]],
5827
- ushort3 ntg[[threads_per_threadgroup]]) {
5828
- const int i03 = tgpig[2];
5829
- const int i02 = tgpig[1];
5830
- const int i01 = tgpig[0];
5831
-
5832
- const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
5833
-
5834
- const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
5835
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
5836
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
5837
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
5838
-
5839
- device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
5840
-
5841
- for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
5842
- device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
5843
-
5844
- quantize_q8_0(src, dst_data[i00/QK8_0]);
5845
- }
5846
- }
5847
-
5848
- kernel void kernel_cpy_f32_q4_0(
5849
- constant ggml_metal_kargs_cpy & args,
5850
- device const char * src0,
5851
- device char * dst,
5852
- uint3 tgpig[[threadgroup_position_in_grid]],
5853
- ushort3 tpitg[[thread_position_in_threadgroup]],
5854
- ushort3 ntg[[threads_per_threadgroup]]) {
5855
- const int i03 = tgpig[2];
5856
- const int i02 = tgpig[1];
5857
- const int i01 = tgpig[0];
5858
-
5859
- const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
5860
-
5861
- const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
5862
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
5863
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
5864
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
5865
-
5866
- device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
5867
-
5868
- for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
5869
- device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
5870
-
5871
- quantize_q4_0(src, dst_data[i00/QK4_0]);
5872
- }
5873
- }
5874
-
5875
- kernel void kernel_cpy_f32_q4_1(
5876
- constant ggml_metal_kargs_cpy & args,
5877
- device const char * src0,
5878
- device char * dst,
5879
- uint3 tgpig[[threadgroup_position_in_grid]],
5880
- ushort3 tpitg[[thread_position_in_threadgroup]],
6831
+ ushort tiitg[[thread_index_in_threadgroup]],
5881
6832
  ushort3 ntg[[threads_per_threadgroup]]) {
5882
6833
  const int i03 = tgpig[2];
5883
6834
  const int i02 = tgpig[1];
5884
- const int i01 = tgpig[0];
6835
+ const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6836
+ const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
5885
6837
 
5886
6838
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
5887
6839
 
5888
6840
  const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
5889
6841
  const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
5890
6842
  const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
5891
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
5892
-
5893
- device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
5894
-
5895
- for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
5896
- device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
5897
-
5898
- quantize_q4_1(src, dst_data[i00/QK4_1]);
5899
- }
5900
- }
5901
-
5902
- kernel void kernel_cpy_f32_q5_0(
5903
- constant ggml_metal_kargs_cpy & args,
5904
- device const char * src0,
5905
- device char * dst,
5906
- uint3 tgpig[[threadgroup_position_in_grid]],
5907
- ushort3 tpitg[[thread_position_in_threadgroup]],
5908
- ushort3 ntg[[threads_per_threadgroup]]) {
5909
- const int i03 = tgpig[2];
5910
- const int i02 = tgpig[1];
5911
- const int i01 = tgpig[0];
6843
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
5912
6844
 
5913
- const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6845
+ device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
5914
6846
 
5915
- const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
5916
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
5917
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
5918
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
5919
-
5920
- device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6847
+ for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
6848
+ device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
5921
6849
 
5922
- for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
5923
- device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
6850
+ quantize_func(src, dst_data[i00]);
5924
6851
 
5925
- quantize_q5_0(src, dst_data[i00/QK5_0]);
6852
+ break;
5926
6853
  }
5927
6854
  }
5928
6855
 
5929
- kernel void kernel_cpy_f32_q5_1(
5930
- constant ggml_metal_kargs_cpy & args,
5931
- device const char * src0,
5932
- device char * dst,
5933
- uint3 tgpig[[threadgroup_position_in_grid]],
5934
- ushort3 tpitg[[thread_position_in_threadgroup]],
5935
- ushort3 ntg[[threads_per_threadgroup]]) {
5936
- const int i03 = tgpig[2];
5937
- const int i02 = tgpig[1];
5938
- const int i01 = tgpig[0];
5939
-
5940
- const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6856
+ typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
5941
6857
 
5942
- const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
5943
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
5944
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
5945
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
5946
-
5947
- device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
5948
-
5949
- for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
5950
- device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
5951
-
5952
- quantize_q5_1(src, dst_data[i00/QK5_1]);
5953
- }
5954
- }
5955
-
5956
- kernel void kernel_cpy_f32_iq4_nl(
5957
- constant ggml_metal_kargs_cpy & args,
5958
- device const char * src0,
5959
- device char * dst,
5960
- uint3 tgpig[[threadgroup_position_in_grid]],
5961
- ushort3 tpitg[[thread_position_in_threadgroup]],
5962
- ushort3 ntg[[threads_per_threadgroup]]) {
5963
- const int i03 = tgpig[2];
5964
- const int i02 = tgpig[1];
5965
- const int i01 = tgpig[0];
5966
-
5967
- const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
5968
-
5969
- const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
5970
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
5971
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
5972
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
5973
-
5974
- device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
5975
-
5976
- for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
5977
- device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
5978
-
5979
- quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
5980
- }
5981
- }
6858
+ template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
6859
+ template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
6860
+ template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
6861
+ template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
6862
+ template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
6863
+ template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
5982
6864
 
5983
6865
  template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
5984
6866
  kernel void kernel_cpy_q_f32(
@@ -5986,11 +6868,12 @@ kernel void kernel_cpy_q_f32(
5986
6868
  device const char * src0,
5987
6869
  device char * dst,
5988
6870
  uint3 tgpig[[threadgroup_position_in_grid]],
5989
- ushort3 tpitg[[thread_position_in_threadgroup]],
6871
+ ushort tiitg[[thread_index_in_threadgroup]],
5990
6872
  ushort3 ntg[[threads_per_threadgroup]]) {
5991
6873
  const int i03 = tgpig[2];
5992
6874
  const int i02 = tgpig[1];
5993
- const int i01 = tgpig[0];
6875
+ const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6876
+ const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
5994
6877
 
5995
6878
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
5996
6879
 
@@ -6002,10 +6885,12 @@ kernel void kernel_cpy_q_f32(
6002
6885
  device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
6003
6886
  device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6004
6887
 
6005
- for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
6888
+ for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
6006
6889
  T4x4 temp;
6007
6890
  dequantize_func(src_data + i00/nl, i00%nl, temp);
6008
6891
  dst_data[i00] = temp;
6892
+
6893
+ break;
6009
6894
  }
6010
6895
  }
6011
6896
 
@@ -7458,7 +8343,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
7458
8343
  kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7459
8344
  }
7460
8345
 
7461
- template<int nr0, typename args_t>
8346
+ template<int NR0, typename args_t>
7462
8347
  void kernel_mul_mv_iq4_nl_f32_impl(
7463
8348
  args_t args,
7464
8349
  device const char * src0,
@@ -7471,13 +8356,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
7471
8356
  const short NSG = FC_mul_mv_nsg;
7472
8357
 
7473
8358
  threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7474
- const int nb = args.ne00/QK4_NL;
7475
8359
 
7476
8360
  const int r0 = tgpig.x;
7477
8361
  const int r1 = tgpig.y;
7478
8362
  const int im = tgpig.z;
7479
8363
 
7480
- const int first_row = (r0 * NSG + sgitg) * nr0;
8364
+ const int first_row = (r0 * NSG + sgitg) * NR0;
7481
8365
 
7482
8366
  const uint i12 = im%args.ne12;
7483
8367
  const uint i13 = im/args.ne12;
@@ -7488,6 +8372,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
7488
8372
  device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
7489
8373
  device const float * y = (device const float *) (src1 + offset1);
7490
8374
 
8375
+ const int nb = args.ne00/QK4_NL;
8376
+ const int ns01 = args.nb01/args.nb00;
8377
+
7491
8378
  const short ix = tiisg/2; // 0...15
7492
8379
  const short it = tiisg%2; // 0 or 1
7493
8380
 
@@ -7495,24 +8382,25 @@ void kernel_mul_mv_iq4_nl_f32_impl(
7495
8382
  threadgroup_barrier(mem_flags::mem_threadgroup);
7496
8383
 
7497
8384
  float4 yl[4];
7498
- float sumf[nr0]={0.f};
8385
+ float sumf[NR0]={0.f};
7499
8386
 
7500
- device const float * yb = y + ix * QK4_NL + it * 8;
8387
+ device const float * yb = y + ix*QK4_NL + it*8;
7501
8388
 
7502
8389
  uint32_t aux32[2];
7503
8390
  thread const uint8_t * q8 = (thread const uint8_t *)aux32;
7504
8391
 
7505
8392
  float4 qf1, qf2;
7506
8393
 
7507
- for (int ib = ix; ib < nb; ib += 16) {
8394
+ // [TAG_MUL_MV_WEIRD]
8395
+ for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
7508
8396
  device const float4 * y4 = (device const float4 *)yb;
7509
8397
  yl[0] = y4[0];
7510
8398
  yl[1] = y4[4];
7511
8399
  yl[2] = y4[1];
7512
8400
  yl[3] = y4[5];
7513
8401
 
7514
- for (short row = 0; row < nr0; row++) {
7515
- device const block_iq4_nl & xb = x[row*nb + ib];
8402
+ for (short row = 0; row < NR0; row++) {
8403
+ device const block_iq4_nl & xb = x[row*ns01 + ib];
7516
8404
  device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
7517
8405
 
7518
8406
  float4 acc1 = {0.f}, acc2 = {0.f};
@@ -7543,7 +8431,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
7543
8431
 
7544
8432
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7545
8433
 
7546
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
8434
+ for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
7547
8435
  float sum_all = simd_sum(sumf[row]);
7548
8436
  if (tiisg == 0) {
7549
8437
  dst_f32[first_row + row] = sum_all;
@@ -7565,7 +8453,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
7565
8453
  kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7566
8454
  }
7567
8455
 
7568
- template<int nr0, typename args_t>
8456
+ template<int NR0, typename args_t>
7569
8457
  void kernel_mul_mv_iq4_xs_f32_impl(
7570
8458
  args_t args,
7571
8459
  device const char * src0,
@@ -7578,12 +8466,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
7578
8466
  const short NSG = FC_mul_mv_nsg;
7579
8467
 
7580
8468
  threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7581
- const int nb = args.ne00/QK_K;
7582
8469
 
7583
8470
  const int r0 = tgpig.x;
7584
8471
  const int r1 = tgpig.y;
7585
8472
  const int im = tgpig.z;
7586
- const int first_row = (r0 * NSG + sgitg) * nr0;
8473
+ const int first_row = (r0 * NSG + sgitg) * NR0;
7587
8474
 
7588
8475
  const uint i12 = im%args.ne12;
7589
8476
  const uint i13 = im/args.ne12;
@@ -7594,6 +8481,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
7594
8481
  device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
7595
8482
  device const float * y = (device const float *) (src1 + offset1);
7596
8483
 
8484
+ const int nb = args.ne00/QK_K;
8485
+ const int ns01 = args.nb01/args.nb00;
8486
+
7597
8487
  const short ix = tiisg/16; // 0 or 1
7598
8488
  const short it = tiisg%16; // 0...15
7599
8489
  const short ib = it/2;
@@ -7603,7 +8493,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
7603
8493
  threadgroup_barrier(mem_flags::mem_threadgroup);
7604
8494
 
7605
8495
  float4 yl[4];
7606
- float sumf[nr0]={0.f};
8496
+ float sumf[NR0]={0.f};
7607
8497
 
7608
8498
  device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
7609
8499
 
@@ -7612,15 +8502,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
7612
8502
 
7613
8503
  float4 qf1, qf2;
7614
8504
 
7615
- for (int ibl = ix; ibl < nb; ibl += 2) {
8505
+ // [TAG_MUL_MV_WEIRD]
8506
+ for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
7616
8507
  device const float4 * y4 = (device const float4 *)yb;
7617
8508
  yl[0] = y4[0];
7618
8509
  yl[1] = y4[4];
7619
8510
  yl[2] = y4[1];
7620
8511
  yl[3] = y4[5];
7621
8512
 
7622
- for (short row = 0; row < nr0; ++row) {
7623
- device const block_iq4_xs & xb = x[row*nb + ibl];
8513
+ for (short row = 0; row < NR0; ++row) {
8514
+ device const block_iq4_xs & xb = x[row*ns01 + ibl];
7624
8515
  device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
7625
8516
 
7626
8517
  float4 acc1 = {0.f}, acc2 = {0.f};
@@ -7650,7 +8541,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
7650
8541
 
7651
8542
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7652
8543
 
7653
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
8544
+ for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
7654
8545
  float sum_all = simd_sum(sumf[row]);
7655
8546
  if (tiisg == 0) {
7656
8547
  dst_f32[first_row + row] = sum_all;
@@ -7672,7 +8563,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
7672
8563
  kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7673
8564
  }
7674
8565
 
7675
- template<int nr0, typename args_t>
8566
+ template<int NR0, typename args_t>
7676
8567
  void kernel_mul_mv_mxfp4_f32_impl(
7677
8568
  args_t args,
7678
8569
  device const char * src0,
@@ -7685,13 +8576,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
7685
8576
  const short NSG = FC_mul_mv_nsg;
7686
8577
 
7687
8578
  threadgroup float * shmem_f32 = (threadgroup float *) shmem;
7688
- const int nb = args.ne00/QK_MXFP4;
7689
8579
 
7690
8580
  const int r0 = tgpig.x;
7691
8581
  const int r1 = tgpig.y;
7692
8582
  const int im = tgpig.z;
7693
8583
 
7694
- const int first_row = (r0 * NSG + sgitg) * nr0;
8584
+ const int first_row = (r0 * NSG + sgitg) * NR0;
7695
8585
 
7696
8586
  const uint i12 = im%args.ne12;
7697
8587
  const uint i13 = im/args.ne12;
@@ -7702,6 +8592,9 @@ void kernel_mul_mv_mxfp4_f32_impl(
7702
8592
  device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
7703
8593
  device const float * y = (device const float *) (src1 + offset1);
7704
8594
 
8595
+ const int nb = args.ne00/QK_MXFP4;
8596
+ const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
8597
+
7705
8598
  const short ix = tiisg/2; // 0...15
7706
8599
  const short it = tiisg%2; // 0 or 1
7707
8600
 
@@ -7709,20 +8602,22 @@ void kernel_mul_mv_mxfp4_f32_impl(
7709
8602
  threadgroup_barrier(mem_flags::mem_threadgroup);
7710
8603
 
7711
8604
  float4 yl[4];
7712
- float sumf[nr0]={0.f};
8605
+ float sumf[NR0]={0.f};
7713
8606
 
7714
- device const float * yb = y + ix * QK_MXFP4 + it * 8;
8607
+ device const float * yb = y + ix*QK_MXFP4 + it*8;
8608
+
8609
+ // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
8610
+ // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
8611
+ for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
8612
+ device const float4 * y4 = (device const float4 *) yb;
7715
8613
 
7716
- for (int ib = ix; ib < nb; ib += 16) {
7717
- device const float4 * y4 = (device const float4 *)yb;
7718
8614
  yl[0] = y4[0];
7719
8615
  yl[1] = y4[4];
7720
8616
  yl[2] = y4[1];
7721
8617
  yl[3] = y4[5];
7722
8618
 
7723
- #pragma unroll(nr0)
7724
- for (short row = 0; row < nr0; row++) {
7725
- device const block_mxfp4 & xb = x[row*nb + ib];
8619
+ FOR_UNROLL (short row = 0; row < NR0; row++) {
8620
+ device const block_mxfp4 & xb = x[row*ns01 + ib];
7726
8621
  device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
7727
8622
 
7728
8623
  float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
@@ -7740,7 +8635,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
7740
8635
 
7741
8636
  device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7742
8637
 
7743
- for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
8638
+ for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
7744
8639
  float sum_all = simd_sum(sumf[row]);
7745
8640
  if (tiisg == 0) {
7746
8641
  dst_f32[first_row + row] = sum_all;
@@ -7765,66 +8660,60 @@ kernel void kernel_mul_mv_mxfp4_f32(
7765
8660
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
7766
8661
  kernel void kernel_get_rows_q(
7767
8662
  constant ggml_metal_kargs_get_rows & args,
7768
- device const void * src0,
7769
- device const void * src1,
7770
- device float * dst,
7771
- uint3 tgpig[[threadgroup_position_in_grid]],
7772
- uint tiitg[[thread_index_in_threadgroup]],
7773
- uint3 tptg [[threads_per_threadgroup]]) {
7774
- const int64_t i10 = tgpig.x;
7775
- const int64_t i11 = tgpig.y;
8663
+ device const void * src0,
8664
+ device const void * src1,
8665
+ device void * dst,
8666
+ uint3 tgpig[[threadgroup_position_in_grid]],
8667
+ ushort tiitg[[thread_index_in_threadgroup]],
8668
+ ushort3 ntg [[threads_per_threadgroup]]) {
8669
+ const int32_t iw0 = tgpig.x/args.ne10;
8670
+ const int32_t i10 = tgpig.x%args.ne10;
8671
+ const int32_t i11 = tgpig.y;
8672
+ const int32_t i12 = tgpig.z;
8673
+
8674
+ const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
7776
8675
 
7777
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
8676
+ const int32_t i02 = i11;
8677
+ const int32_t i03 = i12;
7778
8678
 
7779
- const int64_t i02 = i11;
8679
+ auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
8680
+ auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
7780
8681
 
7781
- for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
8682
+ for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
7782
8683
  float4x4 temp;
7783
- dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
7784
- *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
8684
+ dequantize_func(psrc + ind/nl, ind%nl, temp);
8685
+ pdst[ind] = temp;
8686
+
8687
+ break;
7785
8688
  }
7786
8689
  }
7787
8690
 
7788
- template<typename T>
8691
+ template<typename T0, typename T>
7789
8692
  kernel void kernel_get_rows_f(
7790
8693
  constant ggml_metal_kargs_get_rows & args,
7791
- device const void * src0,
7792
- device const void * src1,
7793
- device float * dst,
7794
- uint3 tgpig[[threadgroup_position_in_grid]],
7795
- uint tiitg[[thread_index_in_threadgroup]],
7796
- uint3 tptg [[threads_per_threadgroup]]) {
7797
- const int64_t i10 = tgpig.x;
7798
- const int64_t i11 = tgpig.y;
7799
-
7800
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
7801
-
7802
- const int64_t i02 = i11;
8694
+ device const void * src0,
8695
+ device const void * src1,
8696
+ device void * dst,
8697
+ uint3 tgpig[[threadgroup_position_in_grid]],
8698
+ ushort tiitg[[thread_index_in_threadgroup]],
8699
+ ushort3 ntg [[threads_per_threadgroup]]) {
8700
+ const int32_t iw0 = tgpig.x/args.ne10;
8701
+ const int32_t i10 = tgpig.x%args.ne10;
8702
+ const int32_t i11 = tgpig.y;
8703
+ const int32_t i12 = tgpig.z;
7803
8704
 
7804
- for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
7805
- (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
7806
- ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
7807
- }
7808
- }
8705
+ const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
7809
8706
 
7810
- kernel void kernel_get_rows_i32(
7811
- constant ggml_metal_kargs_get_rows & args,
7812
- device const void * src0,
7813
- device const void * src1,
7814
- device int32_t * dst,
7815
- uint3 tgpig[[threadgroup_position_in_grid]],
7816
- uint tiitg[[thread_index_in_threadgroup]],
7817
- uint3 tptg [[threads_per_threadgroup]]) {
7818
- const int64_t i10 = tgpig.x;
7819
- const int64_t i11 = tgpig.y;
8707
+ const int32_t i02 = i11;
8708
+ const int32_t i03 = i12;
7820
8709
 
7821
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
8710
+ auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
8711
+ auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
7822
8712
 
7823
- const int64_t i02 = i11;
8713
+ for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
8714
+ pdst[ind] = psrc[ind];
7824
8715
 
7825
- for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
7826
- (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
7827
- ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
8716
+ break;
7828
8717
  }
7829
8718
  }
7830
8719
 
@@ -7893,17 +8782,6 @@ kernel void kernel_set_rows_f(
7893
8782
  constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
7894
8783
  constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
7895
8784
 
7896
- #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
7897
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
7898
- #define BLOCK_SIZE_K 32
7899
- #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
7900
- #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
7901
- #define THREAD_PER_BLOCK 128
7902
- #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
7903
- #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
7904
- #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
7905
- #define SG_MAT_ROW 8
7906
-
7907
8785
  // each block_q contains 16*nl weights
7908
8786
  template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
7909
8787
  kernel void kernel_mul_mm(
@@ -7919,18 +8797,48 @@ kernel void kernel_mul_mm(
7919
8797
  threadgroup S0 * sa = (threadgroup S0 *)(shmem);
7920
8798
  threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
7921
8799
 
7922
- const int r0 = tgpig.y;
7923
- const int r1 = tgpig.x;
8800
+ threadgroup float * sc = (threadgroup float *)(shmem);
8801
+
8802
+ constexpr int NR0 = 64;
8803
+ constexpr int NR1 = 32;
8804
+
8805
+ constexpr int NK = 32;
8806
+ constexpr int NL0 = NK/16;
8807
+ constexpr int NL1 = NK/8;
8808
+
7924
8809
  const int im = tgpig.z;
8810
+ const int r0 = tgpig.y*NR0;
8811
+ const int r1 = tgpig.x*NR1;
7925
8812
 
7926
8813
  // if this block is of 64x32 shape or smaller
7927
- const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
7928
- const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
8814
+ const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
8815
+ const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
7929
8816
 
7930
8817
  // a thread shouldn't load data outside of the matrix
7931
- const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
7932
- const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
8818
+ const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
8819
+ const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
8820
+
8821
+ const short il0 = (tiitg % NL0);
8822
+
8823
+ short il = il0;
8824
+
8825
+ const int i12 = im%args.ne12;
8826
+ const int i13 = im/args.ne12;
8827
+
8828
+ const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8829
+ const short offset1 = il0/nl;
8830
+
8831
+ device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
7933
8832
 
8833
+ const short iy = 8*(tiitg % NL1);
8834
+
8835
+ device const T1 * y = (device const T1 *)(src1
8836
+ + args.nb13*i13
8837
+ + args.nb12*i12
8838
+ + args.nb11*(r1 + lr1)
8839
+ + args.nb10*iy);
8840
+
8841
+ #ifndef GGML_METAL_HAS_TENSOR
7934
8842
  S0_8x8 ma[4];
7935
8843
  S1_8x8 mb[2];
7936
8844
 
@@ -7939,36 +8847,104 @@ kernel void kernel_mul_mm(
7939
8847
  for (short i = 0; i < 8; i++){
7940
8848
  mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
7941
8849
  }
8850
+ #else
8851
+ auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8852
+ auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
7942
8853
 
7943
- short il = (tiitg % THREAD_PER_ROW);
8854
+ mpp::tensor_ops::matmul2d<
8855
+ mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8856
+ execution_simdgroups<4>> mm;
7944
8857
 
7945
- const int i12 = im%args.ne12;
7946
- const int i13 = im/args.ne12;
8858
+ auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8859
+ #endif
7947
8860
 
7948
- const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7949
- const short offset1 = il/nl;
8861
+ for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8862
+ #ifndef GGML_METAL_HAS_TENSOR
8863
+ // load data and store to threadgroup memory
8864
+ if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8865
+ threadgroup_barrier(mem_flags::mem_threadgroup);
7950
8866
 
7951
- device const block_q * x = (device const block_q *)(src0
7952
- + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
8867
+ // no need for dequantization
8868
+ for (short i = 0; i < 16; i++) {
8869
+ const short sx = 2*il0 + i/8;
8870
+ const short sy = (tiitg/NL0)/8;
7953
8871
 
7954
- const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
8872
+ //const short lx = i%8;
8873
+ //const short ly = (tiitg/NL0)%8;
8874
+ const short lx = (tiitg/NL0)%8;
8875
+ const short ly = i%8;
7955
8876
 
7956
- device const T1 * y = (device const T1 *)(src1
7957
- + args.nb13*i13
7958
- + args.nb12*i12
7959
- + args.nb11*(r1*BLOCK_SIZE_N + thread_col)
7960
- + args.nb10*iy);
8877
+ const short ib = 8*sx + sy;
8878
+
8879
+ *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8880
+ }
8881
+ } else {
8882
+ S0_4x4 temp_a;
8883
+ dequantize_func(x, il, temp_a);
8884
+
8885
+ threadgroup_barrier(mem_flags::mem_threadgroup);
8886
+
8887
+ FOR_UNROLL (short i = 0; i < 16; i++) {
8888
+ const short sx = 2*il0 + i/8;
8889
+ const short sy = (tiitg/NL0)/8;
8890
+
8891
+ //const short lx = i%8;
8892
+ //const short ly = (tiitg/NL0)%8;
8893
+ const short lx = (tiitg/NL0)%8;
8894
+ const short ly = i%8;
8895
+
8896
+ const short ib = 8*sx + sy;
8897
+
8898
+ // NOTE: this is massively slower.. WTF?
8899
+ //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
8900
+
8901
+ *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
8902
+ }
8903
+ }
8904
+
8905
+ if (FC_mul_mm_bc_inp) {
8906
+ for (short i = 0; i < 8; ++i) {
8907
+ const short sx = (tiitg%NL1);
8908
+ const short sy = (tiitg/NL1)/8;
8909
+
8910
+ const short lx = i;
8911
+ const short ly = (tiitg/NL1)%8;
8912
+ //const short lx = (tiitg/NL1)%8;
8913
+ //const short ly = i;
8914
+
8915
+ const short ib = 4*sx + sy;
8916
+
8917
+ *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8918
+ }
8919
+ } else {
8920
+ const short sx = (tiitg%NL1);
8921
+ const short sy = (tiitg/NL1)/8;
8922
+
8923
+ const short dx = sx;
8924
+ const short dy = sy;
7961
8925
 
7962
- for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
8926
+ const short ly = (tiitg/NL1)%8;
8927
+
8928
+ const short ib = 4*sx + sy;
8929
+
8930
+ *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
8931
+ }
8932
+ #else
7963
8933
  // load data and store to threadgroup memory
7964
8934
  if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
7965
8935
  threadgroup_barrier(mem_flags::mem_threadgroup);
7966
8936
 
7967
8937
  // no need for dequantization
7968
8938
  for (short i = 0; i < 16; i++) {
7969
- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7970
- + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7971
- + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
8939
+ const short sx = 2*il0 + i/8;
8940
+ const short sy = (tiitg/NL0)/8;
8941
+
8942
+ const short lx = i%8;
8943
+ const short ly = (tiitg/NL0)%8;
8944
+ //const short lx = (tiitg/NL0)%8;
8945
+ //const short ly = i%8;
8946
+
8947
+ *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
7972
8948
  }
7973
8949
  } else {
7974
8950
  S0_4x4 temp_a;
@@ -7977,91 +8953,135 @@ kernel void kernel_mul_mm(
7977
8953
  threadgroup_barrier(mem_flags::mem_threadgroup);
7978
8954
 
7979
8955
  FOR_UNROLL (short i = 0; i < 16; i++) {
7980
- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
7981
- + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
7982
- + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
8956
+ const short sx = 2*il0 + i/8;
8957
+ const short sy = (tiitg/NL0)/8;
8958
+
8959
+ const short lx = i%8;
8960
+ const short ly = (tiitg/NL0)%8;
8961
+ //const short lx = (tiitg/NL0)%8;
8962
+ //const short ly = i%8;
8963
+
8964
+ *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
7983
8965
  }
7984
8966
  }
7985
8967
 
7986
8968
  if (FC_mul_mm_bc_inp) {
7987
8969
  for (short i = 0; i < 8; ++i) {
7988
- sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
8970
+ const short sx = (tiitg%NL1);
8971
+ const short sy = (tiitg/NL1)/8;
8972
+
8973
+ const short lx = i;
8974
+ const short ly = (tiitg/NL1)%8;
8975
+ //const short lx = (tiitg/NL1)%8;
8976
+ //const short ly = i;
8977
+
8978
+ *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
7989
8979
  }
7990
8980
  } else {
7991
- *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
8981
+ const short sx = (tiitg%NL1);
8982
+ const short sy = (tiitg/NL1)/8;
8983
+
8984
+ //const short lx = i;
8985
+ const short ly = (tiitg/NL1)%8;
8986
+ //const short lx = (tiitg/NL1)%8;
8987
+ //const short ly = i;
8988
+
8989
+ *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
7992
8990
  }
8991
+ #endif
7993
8992
 
7994
8993
  il = (il + 2 < nl) ? il + 2 : il % 2;
7995
8994
  x = (il < 2) ? x + (2 + nl - 1)/nl : x;
7996
- y += BLOCK_SIZE_K;
8995
+
8996
+ y += NK;
7997
8997
 
7998
8998
  threadgroup_barrier(mem_flags::mem_threadgroup);
7999
8999
 
9000
+ #ifndef GGML_METAL_HAS_TENSOR
8000
9001
  // load matrices from threadgroup memory and conduct outer products
8001
- threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
8002
- threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
9002
+ threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
9003
+ threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
8003
9004
 
8004
- #pragma unroll(4)
8005
- for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
9005
+ FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
8006
9006
  simdgroup_barrier(mem_flags::mem_none);
8007
9007
 
8008
- #pragma unroll(4)
8009
- for (short i = 0; i < 4; i++) {
8010
- simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
9008
+ FOR_UNROLL (short i = 0; i < 4; i++) {
9009
+ simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
8011
9010
  }
8012
9011
 
8013
- #pragma unroll(2)
8014
- for (short i = 0; i < 2; i++) {
8015
- simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
9012
+ simdgroup_barrier(mem_flags::mem_none);
9013
+
9014
+ FOR_UNROLL (short i = 0; i < 2; i++) {
9015
+ simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
8016
9016
  }
8017
9017
 
8018
9018
  simdgroup_barrier(mem_flags::mem_none);
8019
9019
 
8020
- #pragma unroll(8)
8021
- for (short i = 0; i < 8; i++){
9020
+ FOR_UNROLL (short i = 0; i < 8; i++){
8022
9021
  simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
8023
9022
  }
8024
9023
 
8025
- lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
8026
- lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
9024
+ lsma += 8*64;
9025
+ lsmb += 4*64;
8027
9026
  }
9027
+ #else
9028
+ auto sA = tA.slice(0, 0);
9029
+ auto sB = tB.slice(0, 0);
9030
+
9031
+ mm.run(sB, sA, cT);
9032
+ #endif
8028
9033
  }
8029
9034
 
8030
- if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) {
9035
+ if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
8031
9036
  // if no bounds checks on the output are needed, we can directly write to device memory
9037
+ #ifdef GGML_METAL_HAS_TENSOR
8032
9038
  device float * C = (device float *) dst +
8033
- (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
8034
- (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
9039
+ r0 + \
9040
+ r1 * args.ne0 + im*args.ne1*args.ne0;
9041
+
9042
+ auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
9043
+ cT.store(tC);
9044
+ #else
9045
+ device float * C = (device float *) dst +
9046
+ (r0 + 32*(sgitg & 1)) + \
9047
+ (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
8035
9048
 
8036
9049
  for (short i = 0; i < 8; i++) {
8037
- simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
9050
+ simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
8038
9051
  }
9052
+ #endif
8039
9053
  } else {
8040
9054
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
8041
9055
  threadgroup_barrier(mem_flags::mem_threadgroup);
8042
- threadgroup float * temp_str = ((threadgroup float *) shmem) \
8043
- + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
9056
+
9057
+ threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
9058
+
9059
+ #ifdef GGML_METAL_HAS_TENSOR
9060
+ auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
9061
+ cT.store(tC);
9062
+ #else
8044
9063
  for (short i = 0; i < 8; i++) {
8045
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
9064
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
8046
9065
  }
9066
+ #endif
8047
9067
 
8048
9068
  threadgroup_barrier(mem_flags::mem_threadgroup);
8049
9069
 
8050
9070
  if (sgitg == 0) {
8051
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
8052
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
9071
+ for (int j = tiitg; j < nr1; j += NR1) {
9072
+ device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
8053
9073
  device float4 * D4 = (device float4 *) D;
8054
9074
 
8055
- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
9075
+ threadgroup float * C = temp_str + (j*NR0);
8056
9076
  threadgroup float4 * C4 = (threadgroup float4 *) C;
8057
9077
 
8058
9078
  int i = 0;
8059
- for (; i < n_rows/4; i++) {
9079
+ for (; i < nr0/4; i++) {
8060
9080
  *(D4 + i) = *(C4 + i);
8061
9081
  }
8062
9082
 
8063
9083
  i *= 4;
8064
- for (; i < n_rows; i++) {
9084
+ for (; i < nr0; i++) {
8065
9085
  *(D + i) = *(C + i);
8066
9086
  }
8067
9087
  }
@@ -8128,6 +9148,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
8128
9148
  template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
8129
9149
  template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
8130
9150
  template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
9151
+ template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
8131
9152
  template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
8132
9153
  template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
8133
9154
  template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
@@ -8146,55 +9167,55 @@ kernel void kernel_mul_mm_id(
8146
9167
  ushort tiitg[[thread_index_in_threadgroup]],
8147
9168
  ushort tiisg[[thread_index_in_simdgroup]],
8148
9169
  ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8149
-
8150
9170
  threadgroup S0 * sa = (threadgroup S0 *)(shmem);
8151
9171
  threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
8152
9172
 
8153
- const int r0 = tgpig.y;
8154
- const int r1 = tgpig.x;
9173
+ threadgroup float * sc = (threadgroup float *)(shmem);
9174
+
9175
+ constexpr int NR0 = 64;
9176
+ constexpr int NR1 = 32;
9177
+
9178
+ constexpr int NK = 32;
9179
+ constexpr int NL0 = NK/16;
9180
+ constexpr int NL1 = NK/8;
9181
+
8155
9182
  const int im = tgpig.z; // expert
9183
+ const int r0 = tgpig.y*NR0;
9184
+ const int r1 = tgpig.x*NR1;
8156
9185
 
8157
9186
  device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
8158
9187
  device const int32_t * ids_i32 = (device const int32_t *) (hids);
8159
9188
 
8160
9189
  const int32_t neh1 = tpe_u32[im];
8161
9190
 
8162
- if (r1*BLOCK_SIZE_N >= neh1) {
9191
+ if (r1 >= neh1) {
8163
9192
  return;
8164
9193
  }
8165
9194
 
8166
9195
  // if this block is of 64x32 shape or smaller
8167
- const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
8168
- const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
9196
+ const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
9197
+ const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
8169
9198
 
8170
9199
  // a thread shouldn't load data outside of the matrix
8171
- const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
8172
- const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
8173
-
8174
- S0_8x8 ma[4];
8175
- S1_8x8 mb[2];
9200
+ const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
9201
+ const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
8176
9202
 
8177
- simdgroup_float8x8 mc[8];
9203
+ const short il0 = (tiitg % NL0);
8178
9204
 
8179
- for (short i = 0; i < 8; i++){
8180
- mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8181
- }
9205
+ short il = il0;
8182
9206
 
8183
- short il = (tiitg % THREAD_PER_ROW);
8184
-
8185
- const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
9207
+ const int id = ids_i32[im*args.ne21 + r1 + lr1];
8186
9208
 
8187
9209
  const short i11 = (id % args.ne20) % args.ne11;
8188
9210
  const short i12 = (id / args.ne20);
8189
9211
  const short i13 = 0;
8190
9212
 
8191
9213
  const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
8192
- const short offset1 = il/nl;
9214
+ const short offset1 = il0/nl;
8193
9215
 
8194
- device const block_q * x = (device const block_q *)(src0
8195
- + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
9216
+ device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
8196
9217
 
8197
- const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
9218
+ const short iy = 8*(tiitg % NL1);
8198
9219
 
8199
9220
  device const T1 * y = (device const T1 *)(src1
8200
9221
  + args.nb13*i13
@@ -8202,16 +9223,45 @@ kernel void kernel_mul_mm_id(
8202
9223
  + args.nb11*i11
8203
9224
  + args.nb10*iy);
8204
9225
 
8205
- for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
9226
+ #ifndef GGML_METAL_HAS_TENSOR
9227
+ S0_8x8 ma[4];
9228
+ S1_8x8 mb[2];
9229
+
9230
+ simdgroup_float8x8 mc[8];
9231
+
9232
+ for (short i = 0; i < 8; i++){
9233
+ mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
9234
+ }
9235
+ #else
9236
+ auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
9237
+ auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
9238
+
9239
+ mpp::tensor_ops::matmul2d<
9240
+ mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
9241
+ execution_simdgroups<4>> mm;
9242
+
9243
+ auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
9244
+ #endif
9245
+
9246
+ for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
9247
+ #ifndef GGML_METAL_HAS_TENSOR
8206
9248
  // load data and store to threadgroup memory
8207
9249
  if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8208
9250
  threadgroup_barrier(mem_flags::mem_threadgroup);
8209
9251
 
8210
9252
  // no need for dequantization
8211
9253
  for (short i = 0; i < 16; i++) {
8212
- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8213
- + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8214
- + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
9254
+ const short sx = 2*il0 + i/8;
9255
+ const short sy = (tiitg/NL0)/8;
9256
+
9257
+ //const short lx = i%8;
9258
+ //const short ly = (tiitg/NL0)%8;
9259
+ const short lx = (tiitg/NL0)%8;
9260
+ const short ly = i%8;
9261
+
9262
+ const short ib = 8*sx + sy;
9263
+
9264
+ *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8215
9265
  }
8216
9266
  } else {
8217
9267
  S0_4x4 temp_a;
@@ -8220,85 +9270,188 @@ kernel void kernel_mul_mm_id(
8220
9270
  threadgroup_barrier(mem_flags::mem_threadgroup);
8221
9271
 
8222
9272
  FOR_UNROLL (short i = 0; i < 16; i++) {
8223
- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
8224
- + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
8225
- + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
9273
+ const short sx = 2*il0 + i/8;
9274
+ const short sy = (tiitg/NL0)/8;
9275
+
9276
+ //const short lx = i%8;
9277
+ //const short ly = (tiitg/NL0)%8;
9278
+ const short lx = (tiitg/NL0)%8;
9279
+ const short ly = i%8;
9280
+
9281
+ const short ib = 8*sx + sy;
9282
+
9283
+ // NOTE: this is massively slower.. WTF?
9284
+ //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
9285
+
9286
+ *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
8226
9287
  }
8227
9288
  }
8228
9289
 
8229
9290
  if (FC_mul_mm_bc_inp) {
8230
9291
  for (short i = 0; i < 8; ++i) {
8231
- sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
9292
+ const short sx = (tiitg%NL1);
9293
+ const short sy = (tiitg/NL1)/8;
9294
+
9295
+ const short lx = i;
9296
+ const short ly = (tiitg/NL1)%8;
9297
+ //const short lx = (tiitg/NL1)%8;
9298
+ //const short ly = i;
9299
+
9300
+ const short ib = 4*sx + sy;
9301
+
9302
+ *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
9303
+ }
9304
+ } else {
9305
+ const short sx = (tiitg%NL1);
9306
+ const short sy = (tiitg/NL1)/8;
9307
+
9308
+ const short dx = sx;
9309
+ const short dy = sy;
9310
+
9311
+ const short ly = (tiitg/NL1)%8;
9312
+
9313
+ const short ib = 4*sx + sy;
9314
+
9315
+ *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
9316
+ }
9317
+ #else
9318
+ // load data and store to threadgroup memory
9319
+ if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
9320
+ threadgroup_barrier(mem_flags::mem_threadgroup);
9321
+
9322
+ // no need for dequantization
9323
+ for (short i = 0; i < 16; i++) {
9324
+ const short sx = 2*il0 + i/8;
9325
+ const short sy = (tiitg/NL0)/8;
9326
+
9327
+ const short lx = i%8;
9328
+ const short ly = (tiitg/NL0)%8;
9329
+ //const short lx = (tiitg/NL0)%8;
9330
+ //const short ly = i%8;
9331
+
9332
+ *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8232
9333
  }
8233
9334
  } else {
8234
- *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
9335
+ S0_4x4 temp_a;
9336
+ dequantize_func(x, il, temp_a);
9337
+
9338
+ threadgroup_barrier(mem_flags::mem_threadgroup);
9339
+
9340
+ FOR_UNROLL (short i = 0; i < 16; i++) {
9341
+ const short sx = 2*il0 + i/8;
9342
+ const short sy = (tiitg/NL0)/8;
9343
+
9344
+ const short lx = i%8;
9345
+ const short ly = (tiitg/NL0)%8;
9346
+ //const short lx = (tiitg/NL0)%8;
9347
+ //const short ly = i%8;
9348
+
9349
+ *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
9350
+ }
8235
9351
  }
8236
9352
 
9353
+ if (FC_mul_mm_bc_inp) {
9354
+ for (short i = 0; i < 8; ++i) {
9355
+ const short sx = (tiitg%NL1);
9356
+ const short sy = (tiitg/NL1)/8;
9357
+
9358
+ const short lx = i;
9359
+ const short ly = (tiitg/NL1)%8;
9360
+ //const short lx = (tiitg/NL1)%8;
9361
+ //const short ly = i;
9362
+
9363
+ *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
9364
+ }
9365
+ } else {
9366
+ const short sx = (tiitg%NL1);
9367
+ const short sy = (tiitg/NL1)/8;
9368
+
9369
+ //const short lx = i;
9370
+ const short ly = (tiitg/NL1)%8;
9371
+ //const short lx = (tiitg/NL1)%8;
9372
+ //const short ly = i;
9373
+
9374
+ *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
9375
+ }
9376
+ #endif
9377
+
8237
9378
  il = (il + 2 < nl) ? il + 2 : il % 2;
8238
9379
  x = (il < 2) ? x + (2 + nl - 1)/nl : x;
8239
- y += BLOCK_SIZE_K;
9380
+
9381
+ y += NK;
8240
9382
 
8241
9383
  threadgroup_barrier(mem_flags::mem_threadgroup);
8242
9384
 
9385
+ #ifndef GGML_METAL_HAS_TENSOR
8243
9386
  // load matrices from threadgroup memory and conduct outer products
8244
- threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
8245
- threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
8246
-
8247
- #pragma unroll(4)
8248
- for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
8249
- #pragma unroll(4)
8250
- for (short i = 0; i < 4; i++) {
8251
- simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
9387
+ threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
9388
+ threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
9389
+
9390
+ FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
9391
+ simdgroup_barrier(mem_flags::mem_none);
9392
+
9393
+ FOR_UNROLL (short i = 0; i < 4; i++) {
9394
+ simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
8252
9395
  }
8253
9396
 
8254
9397
  simdgroup_barrier(mem_flags::mem_none);
8255
9398
 
8256
- #pragma unroll(2)
8257
- for (short i = 0; i < 2; i++) {
8258
- simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
9399
+ FOR_UNROLL (short i = 0; i < 2; i++) {
9400
+ simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
8259
9401
  }
8260
9402
 
8261
- #pragma unroll(8)
8262
- for (short i = 0; i < 8; i++){
9403
+ simdgroup_barrier(mem_flags::mem_none);
9404
+
9405
+ FOR_UNROLL (short i = 0; i < 8; i++){
8263
9406
  simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
8264
9407
  }
8265
9408
 
8266
- lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
8267
- lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
9409
+ lsma += 8*64;
9410
+ lsmb += 4*64;
8268
9411
  }
9412
+ #else
9413
+ auto sA = tA.slice(0, 0);
9414
+ auto sB = tB.slice(0, 0);
9415
+
9416
+ mm.run(sB, sA, cT);
9417
+ #endif
8269
9418
  }
8270
9419
 
9420
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
8271
9421
  threadgroup_barrier(mem_flags::mem_threadgroup);
8272
9422
 
8273
- threadgroup float * temp_str = ((threadgroup float *) shmem) \
8274
- + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
9423
+ #ifdef GGML_METAL_HAS_TENSOR
9424
+ auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
9425
+ cT.store(tC);
9426
+ #else
9427
+ threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
8275
9428
 
8276
- #pragma unroll(8)
8277
9429
  for (short i = 0; i < 8; i++) {
8278
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
9430
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
8279
9431
  }
9432
+ #endif
8280
9433
 
8281
9434
  threadgroup_barrier(mem_flags::mem_threadgroup);
8282
9435
 
8283
- for (short j = sgitg; j < n_cols; j += 4) {
8284
- const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
9436
+ for (short j = sgitg; j < nr1; j += 4) {
9437
+ const int id = ids_i32[im*args.ne21 + r1 + j];
8285
9438
 
8286
9439
  const short ide = id % args.ne20;
8287
9440
  const short idt = id / args.ne20;
8288
9441
 
8289
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
9442
+ device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
8290
9443
  device float4 * D4 = (device float4 *) D;
8291
9444
 
8292
- threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
9445
+ threadgroup float * C = (threadgroup float *) shmem + j*NR0;
8293
9446
  threadgroup float4 * C4 = (threadgroup float4 *) C;
8294
9447
 
8295
9448
  int i = tiisg;
8296
- for (; i < n_rows/4; i += 32) {
9449
+ for (; i < nr0/4; i += 32) {
8297
9450
  *(D4 + i) = *(C4 + i);
8298
9451
  }
8299
9452
 
8300
- i = (4*(n_rows/4)) + tiisg;
8301
- for (; i < n_rows; i += 32) {
9453
+ i = (4*(nr0/4)) + tiisg;
9454
+ for (; i < nr0; i += 32) {
8302
9455
  *(D + i) = *(C + i);
8303
9456
  }
8304
9457
  }
@@ -8310,12 +9463,13 @@ kernel void kernel_mul_mm_id(
8310
9463
  // get rows
8311
9464
  //
8312
9465
 
8313
- typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
9466
+ typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
8314
9467
 
8315
- template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
8316
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
9468
+ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
9469
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
9470
+ template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
8317
9471
  #if defined(GGML_METAL_HAS_BF16)
8318
- template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
9472
+ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
8319
9473
  #endif
8320
9474
 
8321
9475
  typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
@@ -8405,9 +9559,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
8405
9559
 
8406
9560
  template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
8407
9561
  template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
8408
- #if defined(GGML_METAL_HAS_BF16)
8409
- template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
8410
- #endif
8411
9562
  template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
8412
9563
  template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
8413
9564
  template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -8463,9 +9614,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
8463
9614
 
8464
9615
  template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
8465
9616
  template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
8466
- #if defined(GGML_METAL_HAS_BF16)
8467
- template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
8468
- #endif
8469
9617
  template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
8470
9618
  template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
8471
9619
  template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -8720,3 +9868,123 @@ kernel void kernel_pool_2d_avg_f32(
8720
9868
 
8721
9869
  o_ptr[cur_oh * args.OW + cur_ow] = res;
8722
9870
  }
9871
+
9872
+ kernel void kernel_opt_step_adamw_f32(
9873
+ constant ggml_metal_kargs_opt_step_adamw & args,
9874
+ device float * x,
9875
+ device const float * g,
9876
+ device float * g_m,
9877
+ device float * g_v,
9878
+ device const float * pars,
9879
+ uint gid[[thread_position_in_grid]]) {
9880
+
9881
+ if (gid >= args.np) {
9882
+ return;
9883
+ }
9884
+
9885
+ const float alpha = pars[0];
9886
+ const float beta1 = pars[1];
9887
+ const float beta2 = pars[2];
9888
+ const float eps = pars[3];
9889
+ const float wd = pars[4];
9890
+ const float beta1h = pars[5];
9891
+ const float beta2h = pars[6];
9892
+
9893
+ const float gi = g[gid];
9894
+ const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
9895
+ const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
9896
+
9897
+ g_m[gid] = gmi;
9898
+ g_v[gid] = gvi;
9899
+
9900
+ const float mh = gmi * beta1h;
9901
+ const float vh = sqrt(gvi * beta2h) + eps;
9902
+
9903
+ x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
9904
+ }
9905
+
9906
+ kernel void kernel_opt_step_sgd_f32(
9907
+ constant ggml_metal_kargs_opt_step_sgd & args,
9908
+ device float * x,
9909
+ device const float * g,
9910
+ device const float * pars,
9911
+ uint gid[[thread_position_in_grid]]) {
9912
+
9913
+ if (gid >= args.np) {
9914
+ return;
9915
+ }
9916
+
9917
+ x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
9918
+ }
9919
+
9920
+ template<typename T>
9921
+ kernel void kernel_memset(
9922
+ constant ggml_metal_kargs_fill & args,
9923
+ device T * dst,
9924
+ uint tpig[[thread_position_in_grid]]) {
9925
+ dst[tpig] = args.val;
9926
+ }
9927
+
9928
+ typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
9929
+
9930
+ template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
9931
+
9932
+ constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
9933
+
9934
+ template<typename T>
9935
+ kernel void kernel_count_equal(
9936
+ constant ggml_metal_kargs_count_equal & args,
9937
+ device const char * src0,
9938
+ device const char * src1,
9939
+ device atomic_int * dst,
9940
+ threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
9941
+ uint3 tgpig[[threadgroup_position_in_grid]],
9942
+ ushort3 tpitg[[thread_position_in_threadgroup]],
9943
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
9944
+ ushort tiisg[[thread_index_in_simdgroup]],
9945
+ ushort3 ntg[[threads_per_threadgroup]]) {
9946
+ const short NSG = FC_count_equal_nsg;
9947
+
9948
+ const int i3 = tgpig.z;
9949
+ const int i2 = tgpig.y;
9950
+ const int i1 = tgpig.x;
9951
+
9952
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
9953
+ return;
9954
+ }
9955
+
9956
+ int sum = 0;
9957
+
9958
+ device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
9959
+ device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
9960
+
9961
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
9962
+ const T v0 = *(device const T *)(base0 + i0*args.nb00);
9963
+ const T v1 = *(device const T *)(base1 + i0*args.nb10);
9964
+ sum += (v0 == v1);
9965
+ }
9966
+
9967
+ sum = simd_sum(sum);
9968
+
9969
+ if (tiisg == 0) {
9970
+ shmem_i32[sgitg] = sum;
9971
+ }
9972
+
9973
+ threadgroup_barrier(mem_flags::mem_threadgroup);
9974
+
9975
+ if (sgitg == 0) {
9976
+ float v = 0.0f;
9977
+ if (tpitg.x < NSG) {
9978
+ v = shmem_i32[tpitg.x];
9979
+ }
9980
+
9981
+ float total = simd_sum(v);
9982
+ if (tpitg.x == 0) {
9983
+ atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
9984
+ }
9985
+ }
9986
+ }
9987
+
9988
+ typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
9989
+
9990
+ template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;