whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
77
77
  return x*y;
78
78
  }
79
79
 
80
+ static inline float sum(float x) {
81
+ return x;
82
+ }
83
+
84
+ static inline float sum(float4 x) {
85
+ return x[0] + x[1] + x[2] + x[3];
86
+ }
87
+
80
88
  // NOTE: this is not dequantizing - we are simply fitting the template
81
89
  template <typename type4x4>
82
90
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -895,752 +903,432 @@ enum ggml_sort_order {
895
903
  GGML_SORT_ORDER_DESC,
896
904
  };
897
905
 
898
- // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
899
- // pros: works for non-contiguous tensors, supports broadcast across all dims
900
- // cons: not very efficient
901
- template <int F>
902
- kernel void kernel_add_fuse_impl(
903
- constant ggml_metal_kargs_bin & args,
904
- device const char * src0,
905
- device const char * src1,
906
- device char * dst,
907
- uint3 tgpig[[threadgroup_position_in_grid]],
908
- ushort3 tpitg[[thread_position_in_threadgroup]],
909
- ushort3 ntg[[threads_per_threadgroup]]) {
910
- const int i03 = tgpig.z;
911
- const int i02 = tgpig.y;
912
- const int i01 = tgpig.x;
906
+ constant float GELU_COEF_A = 0.044715f;
907
+ constant float GELU_QUICK_COEF = -1.702f;
908
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
909
+ constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
913
910
 
914
- const int i13 = i03%args.ne13;
915
- const int i12 = i02%args.ne12;
916
- const int i11 = i01%args.ne11;
911
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
912
+ // ref: https://www.johndcook.com/blog/python_erf/
913
+ constant float p_erf = 0.3275911f;
914
+ constant float a1_erf = 0.254829592f;
915
+ constant float a2_erf = -0.284496736f;
916
+ constant float a3_erf = 1.421413741f;
917
+ constant float a4_erf = -1.453152027f;
918
+ constant float a5_erf = 1.061405429f;
917
919
 
918
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
919
- device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
920
+ template<typename T>
921
+ inline T erf_approx(T x) {
922
+ T sign_x = sign(x);
923
+ x = fabs(x);
924
+ T t = 1.0f / (1.0f + p_erf * x);
925
+ T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
926
+ return sign_x * y;
927
+ }
920
928
 
921
- device const float * src1_ptr[F];
922
- for (short j = 0; j < F; ++j) {
923
- src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
924
- }
929
+ template<typename T> T elu_approx(T x);
925
930
 
926
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
927
- const int i10 = i0%args.ne10;
931
+ template<> inline float elu_approx<float>(float x) {
932
+ return (x > 0.f) ? x : (exp(x) - 1);
933
+ }
928
934
 
929
- float res = src0_ptr[i0];
935
+ template<> inline float4 elu_approx<float4>(float4 x) {
936
+ float4 res;
930
937
 
931
- #pragma unroll
932
- for (short j = 0; j < F; ++j) {
933
- res += src1_ptr[j][i10];
934
- }
938
+ res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
939
+ res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
940
+ res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
941
+ res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
935
942
 
936
- dst_ptr[i0] = res;
937
- }
943
+ return res;
938
944
  }
939
945
 
940
- typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
941
-
942
- template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
943
- template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
944
- template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
945
- template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
946
- template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
947
- template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
948
- template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
949
- template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
946
+ constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
947
+ constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
950
948
 
951
- kernel void kernel_sub_fuse_1(
952
- constant ggml_metal_kargs_bin & args,
949
+ template <typename T0, typename T, typename TC>
950
+ kernel void kernel_unary_impl(
951
+ constant ggml_metal_kargs_unary & args,
953
952
  device const char * src0,
954
- device const char * src1,
955
953
  device char * dst,
956
954
  uint3 tgpig[[threadgroup_position_in_grid]],
957
955
  ushort3 tpitg[[thread_position_in_threadgroup]],
958
956
  ushort3 ntg[[threads_per_threadgroup]]) {
959
- const int i03 = tgpig.z;
960
- const int i02 = tgpig.y;
961
- const int i01 = tgpig.x;
962
-
963
- const int i13 = i03%args.ne13;
964
- const int i12 = i02%args.ne12;
965
- const int i11 = i01%args.ne11;
957
+ #define FC_OP FC_unary_op
958
+ #define FC_CNT FC_unary_cnt
966
959
 
967
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
968
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
969
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
960
+ device const T0 * src0_ptr;
961
+ device T * dst_ptr;
970
962
 
971
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
972
- const int i10 = i0%args.ne10;
973
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
974
- }
975
- }
963
+ int i0;
976
964
 
977
- kernel void kernel_mul_fuse_1(
978
- constant ggml_metal_kargs_bin & args,
979
- device const char * src0,
980
- device const char * src1,
981
- device char * dst,
982
- uint3 tgpig[[threadgroup_position_in_grid]],
983
- ushort3 tpitg[[thread_position_in_threadgroup]],
984
- ushort3 ntg[[threads_per_threadgroup]]) {
985
- const int i03 = tgpig.z;
986
- const int i02 = tgpig.y;
987
- const int i01 = tgpig.x;
965
+ if (FC_CNT) {
966
+ i0 = tgpig.x;
988
967
 
989
- const int i13 = i03%args.ne13;
990
- const int i12 = i02%args.ne12;
991
- const int i11 = i01%args.ne11;
968
+ src0_ptr = (device const T0 *) (src0);
969
+ dst_ptr = (device T *) (dst);
970
+ } else {
971
+ const int i03 = tgpig.z;
972
+ const int i02 = tgpig.y;
973
+ const int k0 = tgpig.x/args.ne01;
974
+ const int i01 = tgpig.x - k0*args.ne01;
992
975
 
993
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
994
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
995
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
976
+ i0 = k0*ntg.x + tpitg.x;
996
977
 
997
- if (args.ne10 == 1) {
998
- const float x = *((device float *)(src1_ptr));
999
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1000
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
1001
- }
1002
- } else {
1003
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1004
- const int i10 = i0%args.ne10;
1005
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
1006
- }
978
+ src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
979
+ dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
1007
980
  }
1008
- }
1009
981
 
1010
- kernel void kernel_div_fuse_1(
1011
- constant ggml_metal_kargs_bin & args,
1012
- device const char * src0,
1013
- device const char * src1,
1014
- device char * dst,
1015
- uint3 tgpig[[threadgroup_position_in_grid]],
1016
- ushort3 tpitg[[thread_position_in_threadgroup]],
1017
- ushort3 ntg[[threads_per_threadgroup]]) {
1018
- const int i03 = tgpig.z;
1019
- const int i02 = tgpig.y;
1020
- const int i01 = tgpig.x;
982
+ {
983
+ //threadgroup_barrier(mem_flags::mem_none);
1021
984
 
1022
- const int i13 = i03%args.ne13;
1023
- const int i12 = i02%args.ne12;
1024
- const int i11 = i01%args.ne11;
985
+ if (!FC_CNT) {
986
+ if (i0 >= args.ne0) {
987
+ return;
988
+ }
989
+ }
1025
990
 
1026
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
1027
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
1028
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
991
+ const TC x = (TC) src0_ptr[i0];
1029
992
 
1030
- if (args.ne10 == 1) {
1031
- const float x = 1.0f / *((device float *)(src1_ptr));
1032
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1033
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
993
+ if (FC_OP == OP_UNARY_NUM_SCALE) {
994
+ dst_ptr[i0] = (T) (args.scale * x + args.bias);
1034
995
  }
1035
- } else {
1036
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1037
- const int i10 = i0%args.ne10;
1038
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
996
+
997
+ if (FC_OP == OP_UNARY_NUM_FILL) {
998
+ dst_ptr[i0] = (T) args.val;
1039
999
  }
1040
- }
1041
- }
1042
1000
 
1043
- kernel void kernel_add_id(
1044
- constant ggml_metal_kargs_add_id & args,
1045
- device const char * src0,
1046
- device const char * src1,
1047
- device const char * src2,
1048
- device char * dst,
1049
- uint3 tgpig[[threadgroup_position_in_grid]],
1050
- ushort3 tpitg[[thread_position_in_threadgroup]],
1051
- ushort3 ntg[[threads_per_threadgroup]]) {
1052
- const int i1 = tgpig.x;
1053
- const int i2 = tgpig.y;
1001
+ if (FC_OP == OP_UNARY_NUM_CLAMP) {
1002
+ dst_ptr[i0] = (T) clamp(x, args.min, args.max);
1003
+ }
1054
1004
 
1055
- const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
1005
+ if (FC_OP == OP_UNARY_NUM_SQR) {
1006
+ dst_ptr[i0] = (T) (x * x);
1007
+ }
1056
1008
 
1057
- const size_t nb1 = args.ne0 * sizeof(float);
1058
- const size_t nb2 = args.ne1 * nb1;
1009
+ if (FC_OP == OP_UNARY_NUM_SQRT) {
1010
+ dst_ptr[i0] = (T) sqrt(x);
1011
+ }
1059
1012
 
1060
- device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
1061
- device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
1062
- device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
1013
+ if (FC_OP == OP_UNARY_NUM_SIN) {
1014
+ dst_ptr[i0] = (T) sin(x);
1015
+ }
1063
1016
 
1064
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1065
- dst_row[i0] = src0_row[i0] + src1_row[i0];
1066
- }
1067
- }
1017
+ if (FC_OP == OP_UNARY_NUM_COS) {
1018
+ dst_ptr[i0] = (T) cos(x);
1019
+ }
1068
1020
 
1069
- template<typename T>
1070
- kernel void kernel_repeat(
1071
- constant ggml_metal_kargs_repeat & args,
1072
- device const char * src0,
1073
- device char * dst,
1074
- uint3 tgpig[[threadgroup_position_in_grid]],
1075
- ushort3 tpitg[[thread_position_in_threadgroup]],
1076
- ushort3 ntg[[threads_per_threadgroup]]) {
1077
- const int i3 = tgpig.z;
1078
- const int i2 = tgpig.y;
1079
- const int i1 = tgpig.x;
1021
+ if (FC_OP == OP_UNARY_NUM_LOG) {
1022
+ dst_ptr[i0] = (T) log(x);
1023
+ }
1080
1024
 
1081
- const int i03 = i3%args.ne03;
1082
- const int i02 = i2%args.ne02;
1083
- const int i01 = i1%args.ne01;
1025
+ if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
1026
+ dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
1027
+ }
1084
1028
 
1085
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
1086
- device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
1029
+ if (FC_OP == OP_UNARY_NUM_TANH) {
1030
+ dst_ptr[i0] = (T) precise::tanh(x);
1031
+ }
1087
1032
 
1088
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1089
- const int i00 = i0%args.ne00;
1090
- *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
1091
- }
1092
- }
1033
+ if (FC_OP == OP_UNARY_NUM_RELU) {
1034
+ dst_ptr[i0] = (T) fmax(0, x);
1035
+ }
1093
1036
 
1094
- typedef decltype(kernel_repeat<float>) kernel_repeat_t;
1037
+ if (FC_OP == OP_UNARY_NUM_SIGMOID) {
1038
+ dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
1039
+ }
1095
1040
 
1096
- template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
1097
- template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
1098
- template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
1099
- template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
1041
+ if (FC_OP == OP_UNARY_NUM_GELU) {
1042
+ dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
1043
+ }
1100
1044
 
1101
- // assumption: src1 is a row
1102
- // broadcast src1 into src0
1103
- template <short F>
1104
- kernel void kernel_add_row_c4_fuse_impl(
1105
- constant ggml_metal_kargs_bin & args,
1106
- device const char * src0,
1107
- device const char * src1,
1108
- device char * dst,
1109
- uint tpig[[thread_position_in_grid]]) {
1110
- const uint nb = args.ne00/4;
1111
- const uint i = tpig % nb;
1045
+ if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
1046
+ dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
1047
+ }
1112
1048
 
1113
- device const float4 * src0_row = (device const float4 *) (src0);
1114
- device float4 * dst_row = (device float4 *) (dst);
1049
+ if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
1050
+ dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
1051
+ }
1115
1052
 
1116
- float4 res = src0_row[tpig];
1053
+ if (FC_OP == OP_UNARY_NUM_SILU) {
1054
+ dst_ptr[i0] = (T) (x / (1 + exp(-x)));
1055
+ }
1117
1056
 
1118
- #pragma unroll(F)
1119
- for (short j = 0; j < F; ++j) {
1120
- res += ((device const float4 *) (src1 + args.o1[j]))[i];
1121
- }
1057
+ if (FC_OP == OP_UNARY_NUM_ELU) {
1058
+ dst_ptr[i0] = (T) elu_approx(x);
1059
+ }
1122
1060
 
1123
- dst_row[tpig] = res;
1124
- }
1061
+ if (FC_OP == OP_UNARY_NUM_NEG) {
1062
+ dst_ptr[i0] = (T) -x;
1063
+ }
1125
1064
 
1126
- typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
1065
+ if (FC_OP == OP_UNARY_NUM_ABS) {
1066
+ dst_ptr[i0] = (T) fabs(x);
1067
+ }
1127
1068
 
1128
- template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
1129
- template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
1130
- template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
1131
- template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
1132
- template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
1133
- template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
1134
- template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
1135
- template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
1069
+ if (FC_OP == OP_UNARY_NUM_SGN) {
1070
+ dst_ptr[i0] = T(x > 0) - T(x < 0);
1071
+ }
1136
1072
 
1137
- template <short F>
1138
- kernel void kernel_sub_row_c4_fuse_impl(
1139
- constant ggml_metal_kargs_bin & args,
1140
- device const char * src0,
1141
- device const char * src1,
1142
- device char * dst,
1143
- uint tpig[[thread_position_in_grid]]) {
1073
+ if (FC_OP == OP_UNARY_NUM_STEP) {
1074
+ dst_ptr[i0] = T(x > 0);
1075
+ }
1144
1076
 
1145
- const uint nb = args.ne00/4;
1146
- const uint i = tpig % nb;
1077
+ if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
1078
+ dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
1079
+ }
1147
1080
 
1148
- device const float4 * src0_row = (device const float4 *) (src0);
1149
- device float4 * dst_row = (device float4 *) (dst);
1081
+ if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
1082
+ dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
1083
+ }
1150
1084
 
1151
- device const float4 * src1_row[F];
1152
- for (short j = 0; j < F; ++j) {
1153
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1154
- }
1085
+ if (FC_OP == OP_UNARY_NUM_EXP) {
1086
+ dst_ptr[i0] = (T) exp(x);
1087
+ }
1155
1088
 
1156
- float4 res = src0_row[tpig];
1089
+ if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
1090
+ dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
1091
+ }
1157
1092
 
1158
- #pragma unroll(F)
1159
- for (short j = 0; j < F; ++j) {
1160
- res -= src1_row[j][i];
1093
+ if (FC_OP == OP_UNARY_NUM_EXPM1) {
1094
+ // TODO: precise implementation
1095
+ dst_ptr[i0] = (T) (exp(x) - 1);
1096
+ }
1161
1097
  }
1162
1098
 
1163
- dst_row[tpig] = res;
1099
+ #undef FC_OP
1100
+ #undef FC_CNT
1164
1101
  }
1165
1102
 
1166
- typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
1103
+ typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
1167
1104
 
1168
- template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
1105
+ template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
1106
+ template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
1107
+ template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
1108
+ template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
1169
1109
 
1170
- template <short F>
1171
- kernel void kernel_mul_row_c4_fuse_impl(
1172
- constant ggml_metal_kargs_bin & args,
1173
- device const char * src0,
1174
- device const char * src1,
1175
- device char * dst,
1176
- uint tpig[[thread_position_in_grid]]) {
1177
-
1178
- const uint nb = args.ne00/4;
1179
- const uint i = tpig % nb;
1180
-
1181
- device const float4 * src0_row = (device const float4 *) (src0);
1182
- device float4 * dst_row = (device float4 *) (dst);
1110
+ // OP: 0 - add, 1 - sub, 2 - mul, 3 - div
1111
+ constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
1112
+ constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
1113
+ constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
1114
+ constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
1183
1115
 
1184
- device const float4 * src1_row[F];
1185
- for (short j = 0; j < F; ++j) {
1186
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1187
- }
1188
-
1189
- float4 res = src0_row[tpig];
1190
-
1191
- #pragma unroll(F)
1192
- for (short j = 0; j < F; ++j) {
1193
- res *= src1_row[j][i];
1194
- }
1195
-
1196
- dst_row[tpig] = res;
1197
- }
1198
-
1199
- typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
1200
-
1201
- template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
1202
-
1203
- template <short F>
1204
- kernel void kernel_div_row_c4_fuse_impl(
1116
+ template <typename T0, typename T1, typename T>
1117
+ kernel void kernel_bin_fuse_impl(
1205
1118
  constant ggml_metal_kargs_bin & args,
1206
1119
  device const char * src0,
1207
1120
  device const char * src1,
1208
1121
  device char * dst,
1209
- uint tpig[[thread_position_in_grid]]) {
1210
-
1211
- const uint nb = args.ne00/4;
1212
- const uint i = tpig % nb;
1213
-
1214
- device const float4 * src0_row = (device const float4 *) (src0);
1215
- device float4 * dst_row = (device float4 *) (dst);
1216
-
1217
- device const float4 * src1_row[F];
1218
- for (short j = 0; j < F; ++j) {
1219
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1220
- }
1221
-
1222
- float4 res = src0_row[tpig];
1223
-
1224
- #pragma unroll(F)
1225
- for (short j = 0; j < F; ++j) {
1226
- res /= src1_row[j][i];
1227
- }
1228
-
1229
- dst_row[tpig] = res;
1230
- }
1231
-
1232
- typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
1233
-
1234
- template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
1235
-
1236
- kernel void kernel_scale_f32(
1237
- constant ggml_metal_kargs_scale & args,
1238
- device const float * src0,
1239
- device float * dst,
1240
- uint tpig[[thread_position_in_grid]]) {
1241
- dst[tpig] = src0[tpig] * args.scale + args.bias;
1242
- }
1243
-
1244
- kernel void kernel_scale_f32_4(
1245
- constant ggml_metal_kargs_scale & args,
1246
- device const float4 * src0,
1247
- device float4 * dst,
1248
- uint tpig[[thread_position_in_grid]]) {
1249
- dst[tpig] = src0[tpig] * args.scale + args.bias;
1250
- }
1251
-
1252
- kernel void kernel_fill_f32(
1253
- constant ggml_metal_kargs_fill & args,
1254
- device const float * src0,
1255
- device float * dst,
1256
- uint tpig[[thread_position_in_grid]]) {
1257
- dst[tpig] = args.val;
1258
- }
1259
-
1260
- kernel void kernel_fill_f32_4(
1261
- constant ggml_metal_kargs_fill & args,
1262
- device const float4 * src0,
1263
- device float4 * dst,
1264
- uint tpig[[thread_position_in_grid]]) {
1265
- dst[tpig] = args.val;
1266
- }
1267
-
1268
- kernel void kernel_clamp_f32(
1269
- constant ggml_metal_kargs_clamp & args,
1270
- device const float * src0,
1271
- device float * dst,
1272
- uint tpig[[thread_position_in_grid]]) {
1273
- dst[tpig] = clamp(src0[tpig], args.min, args.max);
1274
- }
1275
-
1276
- kernel void kernel_clamp_f32_4(
1277
- constant ggml_metal_kargs_clamp & args,
1278
- device const float4 * src0,
1279
- device float4 * dst,
1280
- uint tpig[[thread_position_in_grid]]) {
1281
- dst[tpig] = clamp(src0[tpig], args.min, args.max);
1282
- }
1283
-
1284
- kernel void kernel_relu_f32(
1285
- device const float * src0,
1286
- device float * dst,
1287
- uint tpig[[thread_position_in_grid]]) {
1288
- dst[tpig] = max(0.0f, src0[tpig]);
1289
- }
1290
-
1291
- kernel void kernel_relu_f32_4(
1292
- device const float4 * src0,
1293
- device float4 * dst,
1294
- uint tpig[[thread_position_in_grid]]) {
1295
- dst[tpig] = max(0.0f, src0[tpig]);
1296
- }
1297
-
1298
- kernel void kernel_sigmoid_f32(
1299
- device const float * src0,
1300
- device float * dst,
1301
- uint tpig[[thread_position_in_grid]]) {
1302
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
1303
- }
1304
-
1305
- kernel void kernel_sigmoid_f32_4(
1306
- device const float4 * src0,
1307
- device float4 * dst,
1308
- uint tpig[[thread_position_in_grid]]) {
1309
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
1310
- }
1311
-
1312
- kernel void kernel_tanh_f32(
1313
- device const float * src0,
1314
- device float * dst,
1315
- uint tpig[[thread_position_in_grid]]) {
1316
- dst[tpig] = precise::tanh(src0[tpig]);
1317
- }
1318
-
1319
- kernel void kernel_tanh_f32_4(
1320
- device const float4 * src0,
1321
- device float4 * dst,
1322
- uint tpig[[thread_position_in_grid]]) {
1323
- dst[tpig] = precise::tanh(src0[tpig]);
1324
- }
1325
-
1326
- constant float GELU_COEF_A = 0.044715f;
1327
- constant float GELU_QUICK_COEF = -1.702f;
1328
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
1329
- constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
1330
-
1331
- kernel void kernel_gelu_f32(
1332
- device const float * src0,
1333
- device float * dst,
1334
- uint tpig[[thread_position_in_grid]]) {
1335
- device const float & x = src0[tpig];
1336
-
1337
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
1338
- }
1339
-
1340
- kernel void kernel_gelu_f32_4(
1341
- device const float4 * src0,
1342
- device float4 * dst,
1343
- uint tpig[[thread_position_in_grid]]) {
1344
- device const float4 & x = src0[tpig];
1345
-
1346
- // BEWARE !!!
1347
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
1348
- // This was observed with Falcon 7B and 40B models
1349
- //
1350
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
1351
- }
1352
-
1353
- kernel void kernel_gelu_quick_f32(
1354
- device const float * src0,
1355
- device float * dst,
1356
- uint tpig[[thread_position_in_grid]]) {
1357
- device const float & x = src0[tpig];
1358
-
1359
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
1360
- }
1122
+ uint3 tgpig[[threadgroup_position_in_grid]],
1123
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1124
+ ushort3 ntg[[threads_per_threadgroup]]) {
1125
+ #define FC_OP FC_bin_op
1126
+ #define FC_F FC_bin_f
1127
+ #define FC_RB FC_bin_rb
1128
+ #define FC_CB FC_bin_cb
1361
1129
 
1362
- kernel void kernel_gelu_quick_f32_4(
1363
- device const float4 * src0,
1364
- device float4 * dst,
1365
- uint tpig[[thread_position_in_grid]]) {
1366
- device const float4 & x = src0[tpig];
1130
+ if (FC_RB) {
1131
+ // row broadcast
1132
+ const uint i0 = tgpig.y*args.ne00 + tgpig.x;
1133
+ const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
1367
1134
 
1368
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
1369
- }
1135
+ device const T0 * src0_row = (device const T0 *) (src0);
1136
+ device T * dst_row = (device T *) (dst);
1370
1137
 
1371
- // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
1372
- // ref: https://www.johndcook.com/blog/python_erf/
1373
- constant float p_erf = 0.3275911f;
1374
- constant float a1_erf = 0.254829592f;
1375
- constant float a2_erf = -0.284496736f;
1376
- constant float a3_erf = 1.421413741f;
1377
- constant float a4_erf = -1.453152027f;
1378
- constant float a5_erf = 1.061405429f;
1138
+ if (FC_F == 1) {
1139
+ device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
1379
1140
 
1380
- template<typename T>
1381
- T erf_approx(T x) {
1382
- T sign_x = sign(x);
1383
- x = fabs(x);
1384
- T t = 1.0f / (1.0f + p_erf * x);
1385
- T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
1386
- return sign_x * y;
1387
- }
1141
+ if (FC_OP == 0) {
1142
+ dst_row[i0] = src0_row[i0] + src1_row[i1];
1143
+ }
1388
1144
 
1389
- kernel void kernel_gelu_erf_f32(
1390
- device const float * src0,
1391
- device float * dst,
1392
- uint tpig[[thread_position_in_grid]]) {
1393
- device const float & x = src0[tpig];
1145
+ if (FC_OP == 1) {
1146
+ dst_row[i0] = src0_row[i0] - src1_row[i1];
1147
+ }
1394
1148
 
1395
- dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
1396
- }
1149
+ if (FC_OP == 2) {
1150
+ dst_row[i0] = src0_row[i0] * src1_row[i1];
1151
+ }
1397
1152
 
1398
- kernel void kernel_gelu_erf_f32_4(
1399
- device const float4 * src0,
1400
- device float4 * dst,
1401
- uint tpig[[thread_position_in_grid]]) {
1402
- device const float4 & x = src0[tpig];
1153
+ if (FC_OP == 3) {
1154
+ dst_row[i0] = src0_row[i0] / src1_row[i1];
1155
+ }
1156
+ } else {
1157
+ T0 res = src0_row[i0];
1403
1158
 
1404
- dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
1405
- }
1159
+ if (FC_OP == 0) {
1160
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1161
+ res += ((device const T1 *) (src1 + args.o1[j]))[i1];
1162
+ }
1163
+ }
1406
1164
 
1407
- kernel void kernel_silu_f32(
1408
- device const float * src0,
1409
- device float * dst,
1410
- uint tpig[[thread_position_in_grid]]) {
1411
- device const float & x = src0[tpig];
1412
- dst[tpig] = x / (1.0f + exp(-x));
1413
- }
1165
+ if (FC_OP == 1) {
1166
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1167
+ res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
1168
+ }
1169
+ }
1414
1170
 
1415
- kernel void kernel_silu_f32_4(
1416
- device const float4 * src0,
1417
- device float4 * dst,
1418
- uint tpig[[thread_position_in_grid]]) {
1419
- device const float4 & x = src0[tpig];
1420
- dst[tpig] = x / (1.0f + exp(-x));
1421
- }
1171
+ if (FC_OP == 2) {
1172
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1173
+ res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
1174
+ }
1175
+ }
1422
1176
 
1423
- kernel void kernel_elu_f32(
1424
- device const float * src0,
1425
- device float * dst,
1426
- uint tpig[[thread_position_in_grid]]) {
1427
- const float x = src0[tpig];
1428
- dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
1429
- }
1177
+ if (FC_OP == 3) {
1178
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1179
+ res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
1180
+ }
1181
+ }
1430
1182
 
1431
- kernel void kernel_elu_f32_4(
1432
- device const float4 * src0,
1433
- device float4 * dst,
1434
- uint tpig[[thread_position_in_grid]]) {
1435
- const float4 x = src0[tpig];
1436
- dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
1437
- dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
1438
- dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
1439
- dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
1440
- }
1183
+ dst_row[i0] = res;
1184
+ }
1185
+ } else {
1186
+ const int i03 = tgpig.z;
1187
+ const int i02 = tgpig.y;
1188
+ const int i01 = tgpig.x;
1441
1189
 
1442
- kernel void kernel_sqr_f32(
1443
- device const float * src0,
1444
- device float * dst,
1445
- uint tpig[[thread_position_in_grid]]) {
1446
- dst[tpig] = src0[tpig] * src0[tpig];
1447
- }
1190
+ if (i01 >= args.ne01) {
1191
+ return;
1192
+ }
1448
1193
 
1449
- kernel void kernel_sqr_f32_4(
1450
- device const float4 * src0,
1451
- device float4 * dst,
1452
- uint tpig[[thread_position_in_grid]]) {
1453
- dst[tpig] = src0[tpig] * src0[tpig];
1454
- }
1194
+ const int i13 = i03%args.ne13;
1195
+ const int i12 = i02%args.ne12;
1196
+ const int i11 = i01%args.ne11;
1455
1197
 
1456
- kernel void kernel_sqrt_f32(
1457
- device const float * src0,
1458
- device float * dst,
1459
- uint tpig[[thread_position_in_grid]]) {
1460
- dst[tpig] = sqrt(src0[tpig]);
1461
- }
1198
+ device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
1199
+ device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
1462
1200
 
1463
- kernel void kernel_sqrt_f32_4(
1464
- device const float4 * src0,
1465
- device float4 * dst,
1466
- uint tpig[[thread_position_in_grid]]) {
1467
- dst[tpig] = sqrt(src0[tpig]);
1468
- }
1201
+ if (FC_F == 1) {
1202
+ device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
1469
1203
 
1470
- kernel void kernel_sin_f32(
1471
- device const float * src0,
1472
- device float * dst,
1473
- uint tpig[[thread_position_in_grid]]) {
1474
- dst[tpig] = sin(src0[tpig]);
1475
- }
1204
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1205
+ const int i10 = FC_CB ? i0%args.ne10 : i0;
1476
1206
 
1477
- kernel void kernel_sin_f32_4(
1478
- device const float4 * src0,
1479
- device float4 * dst,
1480
- uint tpig[[thread_position_in_grid]]) {
1481
- dst[tpig] = sin(src0[tpig]);
1482
- }
1207
+ if (FC_OP == 0) {
1208
+ dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
1209
+ }
1483
1210
 
1484
- kernel void kernel_cos_f32(
1485
- device const float * src0,
1486
- device float * dst,
1487
- uint tpig[[thread_position_in_grid]]) {
1488
- dst[tpig] = cos(src0[tpig]);
1489
- }
1211
+ if (FC_OP == 1) {
1212
+ dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
1213
+ }
1490
1214
 
1491
- kernel void kernel_cos_f32_4(
1492
- device const float4 * src0,
1493
- device float4 * dst,
1494
- uint tpig[[thread_position_in_grid]]) {
1495
- dst[tpig] = cos(src0[tpig]);
1496
- }
1215
+ if (FC_OP == 2) {
1216
+ dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
1217
+ }
1497
1218
 
1498
- kernel void kernel_log_f32(
1499
- device const float * src0,
1500
- device float * dst,
1501
- uint tpig[[thread_position_in_grid]]) {
1502
- dst[tpig] = log(src0[tpig]);
1503
- }
1219
+ if (FC_OP == 3) {
1220
+ dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
1221
+ }
1222
+ }
1223
+ } else {
1224
+ device const T1 * src1_ptr[8];
1225
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1226
+ src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
1227
+ }
1504
1228
 
1505
- kernel void kernel_log_f32_4(
1506
- device const float4 * src0,
1507
- device float4 * dst,
1508
- uint tpig[[thread_position_in_grid]]) {
1509
- dst[tpig] = log(src0[tpig]);
1510
- }
1229
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1230
+ const int i10 = FC_CB ? i0%args.ne10 : i0;
1511
1231
 
1512
- kernel void kernel_neg_f32(
1513
- device const float * src0,
1514
- device float * dst,
1515
- uint tpig[[thread_position_in_grid]]) {
1516
- dst[tpig] = -src0[tpig];
1517
- }
1232
+ T res = src0_ptr[i0];
1518
1233
 
1519
- kernel void kernel_neg_f32_4(
1520
- device const float4 * src0,
1521
- device float4 * dst,
1522
- uint tpig[[thread_position_in_grid]]) {
1523
- dst[tpig] = -src0[tpig];
1524
- }
1234
+ if (FC_OP == 0) {
1235
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1236
+ res += src1_ptr[j][i10];
1237
+ }
1238
+ }
1525
1239
 
1526
- kernel void kernel_abs_f32(
1527
- device const float * src0,
1528
- device float * dst,
1529
- uint tpig[[thread_position_in_grid]]) {
1530
- dst[tpig] = fabs(src0[tpig]);
1531
- }
1240
+ if (FC_OP == 1) {
1241
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1242
+ res -= src1_ptr[j][i10];
1243
+ }
1244
+ }
1532
1245
 
1533
- kernel void kernel_abs_f32_4(
1534
- device const float4 * src0,
1535
- device float4 * dst,
1536
- uint tpig[[thread_position_in_grid]]) {
1537
- dst[tpig] = fabs(src0[tpig]);
1538
- }
1246
+ if (FC_OP == 2) {
1247
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1248
+ res *= src1_ptr[j][i10];
1249
+ }
1250
+ }
1539
1251
 
1540
- kernel void kernel_sgn_f32(
1541
- device const float * src0,
1542
- device float * dst,
1543
- uint tpig[[thread_position_in_grid]]) {
1544
- dst[tpig] = sign(src0[tpig]);
1545
- }
1252
+ if (FC_OP == 3) {
1253
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1254
+ res /= src1_ptr[j][i10];
1255
+ }
1256
+ }
1546
1257
 
1547
- kernel void kernel_sgn_f32_4(
1548
- device const float4 * src0,
1549
- device float4 * dst,
1550
- uint tpig[[thread_position_in_grid]]) {
1551
- dst[tpig] = sign(src0[tpig]);
1552
- }
1258
+ dst_ptr[i0] = res;
1259
+ }
1260
+ }
1261
+ }
1553
1262
 
1554
- kernel void kernel_step_f32(
1555
- device const float * src0,
1556
- device float * dst,
1557
- uint tpig[[thread_position_in_grid]]) {
1558
- dst[tpig] = step(0.0f, src0[tpig]);
1263
+ #undef FC_OP
1264
+ #undef FC_F
1265
+ #undef FC_RB
1266
+ #undef FC_CB
1559
1267
  }
1560
1268
 
1561
- kernel void kernel_step_f32_4(
1562
- device const float4 * src0,
1563
- device float4 * dst,
1564
- uint tpig[[thread_position_in_grid]]) {
1565
- dst[tpig] = step(0.0f, src0[tpig]);
1566
- }
1269
+ typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
1567
1270
 
1568
- kernel void kernel_hardswish_f32(
1569
- device const float * src0,
1570
- device float * dst,
1571
- uint tpig[[thread_position_in_grid]]) {
1572
- const float x = src0[tpig];
1573
- dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1574
- }
1271
+ template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
1272
+ template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
1575
1273
 
1576
- kernel void kernel_hardswish_f32_4(
1577
- device const float4 * src0,
1578
- device float4 * dst,
1579
- uint tpig[[thread_position_in_grid]]) {
1580
- const float4 x = src0[tpig];
1581
- dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1582
- }
1274
+ kernel void kernel_add_id(
1275
+ constant ggml_metal_kargs_add_id & args,
1276
+ device const char * src0,
1277
+ device const char * src1,
1278
+ device const char * src2,
1279
+ device char * dst,
1280
+ uint3 tgpig[[threadgroup_position_in_grid]],
1281
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1282
+ ushort3 ntg[[threads_per_threadgroup]]) {
1283
+ const int i1 = tgpig.x;
1284
+ const int i2 = tgpig.y;
1583
1285
 
1584
- kernel void kernel_hardsigmoid_f32(
1585
- device const float * src0,
1586
- device float * dst,
1587
- uint tpig[[thread_position_in_grid]]) {
1588
- const float x = src0[tpig];
1589
- dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1590
- }
1286
+ const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
1591
1287
 
1592
- kernel void kernel_hardsigmoid_f32_4(
1593
- device const float4 * src0,
1594
- device float4 * dst,
1595
- uint tpig[[thread_position_in_grid]]) {
1596
- const float4 x = src0[tpig];
1597
- dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1598
- }
1288
+ const size_t nb1 = args.ne0 * sizeof(float);
1289
+ const size_t nb2 = args.ne1 * nb1;
1599
1290
 
1600
- kernel void kernel_exp_f32(
1601
- device const float * src0,
1602
- device float * dst,
1603
- uint tpig[[thread_position_in_grid]]) {
1604
- dst[tpig] = exp(src0[tpig]);
1605
- }
1291
+ device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
1292
+ device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
1293
+ device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
1606
1294
 
1607
- kernel void kernel_exp_f32_4(
1608
- device const float4 * src0,
1609
- device float4 * dst,
1610
- uint tpig[[thread_position_in_grid]]) {
1611
- dst[tpig] = exp(src0[tpig]);
1295
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1296
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
1297
+ }
1612
1298
  }
1613
1299
 
1614
- kernel void kernel_softplus_f32(
1615
- device const float * src0,
1616
- device float * dst,
1617
- uint tpig[[thread_position_in_grid]]) {
1618
- device const float & x = src0[tpig];
1619
- dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1620
- }
1300
+ template<typename T>
1301
+ kernel void kernel_repeat(
1302
+ constant ggml_metal_kargs_repeat & args,
1303
+ device const char * src0,
1304
+ device char * dst,
1305
+ uint3 tgpig[[threadgroup_position_in_grid]],
1306
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1307
+ ushort3 ntg[[threads_per_threadgroup]]) {
1308
+ const int i3 = tgpig.z;
1309
+ const int i2 = tgpig.y;
1310
+ const int i1 = tgpig.x;
1621
1311
 
1622
- kernel void kernel_softplus_f32_4(
1623
- device const float4 * src0,
1624
- device float4 * dst,
1625
- uint tpig[[thread_position_in_grid]]) {
1626
- device const float4 & x = src0[tpig];
1627
- dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1628
- }
1312
+ const int i03 = i3%args.ne03;
1313
+ const int i02 = i2%args.ne02;
1314
+ const int i01 = i1%args.ne01;
1629
1315
 
1630
- kernel void kernel_expm1_f32(
1631
- device const float * src0,
1632
- device float * dst,
1633
- uint tpig[[thread_position_in_grid]]) {
1634
- dst[tpig] = exp(src0[tpig]) - 1.0f;
1635
- }
1316
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
1317
+ device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
1636
1318
 
1637
- kernel void kernel_expm1_f32_4(
1638
- device const float4 * src0,
1639
- device float4 * dst,
1640
- uint tpig[[thread_position_in_grid]]) {
1641
- dst[tpig] = exp(src0[tpig]) - 1.0f;
1319
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1320
+ const int i00 = i0%args.ne00;
1321
+ *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
1322
+ }
1642
1323
  }
1643
1324
 
1325
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
1326
+
1327
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
1328
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
1329
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
1330
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
1331
+
1644
1332
  kernel void kernel_reglu_f32(
1645
1333
  constant ggml_metal_kargs_glu & args,
1646
1334
  device const char * src0,
@@ -1824,33 +1512,35 @@ kernel void kernel_op_sum_f32(
1824
1512
  }
1825
1513
  }
1826
1514
 
1827
- template <bool norm>
1828
- kernel void kernel_sum_rows(
1515
+ constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
1516
+
1517
+ template <typename T0, typename T>
1518
+ kernel void kernel_sum_rows_impl(
1829
1519
  constant ggml_metal_kargs_sum_rows & args,
1830
- device const float * src0,
1831
- device float * dst,
1832
- threadgroup float * shmem_f32 [[threadgroup(0)]],
1520
+ device const char * src0,
1521
+ device char * dst,
1522
+ threadgroup char * shmem [[threadgroup(0)]],
1833
1523
  uint3 tgpig[[threadgroup_position_in_grid]],
1834
1524
  ushort3 tpitg[[thread_position_in_threadgroup]],
1835
1525
  ushort sgitg[[simdgroup_index_in_threadgroup]],
1836
1526
  ushort tiisg[[thread_index_in_simdgroup]],
1837
1527
  ushort3 ntg[[threads_per_threadgroup]]) {
1838
- int64_t i3 = tgpig.z;
1839
- int64_t i2 = tgpig.y;
1840
- int64_t i1 = tgpig.x;
1528
+ #define FC_OP FC_sum_rows_op
1841
1529
 
1842
- if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1843
- return;
1844
- }
1530
+ const int i3 = tgpig.z;
1531
+ const int i2 = tgpig.y;
1532
+ const int i1 = tgpig.x;
1533
+
1534
+ threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
1845
1535
 
1846
1536
  if (sgitg == 0) {
1847
- shmem_f32[tiisg] = 0.0f;
1537
+ shmem_t[tiisg] = 0.0f;
1848
1538
  }
1849
1539
 
1850
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1851
- device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1540
+ device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1541
+ device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1852
1542
 
1853
- float sumf = 0;
1543
+ T0 sumf = T0(0.0f);
1854
1544
 
1855
1545
  for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1856
1546
  sumf += src_row[i0];
@@ -1861,23 +1551,33 @@ kernel void kernel_sum_rows(
1861
1551
  threadgroup_barrier(mem_flags::mem_threadgroup);
1862
1552
 
1863
1553
  if (tiisg == 0) {
1864
- shmem_f32[sgitg] = sumf;
1554
+ shmem_t[sgitg] = sumf;
1865
1555
  }
1866
1556
 
1867
1557
  threadgroup_barrier(mem_flags::mem_threadgroup);
1868
1558
 
1869
- sumf = shmem_f32[tiisg];
1559
+ sumf = shmem_t[tiisg];
1870
1560
  sumf = simd_sum(sumf);
1871
1561
 
1872
1562
  if (tpitg.x == 0) {
1873
- dst_row[0] = norm ? sumf / args.ne00 : sumf;
1563
+ if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
1564
+ if (is_same<float4, T0>::value) {
1565
+ dst_row[0] = sum(sumf) / (4*args.ne00);
1566
+ } else {
1567
+ dst_row[0] = sum(sumf) / args.ne00;
1568
+ }
1569
+ } else {
1570
+ dst_row[0] = sum(sumf);
1571
+ }
1874
1572
  }
1573
+
1574
+ #undef FC_OP
1875
1575
  }
1876
1576
 
1877
- typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1577
+ typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
1878
1578
 
1879
- template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1880
- template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1579
+ template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
1580
+ template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
1881
1581
 
1882
1582
  template<typename T>
1883
1583
  kernel void kernel_cumsum_blk(
@@ -2689,51 +2389,347 @@ kernel void kernel_rwkv_wkv7_f32(
2689
2389
  const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
2690
2390
  const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
2691
2391
 
2692
- for (uint t = start_t; t < end_t; t += C) {
2693
- threadgroup_barrier(mem_flags::mem_threadgroup);
2694
- _r[tid] = r[t];
2695
- _w[tid] = w[t];
2696
- _k[tid] = k[t];
2697
- _a[tid] = a[t];
2698
- _b[tid] = b[t];
2392
+ for (uint t = start_t; t < end_t; t += C) {
2393
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2394
+ _r[tid] = r[t];
2395
+ _w[tid] = w[t];
2396
+ _k[tid] = k[t];
2397
+ _a[tid] = a[t];
2398
+ _b[tid] = b[t];
2399
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2400
+
2401
+ const float v_val = v[t];
2402
+ float y = 0.0, sa = 0.0;
2403
+
2404
+ float4 sa_vec(0.0);
2405
+
2406
+ for (uint j = 0; j < head_size; j += 4) {
2407
+ float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
2408
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2409
+ sa_vec += a_vec * s_vec;
2410
+ }
2411
+ sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
2412
+
2413
+ for (uint j = 0; j < head_size; j += 4) {
2414
+ float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
2415
+ float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
2416
+ float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
2417
+ float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
2418
+ float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2419
+
2420
+ float4 kv = k_vec * v_val;
2421
+
2422
+ s_vec = s_vec * w_vec + kv + sa * b_vec;
2423
+ y += dot(s_vec, r_vec);
2424
+
2425
+ state[j] = s_vec[0];
2426
+ state[j+1] = s_vec[1];
2427
+ state[j+2] = s_vec[2];
2428
+ state[j+3] = s_vec[3];
2429
+ }
2430
+
2431
+ dst[t] = y;
2432
+ }
2433
+
2434
+ for (uint i = 0; i < head_size; i++) {
2435
+ dst[T * C + batch_id * state_size + head_id * head_size * head_size
2436
+ + tid * head_size + i] = state[i];
2437
+ }
2438
+ }
2439
+
2440
+ constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
2441
+ constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
2442
+
2443
+ #if 1
2444
+ template<short NSG>
2445
+ kernel void kernel_gated_delta_net_impl(
2446
+ constant ggml_metal_kargs_gated_delta_net & args,
2447
+ device const char * q,
2448
+ device const char * k,
2449
+ device const char * v,
2450
+ device const char * g,
2451
+ device const char * b,
2452
+ device const char * s,
2453
+ device char * dst,
2454
+ uint3 tgpig[[threadgroup_position_in_grid]],
2455
+ uint3 tpitg[[thread_position_in_threadgroup]],
2456
+ uint3 ntg[[threads_per_threadgroup]]) {
2457
+ #define S_v FC_gated_delta_net_ne20
2458
+ #define G FC_gated_delta_net_ne30
2459
+
2460
+ const uint tx = tpitg.x;
2461
+ const uint ty = tpitg.y;
2462
+
2463
+ const uint i23 = tgpig.z; // B
2464
+ const uint i21 = tgpig.y; // H
2465
+ const uint i20 = tgpig.x*NSG + ty;
2466
+
2467
+ const uint i01 = i21 % args.ne01;
2468
+ const uint i11 = i21 % args.ne11;
2469
+
2470
+ const float scale = 1.0f / sqrt((float)S_v);
2471
+
2472
+ // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
2473
+ device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
2474
+
2475
+ float ls[NSG];
2476
+
2477
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2478
+ const short is = tx*NSG + j;
2479
+ ls[j] = s_ptr[is];
2480
+ }
2481
+
2482
+ device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
2483
+
2484
+ device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
2485
+ device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
2486
+ device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
2487
+
2488
+ device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
2489
+ device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
2490
+
2491
+ for (short t = 0; t < args.ne22; t++) {
2492
+ float s_k = 0.0f;
2493
+
2494
+ if (G == 1) {
2495
+ const float g_exp = exp(g_ptr[0]);
2496
+
2497
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2498
+ const short is = tx*NSG + j;
2499
+ ls[j] *= g_exp;
2500
+
2501
+ s_k += ls[j]*k_ptr[is];
2502
+ }
2503
+ } else {
2504
+ // KDA
2505
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2506
+ const short is = tx*NSG + j;
2507
+ ls[j] *= exp(g_ptr[is]);
2508
+
2509
+ s_k += ls[j]*k_ptr[is];
2510
+ }
2511
+ }
2512
+
2513
+ s_k = simd_sum(s_k);
2514
+
2515
+ const float d = (v_ptr[i20] - s_k)*b_ptr[0];
2516
+
2517
+ float y = 0.0f;
2518
+
2519
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2520
+ const short is = tx*NSG + j;
2521
+ ls[j] += k_ptr[is]*d;
2522
+
2523
+ y += ls[j]*q_ptr[is];
2524
+ }
2525
+
2526
+ y = simd_sum(y);
2527
+
2528
+ if (tx == 0) {
2529
+ dst_attn[t*args.ne21*S_v] = y*scale;
2530
+ }
2531
+
2532
+ q_ptr += args.ns02;
2533
+ k_ptr += args.ns12;
2534
+ v_ptr += args.ns22;
2535
+
2536
+ b_ptr += args.ne21;
2537
+ g_ptr += args.ne21*G;
2538
+ }
2539
+
2540
+ device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
2541
+
2542
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2543
+ const short is = tx*NSG + j;
2544
+ dst_state[is] = ls[j];
2545
+ }
2546
+
2547
+ #undef S_v
2548
+ #undef G
2549
+ }
2550
+
2551
+ typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
2552
+
2553
+ template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
2554
+ template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
2555
+ template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
2556
+
2557
+ #else
2558
+ // a simplified version of the above
2559
+ // no performance improvement, so keep the above version for now
2560
+
2561
+ template<typename T, short NSG>
2562
+ kernel void kernel_gated_delta_net_impl(
2563
+ constant ggml_metal_kargs_gated_delta_net & args,
2564
+ device const char * q,
2565
+ device const char * k,
2566
+ device const char * v,
2567
+ device const char * g,
2568
+ device const char * b,
2569
+ device const char * s,
2570
+ device char * dst,
2571
+ uint3 tgpig[[threadgroup_position_in_grid]],
2572
+ uint3 tpitg[[thread_position_in_threadgroup]],
2573
+ uint3 ntg[[threads_per_threadgroup]]) {
2574
+ #define S_v FC_gated_delta_net_ne20
2575
+ #define G FC_gated_delta_net_ne30
2576
+
2577
+ const uint tx = tpitg.x;
2578
+ const uint ty = tpitg.y;
2579
+
2580
+ const uint i23 = tgpig.z; // B
2581
+ const uint i21 = tgpig.y; // H
2582
+ const uint i20 = tgpig.x*NSG + ty;
2583
+
2584
+ const uint i01 = i21 % args.ne01;
2585
+ const uint i11 = i21 % args.ne11;
2586
+
2587
+ const float scale = 1.0f / sqrt((float)S_v);
2588
+
2589
+ device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
2590
+
2591
+ float lsf[NSG];
2592
+
2593
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2594
+ const short is = tx*NSG + j;
2595
+ lsf[j] = s_ptr[is*S_v];
2596
+ }
2597
+
2598
+ thread T * ls = (thread T *) (lsf);
2599
+
2600
+ device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
2601
+
2602
+ device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
2603
+ device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
2604
+ device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
2605
+
2606
+ device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
2607
+ device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
2608
+
2609
+ for (short t = 0; t < args.ne22; t++) {
2610
+ device const T * qt_ptr = (device const T *) (q_ptr);
2611
+ device const T * kt_ptr = (device const T *) (k_ptr);
2612
+ device const T * gt_ptr = (device const T *) (g_ptr);
2613
+
2614
+ if (G == 1) {
2615
+ *ls *= exp(g_ptr[0]);
2616
+ } else {
2617
+ // KDA
2618
+ *ls *= exp(gt_ptr[tx]);
2619
+ }
2620
+
2621
+ const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
2622
+
2623
+ const float d = (v_ptr[i20] - s_k)*b_ptr[0];
2624
+
2625
+ *ls += kt_ptr[tx]*d;
2626
+
2627
+ const float y = simd_sum(dot(*ls, qt_ptr[tx]));
2628
+
2629
+ if (tx == 0) {
2630
+ *dst_attn = y*scale;
2631
+ }
2632
+
2633
+ q_ptr += args.ns02;
2634
+ k_ptr += args.ns12;
2635
+ v_ptr += args.ns22;
2636
+
2637
+ b_ptr += args.ne21;
2638
+ g_ptr += args.ne21*G;
2639
+
2640
+ dst_attn += args.ne21*S_v;
2641
+ }
2642
+
2643
+ device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
2644
+ device T * dstt_state = (device T *) (dst_state);
2645
+
2646
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2647
+ const short is = tx*NSG + j;
2648
+ dst_state[is*S_v] = lsf[j];
2649
+ }
2650
+
2651
+ #undef S_v
2652
+ #undef G
2653
+ }
2654
+
2655
+ typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
2656
+
2657
+ template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
2658
+ template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
2659
+ template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
2660
+ #endif
2661
+
2662
+ constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
2663
+ constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
2664
+ constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
2665
+
2666
+ kernel void kernel_solve_tri_f32(
2667
+ constant ggml_metal_kargs_solve_tri & args,
2668
+ device const char * src0,
2669
+ device const char * src1,
2670
+ device char * dst,
2671
+ threadgroup char * shmem [[threadgroup(0)]],
2672
+ ushort3 tgpig[[threadgroup_position_in_grid]],
2673
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
2674
+ ushort tiisg[[thread_index_in_simdgroup]],
2675
+ ushort3 ntg[[threads_per_threadgroup]]) {
2676
+ constexpr short NW = N_SIMDWIDTH;
2677
+
2678
+ const short NSG = FC_solve_tri_nsg;
2679
+ const short N = FC_solve_tri_n;
2680
+ const short K = FC_solve_tri_k;
2681
+ const short NP = PAD2(N, NW);
2682
+
2683
+ const int32_t i03 = tgpig.z;
2684
+ const int32_t i02 = tgpig.y;
2685
+ const int32_t i01 = tgpig.x*NSG + sgitg;
2686
+
2687
+ threadgroup float * sh0 = (threadgroup float *) shmem;
2688
+
2689
+ device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
2690
+ device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
2691
+ device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
2692
+
2693
+ for (short rr = 0; rr < N; rr += NSG) {
2699
2694
  threadgroup_barrier(mem_flags::mem_threadgroup);
2700
2695
 
2701
- const float v_val = v[t];
2702
- float y = 0.0, sa = 0.0;
2696
+ {
2697
+ threadgroup float * sh0_cur = sh0 + sgitg*NP;
2703
2698
 
2704
- float4 sa_vec(0.0);
2699
+ for (short t = 0; t*NW < N; ++t) {
2700
+ const short idx = t*NW + tiisg;
2701
+ sh0_cur[idx] = src0_ptr[idx];
2702
+ }
2705
2703
 
2706
- for (uint j = 0; j < head_size; j += 4) {
2707
- float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
2708
- float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2709
- sa_vec += a_vec * s_vec;
2704
+ src0_ptr += NSG*N;
2710
2705
  }
2711
- sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
2712
2706
 
2713
- for (uint j = 0; j < head_size; j += 4) {
2714
- float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
2715
- float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
2716
- float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
2717
- float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
2718
- float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2707
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2719
2708
 
2720
- float4 kv = k_vec * v_val;
2709
+ if (i01 >= args.ne10) {
2710
+ continue;
2711
+ }
2721
2712
 
2722
- s_vec = s_vec * w_vec + kv + sa * b_vec;
2723
- y += dot(s_vec, r_vec);
2713
+ for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
2714
+ const short r = rr + ir;
2724
2715
 
2725
- state[j] = s_vec[0];
2726
- state[j+1] = s_vec[1];
2727
- state[j+2] = s_vec[2];
2728
- state[j+3] = s_vec[3];
2729
- }
2716
+ threadgroup float * sh0_cur = sh0 + ir*NP;
2730
2717
 
2731
- dst[t] = y;
2732
- }
2718
+ float sum = 0.0f;
2733
2719
 
2734
- for (uint i = 0; i < head_size; i++) {
2735
- dst[T * C + batch_id * state_size + head_id * head_size * head_size
2736
- + tid * head_size + i] = state[i];
2720
+ for (short t = 0; t*NW < r; ++t) {
2721
+ const short idx = t*NW + tiisg;
2722
+ sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
2723
+ }
2724
+
2725
+ sum = simd_sum(sum);
2726
+
2727
+ if (tiisg == 0) {
2728
+ const float diag = sh0_cur[r];
2729
+
2730
+ dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
2731
+ }
2732
+ }
2737
2733
  }
2738
2734
  }
2739
2735
 
@@ -2970,26 +2966,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
2970
2966
  template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
2971
2967
  template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
2972
2968
 
2973
- kernel void kernel_l2_norm_f32(
2969
+ template <typename T0, typename T>
2970
+ kernel void kernel_l2_norm_impl(
2974
2971
  constant ggml_metal_kargs_l2_norm & args,
2975
2972
  device const char * src0,
2976
2973
  device char * dst,
2977
2974
  threadgroup float * shmem_f32 [[threadgroup(0)]],
2978
- uint tgpig[[threadgroup_position_in_grid]],
2979
- ushort tpitg[[thread_position_in_threadgroup]],
2980
- ushort sgitg[[simdgroup_index_in_threadgroup]],
2981
- ushort tiisg[[thread_index_in_simdgroup]],
2982
- ushort ntg[[threads_per_threadgroup]]) {
2975
+ uint3 tgpig[[threadgroup_position_in_grid]],
2976
+ ushort3 tpitg[[thread_position_in_threadgroup]],
2977
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
2978
+ ushort tiisg[[thread_index_in_simdgroup]],
2979
+ ushort3 ntg[[threads_per_threadgroup]]) {
2980
+ const int i03 = tgpig.z;
2981
+ const int i02 = tgpig.y;
2982
+ const int i01 = tgpig.x;
2983
+
2983
2984
  if (sgitg == 0) {
2984
2985
  shmem_f32[tiisg] = 0.0f;
2985
2986
  }
2986
2987
 
2987
- device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
2988
+ device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2989
+ device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2988
2990
 
2989
2991
  float sumf = 0.0f;
2990
2992
 
2991
2993
  // parallel sum
2992
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
2994
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
2993
2995
  sumf += dot(x[i00], x[i00]);
2994
2996
  }
2995
2997
  sumf = simd_sum(sumf);
@@ -3005,14 +3007,18 @@ kernel void kernel_l2_norm_f32(
3005
3007
  sumf = shmem_f32[tiisg];
3006
3008
  sumf = simd_sum(sumf);
3007
3009
 
3008
- const float scale = 1.0f/sqrt(max(sumf, args.eps));
3010
+ const float scale = 1.0f/max(sqrt(sumf), args.eps);
3009
3011
 
3010
- device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
3011
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
3012
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
3012
3013
  y[i00] = x[i00] * scale;
3013
3014
  }
3014
3015
  }
3015
3016
 
3017
+ typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
3018
+
3019
+ template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
3020
+ template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
3021
+
3016
3022
  kernel void kernel_group_norm_f32(
3017
3023
  constant ggml_metal_kargs_group_norm & args,
3018
3024
  device const float * src0,
@@ -3700,6 +3706,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
3700
3706
  template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
3701
3707
  template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
3702
3708
 
3709
+ #if defined(GGML_METAL_HAS_BF16)
3710
+ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
3711
+ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
3712
+ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
3713
+ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
3714
+ #endif
3715
+
3703
3716
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
3704
3717
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
3705
3718
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@@ -3750,6 +3763,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
3750
3763
  template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
3751
3764
  template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
3752
3765
 
3766
+ template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
3767
+ template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
3768
+ template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
3769
+ template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
3770
+
3771
+ template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
3772
+ template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
3773
+ template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
3774
+ template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
3775
+
3753
3776
  template<typename T0, typename T1, short NR0, typename args_t>
3754
3777
  void kernel_mul_mv_t_t_impl(
3755
3778
  args_t args,
@@ -4437,7 +4460,7 @@ kernel void kernel_im2col(
4437
4460
  template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4438
4461
  template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4439
4462
 
4440
- // TODO: obolete -- remove
4463
+ // TODO: obsolete -- remove
4441
4464
  //typedef void (im2col_ext_t)(
4442
4465
  // constant ggml_metal_kargs_im2col & args,
4443
4466
  // device const float * x,
@@ -4749,7 +4772,9 @@ kernel void kernel_conv_transpose_2d<half>(
4749
4772
  uint3 tpitg[[thread_position_in_threadgroup]],
4750
4773
  uint3 ntg[[threads_per_threadgroup]]);
4751
4774
 
4752
- kernel void kernel_upscale_f32(
4775
+ constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
4776
+
4777
+ kernel void kernel_upscale_nearest_f32(
4753
4778
  constant ggml_metal_kargs_upscale & args,
4754
4779
  device const char * src0,
4755
4780
  device char * dst,
@@ -4775,6 +4800,156 @@ kernel void kernel_upscale_f32(
4775
4800
  }
4776
4801
  }
4777
4802
 
4803
+ static inline float bilinear_tri(float x) {
4804
+ return MAX(0.0f, 1.0f - fabs(x));
4805
+ }
4806
+
4807
+ kernel void kernel_upscale_bilinear_f32(
4808
+ constant ggml_metal_kargs_upscale & args,
4809
+ device const char * src0,
4810
+ device char * dst,
4811
+ uint3 tgpig[[threadgroup_position_in_grid]],
4812
+ uint3 tpitg[[thread_position_in_threadgroup]],
4813
+ uint3 ntg[[threads_per_threadgroup]]) {
4814
+
4815
+ const int64_t i3 = tgpig.z;
4816
+ const int64_t i2 = tgpig.y;
4817
+ const int64_t i1 = tgpig.x;
4818
+
4819
+ const int64_t i03 = i3 / args.sf3;
4820
+ const int64_t i02 = i2 / args.sf2;
4821
+
4822
+ const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
4823
+ const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
4824
+ const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
4825
+ const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
4826
+
4827
+ src0 += i03*args.nb03 + i02*args.nb02;
4828
+
4829
+ device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
4830
+
4831
+ if (FC_upscale_aa) {
4832
+ const float support0 = MAX(1.0f, 1.0f / args.sf0);
4833
+ const float invscale0 = 1.0f / support0;
4834
+ const float support1 = MAX(1.0f, 1.0f / args.sf1);
4835
+ const float invscale1 = 1.0f / support1;
4836
+
4837
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4838
+ const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
4839
+
4840
+ int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
4841
+ int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
4842
+
4843
+ int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
4844
+ int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
4845
+
4846
+ float sum = 0.0f;
4847
+ float wsum = 0.0f;
4848
+
4849
+ for (int64_t sy = y_min; sy < y_max; ++sy) {
4850
+ const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
4851
+ for (int64_t sx = x_min; sx < x_max; ++sx) {
4852
+ const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
4853
+ const float w = wx * wy;
4854
+ const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
4855
+ sum += (*src_ptr) * w;
4856
+ wsum += w;
4857
+ }
4858
+ }
4859
+
4860
+ const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
4861
+ dst_ptr[i0] = v;
4862
+ }
4863
+ } else {
4864
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4865
+ const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
4866
+ const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
4867
+ const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
4868
+ const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
4869
+
4870
+ device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
4871
+ device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
4872
+ device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
4873
+ device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
4874
+
4875
+ const float v =
4876
+ (*src00) * (1.0f - fd0) * (1.0f - fd1) +
4877
+ (*src10) * fd0 * (1.0f - fd1) +
4878
+ (*src01) * (1.0f - fd0) * fd1 +
4879
+ (*src11) * fd0 * fd1;
4880
+
4881
+ dst_ptr[i0] = v;
4882
+ }
4883
+ }
4884
+ }
4885
+
4886
+ static inline float bicubic_weight1(float x) {
4887
+ const float a = -0.75f;
4888
+ return ((a + 2) * x - (a + 3)) * x * x + 1;
4889
+ }
4890
+
4891
+ static inline float bicubic_weight2(float x) {
4892
+ const float a = -0.75f;
4893
+ return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
4894
+ }
4895
+
4896
+ kernel void kernel_upscale_bicubic_f32(
4897
+ constant ggml_metal_kargs_upscale & args,
4898
+ device const char * src0,
4899
+ device char * dst,
4900
+ uint3 tgpig[[threadgroup_position_in_grid]],
4901
+ uint3 tpitg[[thread_position_in_threadgroup]],
4902
+ uint3 ntg[[threads_per_threadgroup]]) {
4903
+
4904
+ const int64_t i3 = tgpig.z;
4905
+ const int64_t i2 = tgpig.y;
4906
+ const int64_t i1 = tgpig.x;
4907
+
4908
+ const int64_t i03 = i3 / args.sf3;
4909
+ const int64_t i02 = i2 / args.sf2;
4910
+
4911
+ const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
4912
+ const int64_t i01 = (int64_t)floor(f01);
4913
+ const float fd1 = f01 - (float)i01;
4914
+
4915
+ const float w_y0 = bicubic_weight2(fd1 + 1.0f);
4916
+ const float w_y1 = bicubic_weight1(fd1);
4917
+ const float w_y2 = bicubic_weight1(1.0f - fd1);
4918
+ const float w_y3 = bicubic_weight2(2.0f - fd1);
4919
+
4920
+ const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
4921
+
4922
+ device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
4923
+
4924
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4925
+ const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
4926
+ const int64_t i00 = (int64_t)floor(f00);
4927
+ const float fd0 = f00 - (float)i00;
4928
+
4929
+ const float w_x0 = bicubic_weight2(fd0 + 1.0f);
4930
+ const float w_x1 = bicubic_weight1(fd0);
4931
+ const float w_x2 = bicubic_weight1(1.0f - fd0);
4932
+ const float w_x3 = bicubic_weight2(2.0f - fd0);
4933
+
4934
+ float sum = 0.0f;
4935
+
4936
+ for (int dy = -1; dy <= 2; ++dy) {
4937
+ const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
4938
+ const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
4939
+
4940
+ for (int dx = -1; dx <= 2; ++dx) {
4941
+ const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
4942
+ const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
4943
+
4944
+ const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
4945
+ sum += (*src_ptr) * wx * wy;
4946
+ }
4947
+ }
4948
+
4949
+ dst_ptr[i0] = sum;
4950
+ }
4951
+ }
4952
+
4778
4953
  kernel void kernel_pad_f32(
4779
4954
  constant ggml_metal_kargs_pad & args,
4780
4955
  device const char * src0,
@@ -5114,24 +5289,6 @@ kernel void kernel_argsort_merge_f32_i32(
5114
5289
  template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
5115
5290
  template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
5116
5291
 
5117
- kernel void kernel_leaky_relu_f32(
5118
- constant ggml_metal_kargs_leaky_relu & args,
5119
- device const float * src0,
5120
- device float * dst,
5121
- uint tpig[[thread_position_in_grid]]) {
5122
- const float x = src0[tpig];
5123
- dst[tpig] = x > 0.0f ? x : x * args.slope;
5124
- }
5125
-
5126
- kernel void kernel_leaky_relu_f32_4(
5127
- constant ggml_metal_kargs_leaky_relu & args,
5128
- device const float4 * src0,
5129
- device float4 * dst,
5130
- uint tpig[[thread_position_in_grid]]) {
5131
- const float4 x = src0[tpig];
5132
- dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
5133
- }
5134
-
5135
5292
  constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
5136
5293
 
5137
5294
  constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -5208,6 +5365,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
5208
5365
  // scan the blocks of the mask that are not masked
5209
5366
  // 0 - masked (i.e. full of -INF, skip)
5210
5367
  // 1 - not masked (i.e. at least one element of the mask is not -INF)
5368
+ // 2 - all zero
5211
5369
  kernel void kernel_flash_attn_ext_blk(
5212
5370
  constant ggml_metal_kargs_flash_attn_ext_blk & args,
5213
5371
  device const char * mask,
@@ -5229,27 +5387,29 @@ kernel void kernel_flash_attn_ext_blk(
5229
5387
 
5230
5388
  device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
5231
5389
 
5232
- // fast route
5233
- if (res == 0) {
5234
- if (simd_max(*mask_src) > -MAXHALF/2) {
5235
- res = 1;
5236
- }
5237
- }
5238
-
5239
5390
  // detailed check of the elements of the block
5240
5391
  if ((C > NW || Q > 1) && res == 0) {
5241
- half m = -MAXHALF;
5392
+ half mmin = MAXHALF;
5393
+ half mmax = -MAXHALF;
5242
5394
 
5243
5395
  FOR_UNROLL (short j = 0; j < Q; ++j) {
5244
5396
  FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
5245
- m = max(m, mask_src[ii*NW]);
5397
+ mmin = min(mmin, mask_src[ii*NW]);
5398
+ mmax = max(mmax, mask_src[ii*NW]);
5246
5399
  }
5247
5400
 
5248
5401
  mask_src += args.nb31/2;
5249
5402
  }
5250
5403
 
5251
- if (simd_max(m) > -MAXHALF/2) {
5252
- res = 1;
5404
+ mmin = simd_min(mmin);
5405
+ mmax = simd_max(mmax);
5406
+
5407
+ if (mmax > -MAXHALF) {
5408
+ if (mmin == 0.0 && mmax == 0.0) {
5409
+ res = 2;
5410
+ } else {
5411
+ res = 1;
5412
+ }
5253
5413
  }
5254
5414
  }
5255
5415
 
@@ -5491,9 +5651,13 @@ void kernel_flash_attn_ext_impl(
5491
5651
  ic = 0;
5492
5652
  }
5493
5653
 
5654
+ char blk_cur = 1;
5655
+
5494
5656
  // read the mask into shared mem
5495
5657
  if (FC_flash_attn_ext_has_mask) {
5496
- if (blk[ic0] == 0) {
5658
+ blk_cur = blk[ic0];
5659
+
5660
+ if (blk_cur == 0) {
5497
5661
  FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5498
5662
  pm2[jj] += NW;
5499
5663
  }
@@ -5501,16 +5665,22 @@ void kernel_flash_attn_ext_impl(
5501
5665
  continue;
5502
5666
  }
5503
5667
 
5504
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5505
- const short j = jj*NSG + sgitg;
5668
+ if (blk_cur == 1) {
5669
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5670
+ const short j = jj*NSG + sgitg;
5506
5671
 
5507
- if (FC_flash_attn_ext_bc_mask) {
5508
- sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
5509
- } else {
5510
- sm2[j*SH + tiisg] = pm2[jj][tiisg];
5511
- }
5672
+ if (FC_flash_attn_ext_bc_mask) {
5673
+ sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
5674
+ } else {
5675
+ sm2[j*SH + tiisg] = pm2[jj][tiisg];
5676
+ }
5512
5677
 
5513
- pm2[jj] += NW;
5678
+ pm2[jj] += NW;
5679
+ }
5680
+ } else if (blk_cur == 2) {
5681
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5682
+ pm2[jj] += NW;
5683
+ }
5514
5684
  }
5515
5685
 
5516
5686
  #if 0
@@ -5552,9 +5722,7 @@ void kernel_flash_attn_ext_impl(
5552
5722
 
5553
5723
  constexpr short NC = (C/8)/NSG;
5554
5724
 
5555
- // note: do not unroll for large heads
5556
- #pragma unroll (DK <= 64 ? NC : 1)
5557
- for (short cc = 0; cc < NC; ++cc) {
5725
+ FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
5558
5726
  qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
5559
5727
 
5560
5728
  if (DK % 16 != 0) {
@@ -5575,7 +5743,9 @@ void kernel_flash_attn_ext_impl(
5575
5743
  k8x8_t mk[2];
5576
5744
  q8x8_t mq[2];
5577
5745
 
5578
- FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
5746
+ // note: too much unroll can tank the performance for large heads
5747
+ #pragma unroll (MIN(DK8/2, 4*NSG))
5748
+ for (short i = 0; i < DK8/2; ++i) {
5579
5749
  simdgroup_barrier(mem_flags::mem_none);
5580
5750
 
5581
5751
  simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -5675,10 +5845,12 @@ void kernel_flash_attn_ext_impl(
5675
5845
  }
5676
5846
 
5677
5847
  // mqk = mqk + slope*mask
5678
- if (FC_flash_attn_ext_has_bias) {
5679
- s2 += s2_t(sm2[j*SH + tiisg])*slope;
5680
- } else {
5681
- s2 += s2_t(sm2[j*SH + tiisg]);
5848
+ if (blk_cur != 2) {
5849
+ if (FC_flash_attn_ext_has_bias) {
5850
+ s2 += s2_t(sm2[j*SH + tiisg])*slope;
5851
+ } else {
5852
+ s2 += s2_t(sm2[j*SH + tiisg]);
5853
+ }
5682
5854
  }
5683
5855
 
5684
5856
  M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
@@ -5749,7 +5921,9 @@ void kernel_flash_attn_ext_impl(
5749
5921
  pv += 8*NS20;
5750
5922
  }
5751
5923
  } else {
5752
- FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
5924
+ constexpr short NC = (C/8)/2;
5925
+
5926
+ FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
5753
5927
  s8x8_t vs[2];
5754
5928
 
5755
5929
  simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -5929,7 +6103,7 @@ template<
5929
6103
  void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
5930
6104
  short DK, // K head size
5931
6105
  short DV, // V head size
5932
- short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
6106
+ short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
5933
6107
  short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
5934
6108
  kernel void kernel_flash_attn_ext(
5935
6109
  constant ggml_metal_kargs_flash_attn_ext & args,
@@ -5952,6 +6126,7 @@ kernel void kernel_flash_attn_ext(
5952
6126
  //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
5953
6127
  //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
5954
6128
  case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
6129
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
5955
6130
  }
5956
6131
  #undef FWD_TMPL
5957
6132
  #undef FWD_ARGS
@@ -6001,6 +6176,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at
6001
6176
  template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
6002
6177
  template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
6003
6178
  template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
6179
+ template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>;
6004
6180
  template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
6005
6181
 
6006
6182
  template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
@@ -6015,6 +6191,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at
6015
6191
  template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
6016
6192
  template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
6017
6193
  template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
6194
+ template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>;
6018
6195
  template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
6019
6196
 
6020
6197
  #if defined(GGML_METAL_HAS_BF16)
@@ -6030,6 +6207,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at
6030
6207
  template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
6031
6208
  template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
6032
6209
  template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
6210
+ template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>;
6033
6211
  template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
6034
6212
  #endif
6035
6213
 
@@ -6045,6 +6223,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at
6045
6223
  template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
6046
6224
  template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
6047
6225
  template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
6226
+ template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;
6048
6227
  template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
6049
6228
 
6050
6229
  template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
@@ -6059,6 +6238,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at
6059
6238
  template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
6060
6239
  template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
6061
6240
  template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
6241
+ template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;
6062
6242
  template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
6063
6243
 
6064
6244
  template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
@@ -6073,6 +6253,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at
6073
6253
  template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
6074
6254
  template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
6075
6255
  template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
6256
+ template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;
6076
6257
  template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
6077
6258
 
6078
6259
  template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
@@ -6087,6 +6268,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at
6087
6268
  template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
6088
6269
  template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
6089
6270
  template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
6271
+ template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;
6090
6272
  template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
6091
6273
 
6092
6274
  template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
@@ -6101,6 +6283,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at
6101
6283
  template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
6102
6284
  template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
6103
6285
  template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
6286
+ template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;
6104
6287
  template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
6105
6288
 
6106
6289
  #undef FA_TYPES
@@ -6138,11 +6321,10 @@ template<
6138
6321
  void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
6139
6322
  short DK, // K head size
6140
6323
  short DV, // V head size
6141
- short NE, // head elements per thread
6142
- short Q, // queries per threadgroup
6143
- short C, // cache items per threadgroup
6144
- short NSG> // number of simd groups
6145
- void kernel_flash_attn_ext_vec_impl(
6324
+ short NE = 4, // head elements per thread
6325
+ short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup
6326
+ short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
6327
+ kernel void kernel_flash_attn_ext_vec(
6146
6328
  constant ggml_metal_kargs_flash_attn_ext_vec & args,
6147
6329
  device const char * q,
6148
6330
  device const char * k,
@@ -6159,6 +6341,7 @@ void kernel_flash_attn_ext_vec_impl(
6159
6341
  static_assert(DV % 32 == 0, "DV must be divisible by 32");
6160
6342
 
6161
6343
  #define NWG (FC_flash_attn_ext_vec_nwg)
6344
+ #define NSG (FC_flash_attn_ext_vec_nsg)
6162
6345
 
6163
6346
  #define NS10 (FC_flash_attn_ext_vec_ns10)
6164
6347
  #define NS20 (FC_flash_attn_ext_vec_ns20)
@@ -6185,14 +6368,14 @@ void kernel_flash_attn_ext_vec_impl(
6185
6368
  static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
6186
6369
  static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
6187
6370
 
6188
- const short T = PK + NSG*SH; // shared memory size per query in (half)
6371
+ //const short T = PK + NSG*SH; // shared memory size per query in (half)
6189
6372
 
6190
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
6191
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
6192
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention
6193
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t
6194
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
6195
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results
6373
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
6374
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
6375
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention
6376
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t
6377
+ threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
6378
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results
6196
6379
 
6197
6380
  // store the result for all queries in shared memory (the O matrix from the paper)
6198
6381
  so4 += tiisg;
@@ -6210,11 +6393,13 @@ void kernel_flash_attn_ext_vec_impl(
6210
6393
  // load heads from Q to shared memory
6211
6394
  device const float4 * q4 = (device const float4 *) ((device const char *) q);
6212
6395
 
6213
- for (short i = tiisg; i < PK4; i += NW) {
6214
- if (iq1 < args.ne01 && i < DK4) {
6215
- sq4[i] = (q4_t) q4[i];
6216
- } else {
6217
- sq4[i] = (q4_t) 0.0f;
6396
+ if (iq1 < args.ne01) {
6397
+ for (short i = tiisg; i < PK4; i += NW) {
6398
+ if (i < DK4) {
6399
+ sq4[i] = (q4_t) q4[i];
6400
+ } else {
6401
+ sq4[i] = (q4_t) 0.0f;
6402
+ }
6218
6403
  }
6219
6404
  }
6220
6405
 
@@ -6292,7 +6477,7 @@ void kernel_flash_attn_ext_vec_impl(
6292
6477
  }
6293
6478
 
6294
6479
  // skip -INF blocks
6295
- if (simd_max(sm[tiisg]) == -INFINITY) {
6480
+ if (simd_max(sm[tiisg]) <= -MAXHALF) {
6296
6481
  continue;
6297
6482
  }
6298
6483
 
@@ -6566,57 +6751,11 @@ void kernel_flash_attn_ext_vec_impl(
6566
6751
  }
6567
6752
 
6568
6753
  #undef NWG
6754
+ #undef NSG
6569
6755
  #undef NS10
6570
6756
  #undef NS20
6571
6757
  }
6572
6758
 
6573
- template<
6574
- typename q4_t, // query types in shared memory
6575
- typename k4_t, // key types in shared memory
6576
- typename v4_t, // value types in shared memory
6577
- typename qk_t, // Q*K types
6578
- typename s_t, // soft-max types
6579
- typename s4_t,
6580
- typename o4_t, // attention accumulation types
6581
- typename kd4_t, // key type in device memory
6582
- short nl_k,
6583
- void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
6584
- typename vd4_t, // value type in device memory
6585
- short nl_v,
6586
- void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
6587
- short DK, // K head size
6588
- short DV, // V head size
6589
- short NE = 4, // head elements per thread
6590
- short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
6591
- short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
6592
- kernel void kernel_flash_attn_ext_vec(
6593
- constant ggml_metal_kargs_flash_attn_ext_vec & args,
6594
- device const char * q,
6595
- device const char * k,
6596
- device const char * v,
6597
- device const char * mask,
6598
- device const char * sinks,
6599
- device const char * pad,
6600
- device char * dst,
6601
- threadgroup half * shmem_f16 [[threadgroup(0)]],
6602
- uint3 tgpig[[threadgroup_position_in_grid]],
6603
- ushort tiisg[[thread_index_in_simdgroup]],
6604
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6605
- #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
6606
- #define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
6607
- switch (FC_flash_attn_ext_vec_nsg) {
6608
- // note: disabled cases to reduce library load time
6609
- case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
6610
- case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
6611
- case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
6612
- //case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
6613
- //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
6614
- //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
6615
- }
6616
- #undef FWD_TMPL
6617
- #undef FWD_ARGS
6618
- }
6619
-
6620
6759
  // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
6621
6760
  // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
6622
6761
  //
@@ -6715,6 +6854,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
6715
6854
  template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
6716
6855
  template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
6717
6856
 
6857
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 320, 256, 2>;
6858
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 320, 256, 2>;
6859
+ #if defined(GGML_METAL_HAS_BF16)
6860
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 320, 256, 2>;
6861
+ #endif
6862
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 320, 256, 2>;
6863
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 320, 256, 2>;
6864
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 320, 256, 2>;
6865
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>;
6866
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>;
6867
+
6718
6868
  template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
6719
6869
  template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
6720
6870
  #if defined(GGML_METAL_HAS_BF16)
@@ -8779,6 +8929,26 @@ kernel void kernel_set_rows_f(
8779
8929
  }
8780
8930
  }
8781
8931
 
8932
+ kernel void kernel_diag_f32(
8933
+ constant ggml_metal_kargs_diag & args,
8934
+ device const char * src0,
8935
+ device char * dst,
8936
+ uint3 tgpig[[threadgroup_position_in_grid]],
8937
+ ushort tiitg[[thread_index_in_threadgroup]]) {
8938
+ constexpr short NW = N_SIMDWIDTH;
8939
+
8940
+ const int32_t i3 = tgpig.z;
8941
+ const int32_t i2 = tgpig.y;
8942
+ const int32_t i1 = tgpig.x;
8943
+
8944
+ device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
8945
+ device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
8946
+
8947
+ for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
8948
+ dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
8949
+ }
8950
+ }
8951
+
8782
8952
  constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
8783
8953
  constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
8784
8954
 
@@ -8797,7 +8967,9 @@ kernel void kernel_mul_mm(
8797
8967
  threadgroup S0 * sa = (threadgroup S0 *)(shmem);
8798
8968
  threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
8799
8969
 
8970
+ #ifdef GGML_METAL_HAS_TENSOR
8800
8971
  threadgroup float * sc = (threadgroup float *)(shmem);
8972
+ #endif
8801
8973
 
8802
8974
  constexpr int NR0 = 64;
8803
8975
  constexpr int NR1 = 32;
@@ -8920,8 +9092,8 @@ kernel void kernel_mul_mm(
8920
9092
  const short sx = (tiitg%NL1);
8921
9093
  const short sy = (tiitg/NL1)/8;
8922
9094
 
8923
- const short dx = sx;
8924
- const short dy = sy;
9095
+ //const short dx = sx;
9096
+ //const short dy = sy;
8925
9097
 
8926
9098
  const short ly = (tiitg/NL1)%8;
8927
9099
 
@@ -9153,6 +9325,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
9153
9325
  template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
9154
9326
  template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
9155
9327
  template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
9328
+ template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
9156
9329
 
9157
9330
  template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
9158
9331
  kernel void kernel_mul_mm_id(
@@ -9170,7 +9343,9 @@ kernel void kernel_mul_mm_id(
9170
9343
  threadgroup S0 * sa = (threadgroup S0 *)(shmem);
9171
9344
  threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
9172
9345
 
9346
+ #ifdef GGML_METAL_HAS_TENSOR
9173
9347
  threadgroup float * sc = (threadgroup float *)(shmem);
9348
+ #endif
9174
9349
 
9175
9350
  constexpr int NR0 = 64;
9176
9351
  constexpr int NR1 = 32;
@@ -9305,8 +9480,8 @@ kernel void kernel_mul_mm_id(
9305
9480
  const short sx = (tiitg%NL1);
9306
9481
  const short sy = (tiitg/NL1)/8;
9307
9482
 
9308
- const short dx = sx;
9309
- const short dy = sy;
9483
+ //const short dx = sx;
9484
+ //const short dy = sy;
9310
9485
 
9311
9486
  const short ly = (tiitg/NL1)%8;
9312
9487
 
@@ -9869,6 +10044,74 @@ kernel void kernel_pool_2d_avg_f32(
9869
10044
  o_ptr[cur_oh * args.OW + cur_ow] = res;
9870
10045
  }
9871
10046
 
10047
+
10048
+ kernel void kernel_pool_1d_max_f32(
10049
+ constant ggml_metal_kargs_pool_1d & args,
10050
+ device const float * src,
10051
+ device float * dst,
10052
+ uint gid [[thread_position_in_grid]]
10053
+ ) {
10054
+
10055
+ if (gid >= args.np) {
10056
+ return;
10057
+ }
10058
+
10059
+ const int ow = (int)gid % args.OW;
10060
+ const int row = (int)gid / args.OW;
10061
+
10062
+ const int base = ow * args.s0 - args.p0;
10063
+
10064
+ float acc = -INFINITY;
10065
+
10066
+ const int src_off = row * args.IW;
10067
+ const int dst_off = row * args.OW;
10068
+
10069
+ for (int ki = 0; ki < args.k0; ++ki) {
10070
+ int j = base + ki;
10071
+ if (j < 0 || j >= args.IW){
10072
+ continue;
10073
+ }
10074
+ float v = src[src_off + j];
10075
+ acc = max(acc, v);
10076
+ }
10077
+
10078
+ dst[dst_off + ow] = acc;
10079
+ }
10080
+
10081
+ kernel void kernel_pool_1d_avg_f32(
10082
+ constant ggml_metal_kargs_pool_1d & args,
10083
+ device const float * src,
10084
+ device float * dst,
10085
+ uint gid [[thread_position_in_grid]]
10086
+ ) {
10087
+
10088
+ if (gid >= args.np) {
10089
+ return;
10090
+ }
10091
+
10092
+ const int ow = (int)gid % args.OW;
10093
+ const int row = (int)gid / args.OW;
10094
+
10095
+ const int base = ow * args.s0 - args.p0;
10096
+
10097
+ float acc = 0.0f;
10098
+ int cnt = 0;
10099
+
10100
+ const int src_off = row * args.IW;
10101
+ const int dst_off = row * args.OW;
10102
+
10103
+ for (int ki = 0; ki < args.k0; ++ki) {
10104
+ const int j = base + ki;
10105
+ if (j < 0 || j >= args.IW) {
10106
+ continue;
10107
+ }
10108
+ acc += src[src_off + j];
10109
+ cnt += 1;
10110
+ }
10111
+
10112
+ dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
10113
+ }
10114
+
9872
10115
  kernel void kernel_opt_step_adamw_f32(
9873
10116
  constant ggml_metal_kargs_opt_step_adamw & args,
9874
10117
  device float * x,
@@ -9919,7 +10162,7 @@ kernel void kernel_opt_step_sgd_f32(
9919
10162
 
9920
10163
  template<typename T>
9921
10164
  kernel void kernel_memset(
9922
- constant ggml_metal_kargs_fill & args,
10165
+ constant ggml_metal_kargs_memset & args,
9923
10166
  device T * dst,
9924
10167
  uint tpig[[thread_position_in_grid]]) {
9925
10168
  dst[tpig] = args.val;