whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
77
77
  return x*y;
78
78
  }
79
79
 
80
+ static inline float sum(float x) {
81
+ return x;
82
+ }
83
+
84
+ static inline float sum(float4 x) {
85
+ return x[0] + x[1] + x[2] + x[3];
86
+ }
87
+
80
88
  // NOTE: this is not dequantizing - we are simply fitting the template
81
89
  template <typename type4x4>
82
90
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -110,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
110
118
  }
111
119
  #endif
112
120
 
121
+ template <typename type4x4>
122
+ void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
123
+ device const uint8_t * qs = xb->qs;
124
+ const float d = xb->d;
125
+ const float neg_d = -d;
126
+
127
+ const int byte_offset = il * 2; // il*16 bits = il*2 bytes
128
+ const uint8_t b0 = qs[byte_offset];
129
+ const uint8_t b1 = qs[byte_offset + 1];
130
+
131
+ float4x4 reg_f;
132
+
133
+ reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
134
+ reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
135
+ reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
136
+ reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
137
+ reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
138
+ reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
139
+ reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
140
+ reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
141
+
142
+ reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
143
+ reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
144
+ reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
145
+ reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
146
+ reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
147
+ reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
148
+ reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
149
+ reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
150
+
151
+ reg = (type4x4) reg_f;
152
+ }
153
+
154
+ template <typename type4>
155
+ void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
156
+ const float d = xb->d;
157
+ const float neg_d = -d;
158
+ const int base = il * 4;
159
+ const uint8_t byte = xb->qs[base / 8];
160
+ const int s = base % 8;
161
+
162
+ float4 reg_f;
163
+ reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
164
+ reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
165
+ reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
166
+ reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
167
+
168
+ reg = (type4) reg_f;
169
+ }
170
+
113
171
  template <typename type4x4>
114
172
  void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
115
173
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
@@ -144,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
144
202
  }
145
203
  }
146
204
 
205
+ void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
206
+ float sum_abs = 0.0f;
207
+ for (int j = 0; j < QK1_0; j++) {
208
+ sum_abs += fabs(src[j]);
209
+ }
210
+ dst.d = sum_abs / QK1_0;
211
+
212
+ for (int j = 0; j < QK1_0 / 8; j++) {
213
+ dst.qs[j] = 0;
214
+ }
215
+ for (int j = 0; j < QK1_0; j++) {
216
+ if (src[j] >= 0.0f) {
217
+ dst.qs[j / 8] |= (1 << (j % 8));
218
+ }
219
+ }
220
+ }
221
+
147
222
  void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
148
223
  #pragma METAL fp math_mode(safe)
149
224
  float amax = 0.0f; // absolute max
@@ -895,753 +970,459 @@ enum ggml_sort_order {
895
970
  GGML_SORT_ORDER_DESC,
896
971
  };
897
972
 
898
- // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
899
- // pros: works for non-contiguous tensors, supports broadcast across all dims
900
- // cons: not very efficient
901
- template <int F>
902
- kernel void kernel_add_fuse_impl(
903
- constant ggml_metal_kargs_bin & args,
904
- device const char * src0,
905
- device const char * src1,
906
- device char * dst,
907
- uint3 tgpig[[threadgroup_position_in_grid]],
908
- ushort3 tpitg[[thread_position_in_threadgroup]],
909
- ushort3 ntg[[threads_per_threadgroup]]) {
910
- const int i03 = tgpig.z;
911
- const int i02 = tgpig.y;
912
- const int i01 = tgpig.x;
973
+ constant float GELU_COEF_A = 0.044715f;
974
+ constant float GELU_QUICK_COEF = -1.702f;
975
+ constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
976
+ constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
913
977
 
914
- const int i13 = i03%args.ne13;
915
- const int i12 = i02%args.ne12;
916
- const int i11 = i01%args.ne11;
978
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
979
+ // ref: https://www.johndcook.com/blog/python_erf/
980
+ constant float p_erf = 0.3275911f;
981
+ constant float a1_erf = 0.254829592f;
982
+ constant float a2_erf = -0.284496736f;
983
+ constant float a3_erf = 1.421413741f;
984
+ constant float a4_erf = -1.453152027f;
985
+ constant float a5_erf = 1.061405429f;
917
986
 
918
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
919
- device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
987
+ template<typename T>
988
+ inline T erf_approx(T x) {
989
+ T sign_x = sign(x);
990
+ x = fabs(x);
991
+ T t = 1.0f / (1.0f + p_erf * x);
992
+ T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
993
+ return sign_x * y;
994
+ }
920
995
 
921
- device const float * src1_ptr[F];
922
- for (short j = 0; j < F; ++j) {
923
- src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
924
- }
996
+ template<typename T> T elu_approx(T x);
925
997
 
926
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
927
- const int i10 = i0%args.ne10;
998
+ template<> inline float elu_approx<float>(float x) {
999
+ return (x > 0.f) ? x : (exp(x) - 1);
1000
+ }
928
1001
 
929
- float res = src0_ptr[i0];
1002
+ template<> inline float4 elu_approx<float4>(float4 x) {
1003
+ float4 res;
930
1004
 
931
- #pragma unroll
932
- for (short j = 0; j < F; ++j) {
933
- res += src1_ptr[j][i10];
934
- }
1005
+ res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
1006
+ res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
1007
+ res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
1008
+ res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
935
1009
 
936
- dst_ptr[i0] = res;
937
- }
1010
+ return res;
938
1011
  }
939
1012
 
940
- typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
941
-
942
- template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
943
- template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
944
- template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
945
- template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
946
- template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
947
- template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
948
- template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
949
- template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
1013
+ constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
1014
+ constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
950
1015
 
951
- kernel void kernel_sub_fuse_1(
952
- constant ggml_metal_kargs_bin & args,
1016
+ template <typename T0, typename T, typename TC>
1017
+ kernel void kernel_unary_impl(
1018
+ constant ggml_metal_kargs_unary & args,
953
1019
  device const char * src0,
954
- device const char * src1,
955
1020
  device char * dst,
956
1021
  uint3 tgpig[[threadgroup_position_in_grid]],
957
1022
  ushort3 tpitg[[thread_position_in_threadgroup]],
958
1023
  ushort3 ntg[[threads_per_threadgroup]]) {
959
- const int i03 = tgpig.z;
960
- const int i02 = tgpig.y;
961
- const int i01 = tgpig.x;
1024
+ #define FC_OP FC_unary_op
1025
+ #define FC_CNT FC_unary_cnt
962
1026
 
963
- const int i13 = i03%args.ne13;
964
- const int i12 = i02%args.ne12;
965
- const int i11 = i01%args.ne11;
1027
+ device const T0 * src0_ptr;
1028
+ device T * dst_ptr;
966
1029
 
967
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
968
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
969
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
1030
+ int i0;
970
1031
 
971
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
972
- const int i10 = i0%args.ne10;
973
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
974
- }
975
- }
1032
+ if (FC_CNT) {
1033
+ i0 = tgpig.x;
976
1034
 
977
- kernel void kernel_mul_fuse_1(
978
- constant ggml_metal_kargs_bin & args,
979
- device const char * src0,
980
- device const char * src1,
981
- device char * dst,
982
- uint3 tgpig[[threadgroup_position_in_grid]],
983
- ushort3 tpitg[[thread_position_in_threadgroup]],
984
- ushort3 ntg[[threads_per_threadgroup]]) {
985
- const int i03 = tgpig.z;
986
- const int i02 = tgpig.y;
987
- const int i01 = tgpig.x;
988
-
989
- const int i13 = i03%args.ne13;
990
- const int i12 = i02%args.ne12;
991
- const int i11 = i01%args.ne11;
992
-
993
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
994
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
995
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
996
-
997
- if (args.ne10 == 1) {
998
- const float x = *((device float *)(src1_ptr));
999
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1000
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
1001
- }
1035
+ src0_ptr = (device const T0 *) (src0);
1036
+ dst_ptr = (device T *) (dst);
1002
1037
  } else {
1003
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1004
- const int i10 = i0%args.ne10;
1005
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
1006
- }
1007
- }
1008
- }
1038
+ const int i03 = tgpig.z;
1039
+ const int i02 = tgpig.y;
1040
+ const int k0 = tgpig.x/args.ne01;
1041
+ const int i01 = tgpig.x - k0*args.ne01;
1009
1042
 
1010
- kernel void kernel_div_fuse_1(
1011
- constant ggml_metal_kargs_bin & args,
1012
- device const char * src0,
1013
- device const char * src1,
1014
- device char * dst,
1015
- uint3 tgpig[[threadgroup_position_in_grid]],
1016
- ushort3 tpitg[[thread_position_in_threadgroup]],
1017
- ushort3 ntg[[threads_per_threadgroup]]) {
1018
- const int i03 = tgpig.z;
1019
- const int i02 = tgpig.y;
1020
- const int i01 = tgpig.x;
1043
+ i0 = k0*ntg.x + tpitg.x;
1021
1044
 
1022
- const int i13 = i03%args.ne13;
1023
- const int i12 = i02%args.ne12;
1024
- const int i11 = i01%args.ne11;
1045
+ src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
1046
+ dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
1047
+ }
1025
1048
 
1026
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
1027
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
1028
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
1049
+ {
1050
+ //threadgroup_barrier(mem_flags::mem_none);
1029
1051
 
1030
- if (args.ne10 == 1) {
1031
- const float x = 1.0f / *((device float *)(src1_ptr));
1032
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1033
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
1034
- }
1035
- } else {
1036
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1037
- const int i10 = i0%args.ne10;
1038
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
1052
+ if (!FC_CNT) {
1053
+ if (i0 >= args.ne0) {
1054
+ return;
1055
+ }
1039
1056
  }
1040
- }
1041
- }
1042
1057
 
1043
- kernel void kernel_add_id(
1044
- constant ggml_metal_kargs_add_id & args,
1045
- device const char * src0,
1046
- device const char * src1,
1047
- device const char * src2,
1048
- device char * dst,
1049
- uint3 tgpig[[threadgroup_position_in_grid]],
1050
- ushort3 tpitg[[thread_position_in_threadgroup]],
1051
- ushort3 ntg[[threads_per_threadgroup]]) {
1052
- const int i1 = tgpig.x;
1053
- const int i2 = tgpig.y;
1058
+ const TC x = (TC) src0_ptr[i0];
1054
1059
 
1055
- const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
1056
-
1057
- const size_t nb1 = args.ne0 * sizeof(float);
1058
- const size_t nb2 = args.ne1 * nb1;
1060
+ if (FC_OP == OP_UNARY_NUM_SCALE) {
1061
+ dst_ptr[i0] = (T) (args.scale * x + args.bias);
1062
+ }
1059
1063
 
1060
- device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
1061
- device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
1062
- device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
1064
+ if (FC_OP == OP_UNARY_NUM_FILL) {
1065
+ dst_ptr[i0] = (T) args.val;
1066
+ }
1063
1067
 
1064
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1065
- dst_row[i0] = src0_row[i0] + src1_row[i0];
1066
- }
1067
- }
1068
+ if (FC_OP == OP_UNARY_NUM_CLAMP) {
1069
+ dst_ptr[i0] = (T) clamp(x, args.min, args.max);
1070
+ }
1068
1071
 
1069
- template<typename T>
1070
- kernel void kernel_repeat(
1071
- constant ggml_metal_kargs_repeat & args,
1072
- device const char * src0,
1073
- device char * dst,
1074
- uint3 tgpig[[threadgroup_position_in_grid]],
1075
- ushort3 tpitg[[thread_position_in_threadgroup]],
1076
- ushort3 ntg[[threads_per_threadgroup]]) {
1077
- const int i3 = tgpig.z;
1078
- const int i2 = tgpig.y;
1079
- const int i1 = tgpig.x;
1072
+ if (FC_OP == OP_UNARY_NUM_SQR) {
1073
+ dst_ptr[i0] = (T) (x * x);
1074
+ }
1080
1075
 
1081
- const int i03 = i3%args.ne03;
1082
- const int i02 = i2%args.ne02;
1083
- const int i01 = i1%args.ne01;
1076
+ if (FC_OP == OP_UNARY_NUM_SQRT) {
1077
+ dst_ptr[i0] = (T) sqrt(x);
1078
+ }
1084
1079
 
1085
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
1086
- device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
1080
+ if (FC_OP == OP_UNARY_NUM_SIN) {
1081
+ dst_ptr[i0] = (T) sin(x);
1082
+ }
1087
1083
 
1088
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1089
- const int i00 = i0%args.ne00;
1090
- *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
1091
- }
1092
- }
1084
+ if (FC_OP == OP_UNARY_NUM_COS) {
1085
+ dst_ptr[i0] = (T) cos(x);
1086
+ }
1093
1087
 
1094
- typedef decltype(kernel_repeat<float>) kernel_repeat_t;
1088
+ if (FC_OP == OP_UNARY_NUM_LOG) {
1089
+ dst_ptr[i0] = (T) log(x);
1090
+ }
1095
1091
 
1096
- template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
1097
- template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
1098
- template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
1099
- template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
1092
+ if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
1093
+ dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
1094
+ }
1100
1095
 
1101
- // assumption: src1 is a row
1102
- // broadcast src1 into src0
1103
- template <short F>
1104
- kernel void kernel_add_row_c4_fuse_impl(
1105
- constant ggml_metal_kargs_bin & args,
1106
- device const char * src0,
1107
- device const char * src1,
1108
- device char * dst,
1109
- uint tpig[[thread_position_in_grid]]) {
1110
- const uint nb = args.ne00/4;
1111
- const uint i = tpig % nb;
1096
+ if (FC_OP == OP_UNARY_NUM_TANH) {
1097
+ dst_ptr[i0] = (T) precise::tanh(x);
1098
+ }
1112
1099
 
1113
- device const float4 * src0_row = (device const float4 *) (src0);
1114
- device float4 * dst_row = (device float4 *) (dst);
1100
+ if (FC_OP == OP_UNARY_NUM_RELU) {
1101
+ dst_ptr[i0] = (T) fmax(0, x);
1102
+ }
1115
1103
 
1116
- float4 res = src0_row[tpig];
1104
+ if (FC_OP == OP_UNARY_NUM_SIGMOID) {
1105
+ dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
1106
+ }
1117
1107
 
1118
- #pragma unroll(F)
1119
- for (short j = 0; j < F; ++j) {
1120
- res += ((device const float4 *) (src1 + args.o1[j]))[i];
1121
- }
1108
+ if (FC_OP == OP_UNARY_NUM_GELU) {
1109
+ dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
1110
+ }
1122
1111
 
1123
- dst_row[tpig] = res;
1124
- }
1112
+ if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
1113
+ dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
1114
+ }
1125
1115
 
1126
- typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
1116
+ if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
1117
+ dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
1118
+ }
1127
1119
 
1128
- template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
1129
- template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
1130
- template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
1131
- template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
1132
- template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
1133
- template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
1134
- template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
1135
- template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
1120
+ if (FC_OP == OP_UNARY_NUM_SILU) {
1121
+ dst_ptr[i0] = (T) (x / (1 + exp(-x)));
1122
+ }
1136
1123
 
1137
- template <short F>
1138
- kernel void kernel_sub_row_c4_fuse_impl(
1139
- constant ggml_metal_kargs_bin & args,
1140
- device const char * src0,
1141
- device const char * src1,
1142
- device char * dst,
1143
- uint tpig[[thread_position_in_grid]]) {
1124
+ if (FC_OP == OP_UNARY_NUM_ELU) {
1125
+ dst_ptr[i0] = (T) elu_approx(x);
1126
+ }
1144
1127
 
1145
- const uint nb = args.ne00/4;
1146
- const uint i = tpig % nb;
1128
+ if (FC_OP == OP_UNARY_NUM_NEG) {
1129
+ dst_ptr[i0] = (T) -x;
1130
+ }
1147
1131
 
1148
- device const float4 * src0_row = (device const float4 *) (src0);
1149
- device float4 * dst_row = (device float4 *) (dst);
1132
+ if (FC_OP == OP_UNARY_NUM_ABS) {
1133
+ dst_ptr[i0] = (T) fabs(x);
1134
+ }
1150
1135
 
1151
- device const float4 * src1_row[F];
1152
- for (short j = 0; j < F; ++j) {
1153
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1154
- }
1136
+ if (FC_OP == OP_UNARY_NUM_SGN) {
1137
+ dst_ptr[i0] = T(x > 0) - T(x < 0);
1138
+ }
1155
1139
 
1156
- float4 res = src0_row[tpig];
1140
+ if (FC_OP == OP_UNARY_NUM_STEP) {
1141
+ dst_ptr[i0] = T(x > 0);
1142
+ }
1157
1143
 
1158
- #pragma unroll(F)
1159
- for (short j = 0; j < F; ++j) {
1160
- res -= src1_row[j][i];
1161
- }
1144
+ if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
1145
+ dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
1146
+ }
1162
1147
 
1163
- dst_row[tpig] = res;
1164
- }
1148
+ if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
1149
+ dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
1150
+ }
1165
1151
 
1166
- typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
1152
+ if (FC_OP == OP_UNARY_NUM_EXP) {
1153
+ dst_ptr[i0] = (T) exp(x);
1154
+ }
1167
1155
 
1168
- template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
1156
+ if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
1157
+ dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
1158
+ }
1169
1159
 
1170
- template <short F>
1171
- kernel void kernel_mul_row_c4_fuse_impl(
1172
- constant ggml_metal_kargs_bin & args,
1173
- device const char * src0,
1174
- device const char * src1,
1175
- device char * dst,
1176
- uint tpig[[thread_position_in_grid]]) {
1160
+ if (FC_OP == OP_UNARY_NUM_EXPM1) {
1161
+ // TODO: precise implementation
1162
+ dst_ptr[i0] = (T) (exp(x) - 1);
1163
+ }
1177
1164
 
1178
- const uint nb = args.ne00/4;
1179
- const uint i = tpig % nb;
1165
+ if (FC_OP == OP_UNARY_NUM_FLOOR) {
1166
+ dst_ptr[i0] = (T) floor(x);
1167
+ }
1180
1168
 
1181
- device const float4 * src0_row = (device const float4 *) (src0);
1182
- device float4 * dst_row = (device float4 *) (dst);
1169
+ if (FC_OP == OP_UNARY_NUM_CEIL) {
1170
+ dst_ptr[i0] = (T) ceil(x);
1171
+ }
1183
1172
 
1184
- device const float4 * src1_row[F];
1185
- for (short j = 0; j < F; ++j) {
1186
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1187
- }
1173
+ if (FC_OP == OP_UNARY_NUM_ROUND) {
1174
+ dst_ptr[i0] = (T) round(x);
1175
+ }
1188
1176
 
1189
- float4 res = src0_row[tpig];
1177
+ if (FC_OP == OP_UNARY_NUM_TRUNC) {
1178
+ dst_ptr[i0] = (T) trunc(x);
1179
+ }
1190
1180
 
1191
- #pragma unroll(F)
1192
- for (short j = 0; j < F; ++j) {
1193
- res *= src1_row[j][i];
1181
+ if (FC_OP == OP_UNARY_NUM_XIELU) {
1182
+ const TC xi = x;
1183
+ const TC gate = TC(xi > TC(0.0f));
1184
+ const TC clamped = fmin(xi, TC(args.val));
1185
+ const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
1186
+ const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
1187
+ dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
1188
+ }
1194
1189
  }
1195
1190
 
1196
- dst_row[tpig] = res;
1191
+ #undef FC_OP
1192
+ #undef FC_CNT
1197
1193
  }
1198
1194
 
1199
- typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
1195
+ typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
1200
1196
 
1201
- template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
1197
+ template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
1198
+ template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
1199
+ template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
1200
+ template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
1202
1201
 
1203
- template <short F>
1204
- kernel void kernel_div_row_c4_fuse_impl(
1202
+ // OP: 0 - add, 1 - sub, 2 - mul, 3 - div
1203
+ constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
1204
+ constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
1205
+ constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
1206
+ constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
1207
+
1208
+ template <typename T0, typename T1, typename T>
1209
+ kernel void kernel_bin_fuse_impl(
1205
1210
  constant ggml_metal_kargs_bin & args,
1206
1211
  device const char * src0,
1207
1212
  device const char * src1,
1208
1213
  device char * dst,
1209
- uint tpig[[thread_position_in_grid]]) {
1210
-
1211
- const uint nb = args.ne00/4;
1212
- const uint i = tpig % nb;
1213
-
1214
- device const float4 * src0_row = (device const float4 *) (src0);
1215
- device float4 * dst_row = (device float4 *) (dst);
1216
-
1217
- device const float4 * src1_row[F];
1218
- for (short j = 0; j < F; ++j) {
1219
- src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1220
- }
1221
-
1222
- float4 res = src0_row[tpig];
1223
-
1224
- #pragma unroll(F)
1225
- for (short j = 0; j < F; ++j) {
1226
- res /= src1_row[j][i];
1227
- }
1228
-
1229
- dst_row[tpig] = res;
1230
- }
1231
-
1232
- typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
1233
-
1234
- template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
1235
-
1236
- kernel void kernel_scale_f32(
1237
- constant ggml_metal_kargs_scale & args,
1238
- device const float * src0,
1239
- device float * dst,
1240
- uint tpig[[thread_position_in_grid]]) {
1241
- dst[tpig] = src0[tpig] * args.scale + args.bias;
1242
- }
1243
-
1244
- kernel void kernel_scale_f32_4(
1245
- constant ggml_metal_kargs_scale & args,
1246
- device const float4 * src0,
1247
- device float4 * dst,
1248
- uint tpig[[thread_position_in_grid]]) {
1249
- dst[tpig] = src0[tpig] * args.scale + args.bias;
1250
- }
1251
-
1252
- kernel void kernel_fill_f32(
1253
- constant ggml_metal_kargs_fill & args,
1254
- device const float * src0,
1255
- device float * dst,
1256
- uint tpig[[thread_position_in_grid]]) {
1257
- dst[tpig] = args.val;
1258
- }
1259
-
1260
- kernel void kernel_fill_f32_4(
1261
- constant ggml_metal_kargs_fill & args,
1262
- device const float4 * src0,
1263
- device float4 * dst,
1264
- uint tpig[[thread_position_in_grid]]) {
1265
- dst[tpig] = args.val;
1266
- }
1267
-
1268
- kernel void kernel_clamp_f32(
1269
- constant ggml_metal_kargs_clamp & args,
1270
- device const float * src0,
1271
- device float * dst,
1272
- uint tpig[[thread_position_in_grid]]) {
1273
- dst[tpig] = clamp(src0[tpig], args.min, args.max);
1274
- }
1275
-
1276
- kernel void kernel_clamp_f32_4(
1277
- constant ggml_metal_kargs_clamp & args,
1278
- device const float4 * src0,
1279
- device float4 * dst,
1280
- uint tpig[[thread_position_in_grid]]) {
1281
- dst[tpig] = clamp(src0[tpig], args.min, args.max);
1282
- }
1283
-
1284
- kernel void kernel_relu_f32(
1285
- device const float * src0,
1286
- device float * dst,
1287
- uint tpig[[thread_position_in_grid]]) {
1288
- dst[tpig] = max(0.0f, src0[tpig]);
1289
- }
1290
-
1291
- kernel void kernel_relu_f32_4(
1292
- device const float4 * src0,
1293
- device float4 * dst,
1294
- uint tpig[[thread_position_in_grid]]) {
1295
- dst[tpig] = max(0.0f, src0[tpig]);
1296
- }
1297
-
1298
- kernel void kernel_sigmoid_f32(
1299
- device const float * src0,
1300
- device float * dst,
1301
- uint tpig[[thread_position_in_grid]]) {
1302
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
1303
- }
1304
-
1305
- kernel void kernel_sigmoid_f32_4(
1306
- device const float4 * src0,
1307
- device float4 * dst,
1308
- uint tpig[[thread_position_in_grid]]) {
1309
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
1310
- }
1311
-
1312
- kernel void kernel_tanh_f32(
1313
- device const float * src0,
1314
- device float * dst,
1315
- uint tpig[[thread_position_in_grid]]) {
1316
- dst[tpig] = precise::tanh(src0[tpig]);
1317
- }
1318
-
1319
- kernel void kernel_tanh_f32_4(
1320
- device const float4 * src0,
1321
- device float4 * dst,
1322
- uint tpig[[thread_position_in_grid]]) {
1323
- dst[tpig] = precise::tanh(src0[tpig]);
1324
- }
1325
-
1326
- constant float GELU_COEF_A = 0.044715f;
1327
- constant float GELU_QUICK_COEF = -1.702f;
1328
- constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
1329
- constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
1330
-
1331
- kernel void kernel_gelu_f32(
1332
- device const float * src0,
1333
- device float * dst,
1334
- uint tpig[[thread_position_in_grid]]) {
1335
- device const float & x = src0[tpig];
1336
-
1337
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
1338
- }
1339
-
1340
- kernel void kernel_gelu_f32_4(
1341
- device const float4 * src0,
1342
- device float4 * dst,
1343
- uint tpig[[thread_position_in_grid]]) {
1344
- device const float4 & x = src0[tpig];
1345
-
1346
- // BEWARE !!!
1347
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
1348
- // This was observed with Falcon 7B and 40B models
1349
- //
1350
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
1351
- }
1214
+ uint3 tgpig[[threadgroup_position_in_grid]],
1215
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1216
+ ushort3 ntg[[threads_per_threadgroup]]) {
1217
+ #define FC_OP FC_bin_op
1218
+ #define FC_F FC_bin_f
1219
+ #define FC_RB FC_bin_rb
1220
+ #define FC_CB FC_bin_cb
1352
1221
 
1353
- kernel void kernel_gelu_quick_f32(
1354
- device const float * src0,
1355
- device float * dst,
1356
- uint tpig[[thread_position_in_grid]]) {
1357
- device const float & x = src0[tpig];
1222
+ if (FC_RB) {
1223
+ // row broadcast
1224
+ const uint i0 = tgpig.y*args.ne00 + tgpig.x;
1225
+ const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
1358
1226
 
1359
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
1360
- }
1227
+ device const T0 * src0_row = (device const T0 *) (src0);
1228
+ device T * dst_row = (device T *) (dst);
1361
1229
 
1362
- kernel void kernel_gelu_quick_f32_4(
1363
- device const float4 * src0,
1364
- device float4 * dst,
1365
- uint tpig[[thread_position_in_grid]]) {
1366
- device const float4 & x = src0[tpig];
1230
+ if (FC_F == 1) {
1231
+ device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
1367
1232
 
1368
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
1369
- }
1233
+ if (FC_OP == 0) {
1234
+ dst_row[i0] = src0_row[i0] + src1_row[i1];
1235
+ }
1370
1236
 
1371
- // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
1372
- // ref: https://www.johndcook.com/blog/python_erf/
1373
- constant float p_erf = 0.3275911f;
1374
- constant float a1_erf = 0.254829592f;
1375
- constant float a2_erf = -0.284496736f;
1376
- constant float a3_erf = 1.421413741f;
1377
- constant float a4_erf = -1.453152027f;
1378
- constant float a5_erf = 1.061405429f;
1237
+ if (FC_OP == 1) {
1238
+ dst_row[i0] = src0_row[i0] - src1_row[i1];
1239
+ }
1379
1240
 
1380
- template<typename T>
1381
- T erf_approx(T x) {
1382
- T sign_x = sign(x);
1383
- x = fabs(x);
1384
- T t = 1.0f / (1.0f + p_erf * x);
1385
- T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
1386
- return sign_x * y;
1387
- }
1241
+ if (FC_OP == 2) {
1242
+ dst_row[i0] = src0_row[i0] * src1_row[i1];
1243
+ }
1388
1244
 
1389
- kernel void kernel_gelu_erf_f32(
1390
- device const float * src0,
1391
- device float * dst,
1392
- uint tpig[[thread_position_in_grid]]) {
1393
- device const float & x = src0[tpig];
1245
+ if (FC_OP == 3) {
1246
+ dst_row[i0] = src0_row[i0] / src1_row[i1];
1247
+ }
1248
+ } else {
1249
+ T0 res = src0_row[i0];
1394
1250
 
1395
- dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
1396
- }
1251
+ if (FC_OP == 0) {
1252
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1253
+ res += ((device const T1 *) (src1 + args.o1[j]))[i1];
1254
+ }
1255
+ }
1397
1256
 
1398
- kernel void kernel_gelu_erf_f32_4(
1399
- device const float4 * src0,
1400
- device float4 * dst,
1401
- uint tpig[[thread_position_in_grid]]) {
1402
- device const float4 & x = src0[tpig];
1257
+ if (FC_OP == 1) {
1258
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1259
+ res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
1260
+ }
1261
+ }
1403
1262
 
1404
- dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
1405
- }
1263
+ if (FC_OP == 2) {
1264
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1265
+ res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
1266
+ }
1267
+ }
1406
1268
 
1407
- kernel void kernel_silu_f32(
1408
- device const float * src0,
1409
- device float * dst,
1410
- uint tpig[[thread_position_in_grid]]) {
1411
- device const float & x = src0[tpig];
1412
- dst[tpig] = x / (1.0f + exp(-x));
1413
- }
1269
+ if (FC_OP == 3) {
1270
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1271
+ res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
1272
+ }
1273
+ }
1414
1274
 
1415
- kernel void kernel_silu_f32_4(
1416
- device const float4 * src0,
1417
- device float4 * dst,
1418
- uint tpig[[thread_position_in_grid]]) {
1419
- device const float4 & x = src0[tpig];
1420
- dst[tpig] = x / (1.0f + exp(-x));
1421
- }
1275
+ dst_row[i0] = res;
1276
+ }
1277
+ } else {
1278
+ const int i03 = tgpig.z;
1279
+ const int i02 = tgpig.y;
1280
+ const int i01 = tgpig.x;
1422
1281
 
1423
- kernel void kernel_elu_f32(
1424
- device const float * src0,
1425
- device float * dst,
1426
- uint tpig[[thread_position_in_grid]]) {
1427
- const float x = src0[tpig];
1428
- dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
1429
- }
1282
+ if (i01 >= args.ne01) {
1283
+ return;
1284
+ }
1430
1285
 
1431
- kernel void kernel_elu_f32_4(
1432
- device const float4 * src0,
1433
- device float4 * dst,
1434
- uint tpig[[thread_position_in_grid]]) {
1435
- const float4 x = src0[tpig];
1436
- dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
1437
- dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
1438
- dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
1439
- dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
1440
- }
1286
+ const int i13 = i03%args.ne13;
1287
+ const int i12 = i02%args.ne12;
1288
+ const int i11 = i01%args.ne11;
1441
1289
 
1442
- kernel void kernel_sqr_f32(
1443
- device const float * src0,
1444
- device float * dst,
1445
- uint tpig[[thread_position_in_grid]]) {
1446
- dst[tpig] = src0[tpig] * src0[tpig];
1447
- }
1290
+ device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
1291
+ device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
1448
1292
 
1449
- kernel void kernel_sqr_f32_4(
1450
- device const float4 * src0,
1451
- device float4 * dst,
1452
- uint tpig[[thread_position_in_grid]]) {
1453
- dst[tpig] = src0[tpig] * src0[tpig];
1454
- }
1293
+ if (FC_F == 1) {
1294
+ device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
1455
1295
 
1456
- kernel void kernel_sqrt_f32(
1457
- device const float * src0,
1458
- device float * dst,
1459
- uint tpig[[thread_position_in_grid]]) {
1460
- dst[tpig] = sqrt(src0[tpig]);
1461
- }
1296
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1297
+ const int i10 = FC_CB ? i0%args.ne10 : i0;
1462
1298
 
1463
- kernel void kernel_sqrt_f32_4(
1464
- device const float4 * src0,
1465
- device float4 * dst,
1466
- uint tpig[[thread_position_in_grid]]) {
1467
- dst[tpig] = sqrt(src0[tpig]);
1468
- }
1299
+ if (FC_OP == 0) {
1300
+ dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
1301
+ }
1469
1302
 
1470
- kernel void kernel_sin_f32(
1471
- device const float * src0,
1472
- device float * dst,
1473
- uint tpig[[thread_position_in_grid]]) {
1474
- dst[tpig] = sin(src0[tpig]);
1475
- }
1303
+ if (FC_OP == 1) {
1304
+ dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
1305
+ }
1476
1306
 
1477
- kernel void kernel_sin_f32_4(
1478
- device const float4 * src0,
1479
- device float4 * dst,
1480
- uint tpig[[thread_position_in_grid]]) {
1481
- dst[tpig] = sin(src0[tpig]);
1482
- }
1307
+ if (FC_OP == 2) {
1308
+ dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
1309
+ }
1483
1310
 
1484
- kernel void kernel_cos_f32(
1485
- device const float * src0,
1486
- device float * dst,
1487
- uint tpig[[thread_position_in_grid]]) {
1488
- dst[tpig] = cos(src0[tpig]);
1489
- }
1311
+ if (FC_OP == 3) {
1312
+ dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
1313
+ }
1314
+ }
1315
+ } else {
1316
+ device const T1 * src1_ptr[8];
1317
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1318
+ src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
1319
+ }
1490
1320
 
1491
- kernel void kernel_cos_f32_4(
1492
- device const float4 * src0,
1493
- device float4 * dst,
1494
- uint tpig[[thread_position_in_grid]]) {
1495
- dst[tpig] = cos(src0[tpig]);
1496
- }
1321
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1322
+ const int i10 = FC_CB ? i0%args.ne10 : i0;
1497
1323
 
1498
- kernel void kernel_log_f32(
1499
- device const float * src0,
1500
- device float * dst,
1501
- uint tpig[[thread_position_in_grid]]) {
1502
- dst[tpig] = log(src0[tpig]);
1503
- }
1324
+ T res = src0_ptr[i0];
1504
1325
 
1505
- kernel void kernel_log_f32_4(
1506
- device const float4 * src0,
1507
- device float4 * dst,
1508
- uint tpig[[thread_position_in_grid]]) {
1509
- dst[tpig] = log(src0[tpig]);
1510
- }
1326
+ if (FC_OP == 0) {
1327
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1328
+ res += src1_ptr[j][i10];
1329
+ }
1330
+ }
1511
1331
 
1512
- kernel void kernel_neg_f32(
1513
- device const float * src0,
1514
- device float * dst,
1515
- uint tpig[[thread_position_in_grid]]) {
1516
- dst[tpig] = -src0[tpig];
1517
- }
1332
+ if (FC_OP == 1) {
1333
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1334
+ res -= src1_ptr[j][i10];
1335
+ }
1336
+ }
1518
1337
 
1519
- kernel void kernel_neg_f32_4(
1520
- device const float4 * src0,
1521
- device float4 * dst,
1522
- uint tpig[[thread_position_in_grid]]) {
1523
- dst[tpig] = -src0[tpig];
1524
- }
1338
+ if (FC_OP == 2) {
1339
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1340
+ res *= src1_ptr[j][i10];
1341
+ }
1342
+ }
1525
1343
 
1526
- kernel void kernel_abs_f32(
1527
- device const float * src0,
1528
- device float * dst,
1529
- uint tpig[[thread_position_in_grid]]) {
1530
- dst[tpig] = fabs(src0[tpig]);
1531
- }
1344
+ if (FC_OP == 3) {
1345
+ FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1346
+ res /= src1_ptr[j][i10];
1347
+ }
1348
+ }
1532
1349
 
1533
- kernel void kernel_abs_f32_4(
1534
- device const float4 * src0,
1535
- device float4 * dst,
1536
- uint tpig[[thread_position_in_grid]]) {
1537
- dst[tpig] = fabs(src0[tpig]);
1538
- }
1350
+ dst_ptr[i0] = res;
1351
+ }
1352
+ }
1353
+ }
1539
1354
 
1540
- kernel void kernel_sgn_f32(
1541
- device const float * src0,
1542
- device float * dst,
1543
- uint tpig[[thread_position_in_grid]]) {
1544
- dst[tpig] = sign(src0[tpig]);
1355
+ #undef FC_OP
1356
+ #undef FC_F
1357
+ #undef FC_RB
1358
+ #undef FC_CB
1545
1359
  }
1546
1360
 
1547
- kernel void kernel_sgn_f32_4(
1548
- device const float4 * src0,
1549
- device float4 * dst,
1550
- uint tpig[[thread_position_in_grid]]) {
1551
- dst[tpig] = sign(src0[tpig]);
1552
- }
1361
+ typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
1553
1362
 
1554
- kernel void kernel_step_f32(
1555
- device const float * src0,
1556
- device float * dst,
1557
- uint tpig[[thread_position_in_grid]]) {
1558
- dst[tpig] = step(0.0f, src0[tpig]);
1559
- }
1363
+ template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
1364
+ template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
1560
1365
 
1561
- kernel void kernel_step_f32_4(
1562
- device const float4 * src0,
1563
- device float4 * dst,
1564
- uint tpig[[thread_position_in_grid]]) {
1565
- dst[tpig] = step(0.0f, src0[tpig]);
1566
- }
1366
+ kernel void kernel_add_id(
1367
+ constant ggml_metal_kargs_add_id & args,
1368
+ device const char * src0,
1369
+ device const char * src1,
1370
+ device const char * src2,
1371
+ device char * dst,
1372
+ uint3 tgpig[[threadgroup_position_in_grid]],
1373
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1374
+ ushort3 ntg[[threads_per_threadgroup]]) {
1375
+ const int i1 = tgpig.x;
1376
+ const int i2 = tgpig.y;
1567
1377
 
1568
- kernel void kernel_hardswish_f32(
1569
- device const float * src0,
1570
- device float * dst,
1571
- uint tpig[[thread_position_in_grid]]) {
1572
- const float x = src0[tpig];
1573
- dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1574
- }
1378
+ const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
1575
1379
 
1576
- kernel void kernel_hardswish_f32_4(
1577
- device const float4 * src0,
1578
- device float4 * dst,
1579
- uint tpig[[thread_position_in_grid]]) {
1580
- const float4 x = src0[tpig];
1581
- dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1582
- }
1380
+ const size_t nb1 = args.ne0 * sizeof(float);
1381
+ const size_t nb2 = args.ne1 * nb1;
1583
1382
 
1584
- kernel void kernel_hardsigmoid_f32(
1585
- device const float * src0,
1586
- device float * dst,
1587
- uint tpig[[thread_position_in_grid]]) {
1588
- const float x = src0[tpig];
1589
- dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1590
- }
1383
+ device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
1384
+ device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
1385
+ device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
1591
1386
 
1592
- kernel void kernel_hardsigmoid_f32_4(
1593
- device const float4 * src0,
1594
- device float4 * dst,
1595
- uint tpig[[thread_position_in_grid]]) {
1596
- const float4 x = src0[tpig];
1597
- dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1387
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1388
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
1389
+ }
1598
1390
  }
1599
1391
 
1600
- kernel void kernel_exp_f32(
1601
- device const float * src0,
1602
- device float * dst,
1603
- uint tpig[[thread_position_in_grid]]) {
1604
- dst[tpig] = exp(src0[tpig]);
1605
- }
1392
+ template<typename T>
1393
+ kernel void kernel_repeat(
1394
+ constant ggml_metal_kargs_repeat & args,
1395
+ device const char * src0,
1396
+ device char * dst,
1397
+ uint3 tgpig[[threadgroup_position_in_grid]],
1398
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1399
+ ushort3 ntg[[threads_per_threadgroup]]) {
1400
+ const int i3 = tgpig.z;
1401
+ const int i2 = tgpig.y;
1402
+ const int i1 = tgpig.x;
1606
1403
 
1607
- kernel void kernel_exp_f32_4(
1608
- device const float4 * src0,
1609
- device float4 * dst,
1610
- uint tpig[[thread_position_in_grid]]) {
1611
- dst[tpig] = exp(src0[tpig]);
1612
- }
1404
+ const int i03 = i3%args.ne03;
1405
+ const int i02 = i2%args.ne02;
1406
+ const int i01 = i1%args.ne01;
1613
1407
 
1614
- kernel void kernel_softplus_f32(
1615
- device const float * src0,
1616
- device float * dst,
1617
- uint tpig[[thread_position_in_grid]]) {
1618
- device const float & x = src0[tpig];
1619
- dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1620
- }
1408
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
1409
+ device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
1621
1410
 
1622
- kernel void kernel_softplus_f32_4(
1623
- device const float4 * src0,
1624
- device float4 * dst,
1625
- uint tpig[[thread_position_in_grid]]) {
1626
- device const float4 & x = src0[tpig];
1627
- dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
1411
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1412
+ const int i00 = i0%args.ne00;
1413
+ *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
1414
+ }
1628
1415
  }
1629
1416
 
1630
- kernel void kernel_expm1_f32(
1631
- device const float * src0,
1632
- device float * dst,
1633
- uint tpig[[thread_position_in_grid]]) {
1634
- dst[tpig] = exp(src0[tpig]) - 1.0f;
1635
- }
1417
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
1636
1418
 
1637
- kernel void kernel_expm1_f32_4(
1638
- device const float4 * src0,
1639
- device float4 * dst,
1640
- uint tpig[[thread_position_in_grid]]) {
1641
- dst[tpig] = exp(src0[tpig]) - 1.0f;
1642
- }
1419
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
1420
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
1421
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
1422
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
1643
1423
 
1644
- kernel void kernel_reglu_f32(
1424
+ template<typename T>
1425
+ kernel void kernel_reglu(
1645
1426
  constant ggml_metal_kargs_glu & args,
1646
1427
  device const char * src0,
1647
1428
  device const char * src1,
@@ -1649,19 +1430,25 @@ kernel void kernel_reglu_f32(
1649
1430
  uint tgpig[[threadgroup_position_in_grid]],
1650
1431
  uint tpitg[[thread_position_in_threadgroup]],
1651
1432
  uint ntg[[threads_per_threadgroup]]) {
1652
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1653
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1654
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1433
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1434
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1435
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1655
1436
 
1656
1437
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1657
1438
  const float x0 = src0_row[i0];
1658
1439
  const float x1 = src1_row[i0];
1659
1440
 
1660
- dst_row[i0] = x0*x1*(x0 > 0.0f);
1441
+ dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
1661
1442
  }
1662
1443
  }
1663
1444
 
1664
- kernel void kernel_geglu_f32(
1445
+ typedef decltype(kernel_reglu<float>) kernel_reglu_t;
1446
+
1447
+ template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
1448
+ template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
1449
+
1450
+ template<typename T>
1451
+ kernel void kernel_geglu(
1665
1452
  constant ggml_metal_kargs_glu & args,
1666
1453
  device const char * src0,
1667
1454
  device const char * src1,
@@ -1669,9 +1456,9 @@ kernel void kernel_geglu_f32(
1669
1456
  uint tgpig[[threadgroup_position_in_grid]],
1670
1457
  uint tpitg[[thread_position_in_threadgroup]],
1671
1458
  uint ntg[[threads_per_threadgroup]]) {
1672
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1673
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1674
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1459
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1460
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1461
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1675
1462
 
1676
1463
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1677
1464
  const float x0 = src0_row[i0];
@@ -1679,11 +1466,17 @@ kernel void kernel_geglu_f32(
1679
1466
 
1680
1467
  const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1681
1468
 
1682
- dst_row[i0] = gelu*x1;
1469
+ dst_row[i0] = (T)(gelu*x1);
1683
1470
  }
1684
1471
  }
1685
1472
 
1686
- kernel void kernel_swiglu_f32(
1473
+ typedef decltype(kernel_geglu<float>) kernel_geglu_t;
1474
+
1475
+ template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
1476
+ template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
1477
+
1478
+ template<typename T>
1479
+ kernel void kernel_swiglu(
1687
1480
  constant ggml_metal_kargs_glu & args,
1688
1481
  device const char * src0,
1689
1482
  device const char * src1,
@@ -1691,9 +1484,9 @@ kernel void kernel_swiglu_f32(
1691
1484
  uint tgpig[[threadgroup_position_in_grid]],
1692
1485
  uint tpitg[[thread_position_in_threadgroup]],
1693
1486
  uint ntg[[threads_per_threadgroup]]) {
1694
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1695
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1696
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1487
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1488
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1489
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1697
1490
 
1698
1491
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1699
1492
  const float x0 = src0_row[i0];
@@ -1701,11 +1494,17 @@ kernel void kernel_swiglu_f32(
1701
1494
 
1702
1495
  const float silu = x0 / (1.0f + exp(-x0));
1703
1496
 
1704
- dst_row[i0] = silu*x1;
1497
+ dst_row[i0] = (T)(silu*x1);
1705
1498
  }
1706
1499
  }
1707
1500
 
1708
- kernel void kernel_swiglu_oai_f32(
1501
+ typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
1502
+
1503
+ template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
1504
+ template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
1505
+
1506
+ template<typename T>
1507
+ kernel void kernel_swiglu_oai(
1709
1508
  constant ggml_metal_kargs_glu & args,
1710
1509
  device const char * src0,
1711
1510
  device const char * src1,
@@ -1713,9 +1512,9 @@ kernel void kernel_swiglu_oai_f32(
1713
1512
  uint tgpig[[threadgroup_position_in_grid]],
1714
1513
  uint tpitg[[thread_position_in_threadgroup]],
1715
1514
  uint ntg[[threads_per_threadgroup]]) {
1716
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1717
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1718
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1515
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1516
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1517
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1719
1518
 
1720
1519
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1721
1520
  float x0 = src0_row[i0];
@@ -1727,11 +1526,17 @@ kernel void kernel_swiglu_oai_f32(
1727
1526
  float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
1728
1527
  out_glu = out_glu * (1.0f + x1);
1729
1528
 
1730
- dst_row[i0] = out_glu;
1529
+ dst_row[i0] = (T)out_glu;
1731
1530
  }
1732
1531
  }
1733
1532
 
1734
- kernel void kernel_geglu_erf_f32(
1533
+ typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
1534
+
1535
+ template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
1536
+ template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
1537
+
1538
+ template<typename T>
1539
+ kernel void kernel_geglu_erf(
1735
1540
  constant ggml_metal_kargs_glu & args,
1736
1541
  device const char * src0,
1737
1542
  device const char * src1,
@@ -1739,9 +1544,9 @@ kernel void kernel_geglu_erf_f32(
1739
1544
  uint tgpig[[threadgroup_position_in_grid]],
1740
1545
  uint tpitg[[thread_position_in_threadgroup]],
1741
1546
  uint ntg[[threads_per_threadgroup]]) {
1742
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1743
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1744
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1547
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1548
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1549
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1745
1550
 
1746
1551
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1747
1552
  const float x0 = src0_row[i0];
@@ -1749,11 +1554,17 @@ kernel void kernel_geglu_erf_f32(
1749
1554
 
1750
1555
  const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
1751
1556
 
1752
- dst_row[i0] = gelu_erf*x1;
1557
+ dst_row[i0] = (T)(gelu_erf*x1);
1753
1558
  }
1754
1559
  }
1755
1560
 
1756
- kernel void kernel_geglu_quick_f32(
1561
+ typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
1562
+
1563
+ template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
1564
+ template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
1565
+
1566
+ template<typename T>
1567
+ kernel void kernel_geglu_quick(
1757
1568
  constant ggml_metal_kargs_glu & args,
1758
1569
  device const char * src0,
1759
1570
  device const char * src1,
@@ -1761,9 +1572,9 @@ kernel void kernel_geglu_quick_f32(
1761
1572
  uint tgpig[[threadgroup_position_in_grid]],
1762
1573
  uint tpitg[[thread_position_in_threadgroup]],
1763
1574
  uint ntg[[threads_per_threadgroup]]) {
1764
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1765
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1766
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1575
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1576
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1577
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1767
1578
 
1768
1579
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1769
1580
  const float x0 = src0_row[i0];
@@ -1771,10 +1582,15 @@ kernel void kernel_geglu_quick_f32(
1771
1582
 
1772
1583
  const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
1773
1584
 
1774
- dst_row[i0] = gelu_quick*x1;
1585
+ dst_row[i0] = (T)(gelu_quick*x1);
1775
1586
  }
1776
1587
  }
1777
1588
 
1589
+ typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
1590
+
1591
+ template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
1592
+ template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
1593
+
1778
1594
  kernel void kernel_op_sum_f32(
1779
1595
  constant ggml_metal_kargs_sum & args,
1780
1596
  device const float * src0,
@@ -1824,33 +1640,35 @@ kernel void kernel_op_sum_f32(
1824
1640
  }
1825
1641
  }
1826
1642
 
1827
- template <bool norm>
1828
- kernel void kernel_sum_rows(
1643
+ constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
1644
+
1645
+ template <typename T0, typename T>
1646
+ kernel void kernel_sum_rows_impl(
1829
1647
  constant ggml_metal_kargs_sum_rows & args,
1830
- device const float * src0,
1831
- device float * dst,
1832
- threadgroup float * shmem_f32 [[threadgroup(0)]],
1648
+ device const char * src0,
1649
+ device char * dst,
1650
+ threadgroup char * shmem [[threadgroup(0)]],
1833
1651
  uint3 tgpig[[threadgroup_position_in_grid]],
1834
1652
  ushort3 tpitg[[thread_position_in_threadgroup]],
1835
1653
  ushort sgitg[[simdgroup_index_in_threadgroup]],
1836
1654
  ushort tiisg[[thread_index_in_simdgroup]],
1837
1655
  ushort3 ntg[[threads_per_threadgroup]]) {
1838
- int64_t i3 = tgpig.z;
1839
- int64_t i2 = tgpig.y;
1840
- int64_t i1 = tgpig.x;
1656
+ #define FC_OP FC_sum_rows_op
1841
1657
 
1842
- if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1843
- return;
1844
- }
1658
+ const int i3 = tgpig.z;
1659
+ const int i2 = tgpig.y;
1660
+ const int i1 = tgpig.x;
1661
+
1662
+ threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
1845
1663
 
1846
1664
  if (sgitg == 0) {
1847
- shmem_f32[tiisg] = 0.0f;
1665
+ shmem_t[tiisg] = 0.0f;
1848
1666
  }
1849
1667
 
1850
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1851
- device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1668
+ device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1669
+ device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1852
1670
 
1853
- float sumf = 0;
1671
+ T0 sumf = T0(0.0f);
1854
1672
 
1855
1673
  for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1856
1674
  sumf += src_row[i0];
@@ -1861,23 +1679,33 @@ kernel void kernel_sum_rows(
1861
1679
  threadgroup_barrier(mem_flags::mem_threadgroup);
1862
1680
 
1863
1681
  if (tiisg == 0) {
1864
- shmem_f32[sgitg] = sumf;
1682
+ shmem_t[sgitg] = sumf;
1865
1683
  }
1866
1684
 
1867
1685
  threadgroup_barrier(mem_flags::mem_threadgroup);
1868
1686
 
1869
- sumf = shmem_f32[tiisg];
1687
+ sumf = shmem_t[tiisg];
1870
1688
  sumf = simd_sum(sumf);
1871
1689
 
1872
1690
  if (tpitg.x == 0) {
1873
- dst_row[0] = norm ? sumf / args.ne00 : sumf;
1691
+ if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
1692
+ if (is_same<float4, T0>::value) {
1693
+ dst_row[0] = sum(sumf) / (4*args.ne00);
1694
+ } else {
1695
+ dst_row[0] = sum(sumf) / args.ne00;
1696
+ }
1697
+ } else {
1698
+ dst_row[0] = sum(sumf);
1699
+ }
1874
1700
  }
1701
+
1702
+ #undef FC_OP
1875
1703
  }
1876
1704
 
1877
- typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1705
+ typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
1878
1706
 
1879
- template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1880
- template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1707
+ template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
1708
+ template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
1881
1709
 
1882
1710
  template<typename T>
1883
1711
  kernel void kernel_cumsum_blk(
@@ -2737,6 +2565,329 @@ kernel void kernel_rwkv_wkv7_f32(
2737
2565
  }
2738
2566
  }
2739
2567
 
2568
+ constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
2569
+ constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
2570
+ constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
2571
+
2572
+ #if 1
2573
+ template<short NSG>
2574
+ kernel void kernel_gated_delta_net_impl(
2575
+ constant ggml_metal_kargs_gated_delta_net & args,
2576
+ device const char * q,
2577
+ device const char * k,
2578
+ device const char * v,
2579
+ device const char * g,
2580
+ device const char * b,
2581
+ device const char * s,
2582
+ device char * dst,
2583
+ uint3 tgpig[[threadgroup_position_in_grid]],
2584
+ uint3 tpitg[[thread_position_in_threadgroup]],
2585
+ uint3 ntg[[threads_per_threadgroup]]) {
2586
+ #define S_v FC_gated_delta_net_ne20
2587
+ #define G FC_gated_delta_net_ne30
2588
+ #define K FC_gated_delta_net_K
2589
+
2590
+ const uint tx = tpitg.x;
2591
+ const uint ty = tpitg.y;
2592
+
2593
+ const uint i23 = tgpig.z; // B (n_seqs)
2594
+ const uint i21 = tgpig.y; // H (head)
2595
+ const uint i20 = tgpig.x*NSG + ty; // row within S_v
2596
+
2597
+ const uint i01 = i21 % args.ne01;
2598
+ const uint i11 = i21 % args.ne11;
2599
+
2600
+ const float scale = 1.0f / sqrt((float)S_v);
2601
+
2602
+ // input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
2603
+ // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
2604
+ const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
2605
+ device const float * s_ptr = (device const float *) (s) + state_in_base;
2606
+
2607
+ float ls[NSG];
2608
+
2609
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2610
+ const short is = tx*NSG + j;
2611
+ ls[j] = s_ptr[is];
2612
+ }
2613
+
2614
+ device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
2615
+
2616
+ device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
2617
+ device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
2618
+ device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
2619
+
2620
+ device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
2621
+ device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
2622
+
2623
+ // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
2624
+ // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
2625
+
2626
+ // output state base offset: after attention scores
2627
+ const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
2628
+ // output state per-slot size: S_v * S_v * H * n_seqs
2629
+ const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
2630
+ // per-(seq,head) offset within a slot
2631
+ const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
2632
+
2633
+ for (short t = 0; t < args.ne22; t++) {
2634
+ float s_k = 0.0f;
2635
+
2636
+ if (G == 1) {
2637
+ const float g_exp = exp(g_ptr[0]);
2638
+
2639
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2640
+ const short is = tx*NSG + j;
2641
+ ls[j] *= g_exp;
2642
+
2643
+ s_k += ls[j]*k_ptr[is];
2644
+ }
2645
+ } else {
2646
+ // KDA
2647
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2648
+ const short is = tx*NSG + j;
2649
+ ls[j] *= exp(g_ptr[is]);
2650
+
2651
+ s_k += ls[j]*k_ptr[is];
2652
+ }
2653
+ }
2654
+
2655
+ s_k = simd_sum(s_k);
2656
+
2657
+ const float d = (v_ptr[i20] - s_k)*b_ptr[0];
2658
+
2659
+ float y = 0.0f;
2660
+
2661
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2662
+ const short is = tx*NSG + j;
2663
+ ls[j] += k_ptr[is]*d;
2664
+
2665
+ y += ls[j]*q_ptr[is];
2666
+ }
2667
+
2668
+ y = simd_sum(y);
2669
+
2670
+ if (tx == 0) {
2671
+ dst_attn[t*args.ne21*S_v] = y*scale;
2672
+ }
2673
+
2674
+ q_ptr += args.ns02;
2675
+ k_ptr += args.ns12;
2676
+ v_ptr += args.ns22;
2677
+
2678
+ b_ptr += args.ne21;
2679
+ g_ptr += args.ne21*G;
2680
+
2681
+ if (K > 1) {
2682
+ const int target_slot = (int)args.ne22 - 1 - (int)t;
2683
+ if (target_slot >= 0 && target_slot < (int)K) {
2684
+ device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
2685
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2686
+ const short is = tx*NSG + j;
2687
+ dst_state[is] = ls[j];
2688
+ }
2689
+ }
2690
+ }
2691
+ }
2692
+
2693
+ if (K == 1) {
2694
+ device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
2695
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2696
+ const short is = tx*NSG + j;
2697
+ dst_state[is] = ls[j];
2698
+ }
2699
+ }
2700
+
2701
+ #undef S_v
2702
+ #undef G
2703
+ #undef K
2704
+ }
2705
+
2706
+ typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
2707
+
2708
+ template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
2709
+ template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
2710
+ template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
2711
+
2712
+ #else
2713
+ // a simplified version of the above
2714
+ // no performance improvement, so keep the above version for now
2715
+
2716
+ template<typename T, short NSG>
2717
+ kernel void kernel_gated_delta_net_impl(
2718
+ constant ggml_metal_kargs_gated_delta_net & args,
2719
+ device const char * q,
2720
+ device const char * k,
2721
+ device const char * v,
2722
+ device const char * g,
2723
+ device const char * b,
2724
+ device const char * s,
2725
+ device char * dst,
2726
+ uint3 tgpig[[threadgroup_position_in_grid]],
2727
+ uint3 tpitg[[thread_position_in_threadgroup]],
2728
+ uint3 ntg[[threads_per_threadgroup]]) {
2729
+ #define S_v FC_gated_delta_net_ne20
2730
+ #define G FC_gated_delta_net_ne30
2731
+
2732
+ const uint tx = tpitg.x;
2733
+ const uint ty = tpitg.y;
2734
+
2735
+ const uint i23 = tgpig.z; // B
2736
+ const uint i21 = tgpig.y; // H
2737
+ const uint i20 = tgpig.x*NSG + ty;
2738
+
2739
+ const uint i01 = i21 % args.ne01;
2740
+ const uint i11 = i21 % args.ne11;
2741
+
2742
+ const float scale = 1.0f / sqrt((float)S_v);
2743
+
2744
+ device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
2745
+
2746
+ float lsf[NSG];
2747
+
2748
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2749
+ const short is = tx*NSG + j;
2750
+ lsf[j] = s_ptr[is*S_v];
2751
+ }
2752
+
2753
+ thread T * ls = (thread T *) (lsf);
2754
+
2755
+ device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
2756
+
2757
+ device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
2758
+ device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
2759
+ device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
2760
+
2761
+ device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
2762
+ device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
2763
+
2764
+ for (short t = 0; t < args.ne22; t++) {
2765
+ device const T * qt_ptr = (device const T *) (q_ptr);
2766
+ device const T * kt_ptr = (device const T *) (k_ptr);
2767
+ device const T * gt_ptr = (device const T *) (g_ptr);
2768
+
2769
+ if (G == 1) {
2770
+ *ls *= exp(g_ptr[0]);
2771
+ } else {
2772
+ // KDA
2773
+ *ls *= exp(gt_ptr[tx]);
2774
+ }
2775
+
2776
+ const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
2777
+
2778
+ const float d = (v_ptr[i20] - s_k)*b_ptr[0];
2779
+
2780
+ *ls += kt_ptr[tx]*d;
2781
+
2782
+ const float y = simd_sum(dot(*ls, qt_ptr[tx]));
2783
+
2784
+ if (tx == 0) {
2785
+ *dst_attn = y*scale;
2786
+ }
2787
+
2788
+ q_ptr += args.ns02;
2789
+ k_ptr += args.ns12;
2790
+ v_ptr += args.ns22;
2791
+
2792
+ b_ptr += args.ne21;
2793
+ g_ptr += args.ne21*G;
2794
+
2795
+ dst_attn += args.ne21*S_v;
2796
+ }
2797
+
2798
+ device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
2799
+ device T * dstt_state = (device T *) (dst_state);
2800
+
2801
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2802
+ const short is = tx*NSG + j;
2803
+ dst_state[is*S_v] = lsf[j];
2804
+ }
2805
+
2806
+ #undef S_v
2807
+ #undef G
2808
+ }
2809
+
2810
+ typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
2811
+
2812
+ template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
2813
+ template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
2814
+ template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
2815
+ #endif
2816
+
2817
+ constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
2818
+ constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
2819
+ constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
2820
+
2821
+ kernel void kernel_solve_tri_f32(
2822
+ constant ggml_metal_kargs_solve_tri & args,
2823
+ device const char * src0,
2824
+ device const char * src1,
2825
+ device char * dst,
2826
+ threadgroup char * shmem [[threadgroup(0)]],
2827
+ ushort3 tgpig[[threadgroup_position_in_grid]],
2828
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
2829
+ ushort tiisg[[thread_index_in_simdgroup]],
2830
+ ushort3 ntg[[threads_per_threadgroup]]) {
2831
+ constexpr short NW = N_SIMDWIDTH;
2832
+
2833
+ const short NSG = FC_solve_tri_nsg;
2834
+ const short N = FC_solve_tri_n;
2835
+ const short K = FC_solve_tri_k;
2836
+ const short NP = PAD2(N, NW);
2837
+
2838
+ const int32_t i03 = tgpig.z;
2839
+ const int32_t i02 = tgpig.y;
2840
+ const int32_t i01 = tgpig.x*NSG + sgitg;
2841
+
2842
+ threadgroup float * sh0 = (threadgroup float *) shmem;
2843
+
2844
+ device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
2845
+ device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
2846
+ device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
2847
+
2848
+ for (short rr = 0; rr < N; rr += NSG) {
2849
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2850
+
2851
+ {
2852
+ threadgroup float * sh0_cur = sh0 + sgitg*NP;
2853
+
2854
+ for (short t = 0; t*NW < N; ++t) {
2855
+ const short idx = t*NW + tiisg;
2856
+ sh0_cur[idx] = src0_ptr[idx];
2857
+ }
2858
+
2859
+ src0_ptr += NSG*N;
2860
+ }
2861
+
2862
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2863
+
2864
+ if (i01 >= args.ne10) {
2865
+ continue;
2866
+ }
2867
+
2868
+ for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
2869
+ const short r = rr + ir;
2870
+
2871
+ threadgroup float * sh0_cur = sh0 + ir*NP;
2872
+
2873
+ float sum = 0.0f;
2874
+
2875
+ for (short t = 0; t*NW < r; ++t) {
2876
+ const short idx = t*NW + tiisg;
2877
+ sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
2878
+ }
2879
+
2880
+ sum = simd_sum(sum);
2881
+
2882
+ if (tiisg == 0) {
2883
+ const float diag = sh0_cur[r];
2884
+
2885
+ dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
2886
+ }
2887
+ }
2888
+ }
2889
+ }
2890
+
2740
2891
  kernel void kernel_argmax_f32(
2741
2892
  constant ggml_metal_kargs_argmax & args,
2742
2893
  device const char * src0,
@@ -2970,26 +3121,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
2970
3121
  template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
2971
3122
  template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
2972
3123
 
2973
- kernel void kernel_l2_norm_f32(
3124
+ template <typename T0, typename T>
3125
+ kernel void kernel_l2_norm_impl(
2974
3126
  constant ggml_metal_kargs_l2_norm & args,
2975
3127
  device const char * src0,
2976
3128
  device char * dst,
2977
3129
  threadgroup float * shmem_f32 [[threadgroup(0)]],
2978
- uint tgpig[[threadgroup_position_in_grid]],
2979
- ushort tpitg[[thread_position_in_threadgroup]],
2980
- ushort sgitg[[simdgroup_index_in_threadgroup]],
2981
- ushort tiisg[[thread_index_in_simdgroup]],
2982
- ushort ntg[[threads_per_threadgroup]]) {
3130
+ uint3 tgpig[[threadgroup_position_in_grid]],
3131
+ ushort3 tpitg[[thread_position_in_threadgroup]],
3132
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
3133
+ ushort tiisg[[thread_index_in_simdgroup]],
3134
+ ushort3 ntg[[threads_per_threadgroup]]) {
3135
+ const int i03 = tgpig.z;
3136
+ const int i02 = tgpig.y;
3137
+ const int i01 = tgpig.x;
3138
+
2983
3139
  if (sgitg == 0) {
2984
3140
  shmem_f32[tiisg] = 0.0f;
2985
3141
  }
2986
3142
 
2987
- device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
3143
+ device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
3144
+ device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2988
3145
 
2989
3146
  float sumf = 0.0f;
2990
3147
 
2991
3148
  // parallel sum
2992
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
3149
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
2993
3150
  sumf += dot(x[i00], x[i00]);
2994
3151
  }
2995
3152
  sumf = simd_sum(sumf);
@@ -3005,14 +3162,18 @@ kernel void kernel_l2_norm_f32(
3005
3162
  sumf = shmem_f32[tiisg];
3006
3163
  sumf = simd_sum(sumf);
3007
3164
 
3008
- const float scale = 1.0f/sqrt(max(sumf, args.eps));
3165
+ const float scale = 1.0f/max(sqrt(sumf), args.eps);
3009
3166
 
3010
- device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
3011
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
3167
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
3012
3168
  y[i00] = x[i00] * scale;
3013
3169
  }
3014
3170
  }
3015
3171
 
3172
+ typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
3173
+
3174
+ template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
3175
+ template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
3176
+
3016
3177
  kernel void kernel_group_norm_f32(
3017
3178
  constant ggml_metal_kargs_group_norm & args,
3018
3179
  device const float * src0,
@@ -3094,6 +3255,35 @@ kernel void kernel_group_norm_f32(
3094
3255
  }
3095
3256
  }
3096
3257
 
3258
+ // Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
3259
+ inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3260
+ device const uint8_t * qs = qb_curr->qs + il / 8;
3261
+ const uint8_t b0 = qs[0];
3262
+ const uint8_t b1 = qs[1];
3263
+
3264
+ float acc = 0.0f;
3265
+
3266
+ acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
3267
+ acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
3268
+ acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
3269
+ acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
3270
+ acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
3271
+ acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
3272
+ acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
3273
+ acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
3274
+
3275
+ acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
3276
+ acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
3277
+ acc += select(0.0f, yl[10], bool(b1 & 0x04));
3278
+ acc += select(0.0f, yl[11], bool(b1 & 0x08));
3279
+ acc += select(0.0f, yl[12], bool(b1 & 0x10));
3280
+ acc += select(0.0f, yl[13], bool(b1 & 0x20));
3281
+ acc += select(0.0f, yl[14], bool(b1 & 0x40));
3282
+ acc += select(0.0f, yl[15], bool(b1 & 0x80));
3283
+
3284
+ return qb_curr->d * (2.0f * acc - sumy);
3285
+ }
3286
+
3097
3287
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
3098
3288
  // il indicates where the q4 quants begin (0 or QK4_0/4)
3099
3289
  // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -3226,6 +3416,9 @@ static inline void helper_mv_reduce_and_write(
3226
3416
 
3227
3417
  constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
3228
3418
  constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
3419
+ constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]];
3420
+ constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]];
3421
+ constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]];
3229
3422
 
3230
3423
  template<typename block_q_type, short NR0, typename args_t>
3231
3424
  void mul_vec_q_n_f32_impl(
@@ -3249,72 +3442,151 @@ void mul_vec_q_n_f32_impl(
3249
3442
  const int r1 = tgpig.y;
3250
3443
  const int im = tgpig.z;
3251
3444
 
3252
- const uint i12 = im%args.ne12;
3253
- const uint i13 = im/args.ne12;
3445
+ const uint i12 = im%FC_mul_mv_ne12;
3446
+ const uint i13 = im/FC_mul_mv_ne12;
3254
3447
 
3255
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3448
+ //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3256
3449
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3257
3450
 
3258
3451
  //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
3259
3452
  device const float * y = (device const float *) (src1 + offset1);
3260
3453
 
3261
- // pointers to src0 rows
3262
- device const block_q_type * ax[NR0];
3263
- FOR_UNROLL (int row = 0; row < NR0; ++row) {
3264
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3454
+ // pointers to src0 rows
3455
+ device const block_q_type * ax[NR0];
3456
+ FOR_UNROLL (int row = 0; row < NR0; ++row) {
3457
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3458
+
3459
+ ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
3460
+ }
3461
+
3462
+ float sumf[NR0] = {0.f};
3463
+
3464
+ const short ix = (tiisg/(NW/NQ));
3465
+ const short il = (tiisg%(NW/NQ))*8;
3466
+
3467
+ //const int ib0 = sgitg*NQ + ix;
3468
+ const int ib0 = ix;
3469
+
3470
+ float yl[16]; // src1 vector cache
3471
+
3472
+ //device const float * yb = y + ix*QK4_0 + il;
3473
+ device const float * yb = y + ib0*QK4_0 + il;
3474
+
3475
+ // each thread in a SIMD group deals with half a block.
3476
+ //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
3477
+ for (int ib = ib0; ib < nb; ib += NQ) {
3478
+ float sumy[2] = { 0.f, 0.f };
3479
+
3480
+ FOR_UNROLL (short i = 0; i < 8; i += 2) {
3481
+ sumy[0] += yb[i + 0] + yb[i + 1];
3482
+ yl[i + 0] = yb[i + 0];
3483
+ yl[i + 1] = yb[i + 1]/256.f;
3484
+
3485
+ sumy[1] += yb[i + 16] + yb[i + 17];
3486
+ yl[i + 8] = yb[i + 16]/16.f;
3487
+ yl[i + 9] = yb[i + 17]/4096.f;
3488
+ }
3489
+
3490
+ FOR_UNROLL (short row = 0; row < NR0; row++) {
3491
+ sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
3492
+ }
3493
+
3494
+ yb += QK4_0 * 16;
3495
+ //yb += NSG*NQ*QK4_0;
3496
+ }
3497
+
3498
+ device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
3499
+
3500
+ //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3501
+
3502
+ for (int row = 0; row < NR0; ++row) {
3503
+ const float tot = simd_sum(sumf[row]);
3504
+
3505
+ if (tiisg == 0 && r0 + row < args.ne01) {
3506
+ dst_f32[r0 + row] = tot;
3507
+ }
3508
+ }
3509
+ }
3510
+
3511
+ template<int nr0, typename args_t>
3512
+ void kernel_mul_mv_q1_0_f32_impl(
3513
+ args_t args,
3514
+ device const char * src0,
3515
+ device const char * src1,
3516
+ device char * dst,
3517
+ threadgroup char * shmem,
3518
+ uint3 tgpig,
3519
+ ushort tiisg,
3520
+ ushort sgitg) {
3521
+ const short NSG = FC_mul_mv_nsg;
3522
+
3523
+ const int nb = args.ne00/QK1_0;
3524
+
3525
+ const int r0 = tgpig.x;
3526
+ const int r1 = tgpig.y;
3527
+ const int im = tgpig.z;
3528
+
3529
+ const int first_row = (r0 * NSG + sgitg) * nr0;
3265
3530
 
3266
- ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
3267
- }
3531
+ const uint i12 = im%FC_mul_mv_ne12;
3532
+ const uint i13 = im/FC_mul_mv_ne12;
3268
3533
 
3269
- float sumf[NR0] = {0.f};
3534
+ const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
3270
3535
 
3271
- const short ix = (tiisg/(NW/NQ));
3272
- const short il = (tiisg%(NW/NQ))*8;
3536
+ device const float * y = (device const float *) (src1 + offset1);
3273
3537
 
3274
- //const int ib0 = sgitg*NQ + ix;
3275
- const int ib0 = ix;
3538
+ device const block_q1_0 * ax[nr0];
3539
+ for (int row = 0; row < nr0; ++row) {
3540
+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3541
+ ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
3542
+ }
3276
3543
 
3277
- float yl[16]; // src1 vector cache
3544
+ float yl[16];
3545
+ float sumf[nr0] = {0.f};
3278
3546
 
3279
- //device const float * yb = y + ix*QK4_0 + il;
3280
- device const float * yb = y + ib0*QK4_0 + il;
3547
+ const short ix = (tiisg/8);
3548
+ const short il = (tiisg%8)*16;
3281
3549
 
3282
- // each thread in a SIMD group deals with half a block.
3283
- //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
3284
- for (int ib = ib0; ib < nb; ib += NQ) {
3285
- float sumy[2] = { 0.f, 0.f };
3550
+ device const float * yb = y + ix*QK1_0 + il;
3286
3551
 
3287
- FOR_UNROLL (short i = 0; i < 8; i += 2) {
3288
- sumy[0] += yb[i + 0] + yb[i + 1];
3289
- yl[i + 0] = yb[i + 0];
3290
- yl[i + 1] = yb[i + 1]/256.f;
3552
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
3553
+ float sumy = 0.f;
3291
3554
 
3292
- sumy[1] += yb[i + 16] + yb[i + 17];
3293
- yl[i + 8] = yb[i + 16]/16.f;
3294
- yl[i + 9] = yb[i + 17]/4096.f;
3555
+ FOR_UNROLL (short i = 0; i < 16; i++) {
3556
+ yl[i] = yb[i];
3557
+ sumy += yb[i];
3295
3558
  }
3296
3559
 
3297
- FOR_UNROLL (short row = 0; row < NR0; row++) {
3298
- sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
3560
+ FOR_UNROLL (short row = 0; row < nr0; row++) {
3561
+ sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
3299
3562
  }
3300
3563
 
3301
- yb += QK4_0 * 16;
3302
- //yb += NSG*NQ*QK4_0;
3564
+ yb += QK1_0 * (N_SIMDWIDTH/8);
3303
3565
  }
3304
3566
 
3305
- device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
3306
-
3307
- //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3567
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3308
3568
 
3309
- for (int row = 0; row < NR0; ++row) {
3569
+ for (int row = 0; row < nr0; ++row) {
3310
3570
  const float tot = simd_sum(sumf[row]);
3311
3571
 
3312
- if (tiisg == 0 && r0 + row < args.ne01) {
3313
- dst_f32[r0 + row] = tot;
3572
+ if (tiisg == 0 && first_row + row < args.ne01) {
3573
+ dst_f32[first_row + row] = tot;
3314
3574
  }
3315
3575
  }
3316
3576
  }
3317
3577
 
3578
+ [[host_name("kernel_mul_mv_q1_0_f32")]]
3579
+ kernel void kernel_mul_mv_q1_0_f32(
3580
+ constant ggml_metal_kargs_mul_mv & args,
3581
+ device const char * src0,
3582
+ device const char * src1,
3583
+ device char * dst,
3584
+ uint3 tgpig[[threadgroup_position_in_grid]],
3585
+ ushort tiisg[[thread_index_in_simdgroup]],
3586
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3587
+ kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
3588
+ }
3589
+
3318
3590
  kernel void kernel_mul_mv_q4_0_f32(
3319
3591
  constant ggml_metal_kargs_mul_mv & args,
3320
3592
  device const char * src0,
@@ -3384,10 +3656,10 @@ void kernel_mul_mv_q8_0_f32_impl(
3384
3656
  const int r1 = tgpig.y;
3385
3657
  const int im = tgpig.z;
3386
3658
 
3387
- const uint i12 = im%args.ne12;
3388
- const uint i13 = im/args.ne12;
3659
+ const uint i12 = im%FC_mul_mv_ne12;
3660
+ const uint i13 = im/FC_mul_mv_ne12;
3389
3661
 
3390
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3662
+ //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3391
3663
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3392
3664
 
3393
3665
  //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
@@ -3396,7 +3668,7 @@ void kernel_mul_mv_q8_0_f32_impl(
3396
3668
  // pointers to src0 rows
3397
3669
  device const block_q8_0 * ax[NR0];
3398
3670
  FOR_UNROLL (short row = 0; row < NR0; ++row) {
3399
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3671
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3400
3672
 
3401
3673
  ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
3402
3674
  }
@@ -3476,10 +3748,10 @@ void kernel_mul_mv_ext_q4_f32_impl(
3476
3748
  const int i11 = tgpig.y*r1ptg;
3477
3749
  const int i1m = tgpig.z;
3478
3750
 
3479
- const int i12 = i1m%args.ne12;
3480
- const int i13 = i1m/args.ne12;
3751
+ const int i12 = i1m%FC_mul_mv_ne12;
3752
+ const int i13 = i1m/FC_mul_mv_ne12;
3481
3753
 
3482
- const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3754
+ const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3483
3755
  const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3484
3756
 
3485
3757
  device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
@@ -3579,10 +3851,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
3579
3851
  const int i11 = tgpig.y*r1ptg;
3580
3852
  const int i1m = tgpig.z;
3581
3853
 
3582
- const int i12 = i1m%args.ne12;
3583
- const int i13 = i1m/args.ne12;
3854
+ const int i12 = i1m%FC_mul_mv_ne12;
3855
+ const int i13 = i1m/FC_mul_mv_ne12;
3584
3856
 
3585
- const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3857
+ const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3586
3858
  const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3587
3859
 
3588
3860
  device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
@@ -3700,6 +3972,18 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
3700
3972
  template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
3701
3973
  template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
3702
3974
 
3975
+ #if defined(GGML_METAL_HAS_BF16)
3976
+ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
3977
+ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
3978
+ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
3979
+ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
3980
+ #endif
3981
+
3982
+ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
3983
+ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
3984
+ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
3985
+ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
3986
+
3703
3987
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
3704
3988
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
3705
3989
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@@ -3750,6 +4034,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
3750
4034
  template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
3751
4035
  template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
3752
4036
 
4037
+ template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
4038
+ template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
4039
+ template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
4040
+ template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
4041
+
4042
+ template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
4043
+ template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
4044
+ template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
4045
+ template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
4046
+
3753
4047
  template<typename T0, typename T1, short NR0, typename args_t>
3754
4048
  void kernel_mul_mv_t_t_impl(
3755
4049
  args_t args,
@@ -3772,10 +4066,10 @@ void kernel_mul_mv_t_t_impl(
3772
4066
  const int r1 = tgpig.y;
3773
4067
  const int im = tgpig.z;
3774
4068
 
3775
- const uint i12 = im%args.ne12;
3776
- const uint i13 = im/args.ne12;
4069
+ const uint i12 = im%FC_mul_mv_ne12;
4070
+ const uint i13 = im/FC_mul_mv_ne12;
3777
4071
 
3778
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4072
+ //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3779
4073
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3780
4074
 
3781
4075
  //device const T0 * x = (device const T0 *) (src0 + offset0);
@@ -3784,7 +4078,7 @@ void kernel_mul_mv_t_t_impl(
3784
4078
  // pointers to src0 rows
3785
4079
  device const T0 * ax [NR0];
3786
4080
  FOR_UNROLL (short row = 0; row < NR0; ++row) {
3787
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4081
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3788
4082
 
3789
4083
  ax[row] = (device const T0 *) ((device char *) src0 + offset0);
3790
4084
  }
@@ -3894,10 +4188,10 @@ void kernel_mul_mv_t_t_4_impl(
3894
4188
  const int r1 = tgpig.y;
3895
4189
  const int im = tgpig.z;
3896
4190
 
3897
- const uint i12 = im%args.ne12;
3898
- const uint i13 = im/args.ne12;
4191
+ const uint i12 = im%FC_mul_mv_ne12;
4192
+ const uint i13 = im/FC_mul_mv_ne12;
3899
4193
 
3900
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4194
+ //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3901
4195
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3902
4196
 
3903
4197
  device const T1 * y = (device const T1 *) (src1 + offset1);
@@ -3907,7 +4201,7 @@ void kernel_mul_mv_t_t_4_impl(
3907
4201
  device const T0 * ax [NR0];
3908
4202
  device const T04 * ax4[NR0];
3909
4203
  FOR_UNROLL (short row = 0; row < NR0; ++row) {
3910
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4204
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3911
4205
 
3912
4206
  ax [row] = (device const T0 *) ((device char *) src0 + offset0);
3913
4207
  ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
@@ -4011,10 +4305,10 @@ void kernel_mul_mv_t_t_short_impl(
4011
4305
  return;
4012
4306
  }
4013
4307
 
4014
- const uint i12 = im%args.ne12;
4015
- const uint i13 = im/args.ne12;
4308
+ const uint i12 = im%FC_mul_mv_ne12;
4309
+ const uint i13 = im/FC_mul_mv_ne12;
4016
4310
 
4017
- const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4311
+ const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
4018
4312
 
4019
4313
  device const T0 * x = (device const T0 *) (src0 + offset0);
4020
4314
 
@@ -4437,59 +4731,59 @@ kernel void kernel_im2col(
4437
4731
  template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4438
4732
  template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4439
4733
 
4440
- // TODO: obolete -- remove
4441
- //typedef void (im2col_ext_t)(
4442
- // constant ggml_metal_kargs_im2col & args,
4443
- // device const float * x,
4444
- // device char * dst,
4445
- // uint3 tgpig[[threadgroup_position_in_grid]],
4446
- // uint3 tgpg[[threadgroups_per_grid]],
4447
- // uint3 tpitg[[thread_position_in_threadgroup]],
4448
- // uint3 ntg[[threads_per_threadgroup]]);
4449
- //
4450
- //template <typename T>
4451
- //kernel void kernel_im2col_ext(
4452
- // constant ggml_metal_kargs_im2col & args,
4453
- // device const float * x,
4454
- // device char * dst,
4455
- // uint3 tgpig[[threadgroup_position_in_grid]],
4456
- // uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4457
- // uint3 tpitg[[thread_position_in_threadgroup]],
4458
- // uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4459
- // const int64_t KHW = (int64_t)args.KHW;
4460
- //
4461
- // const int64_t d = tgpig[0] / args.CHW;
4462
- // const int64_t chw = tgpig[0] % args.CHW;
4463
- // const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4464
- // const int64_t HW = tgpig[0] % KHW;
4465
- //
4466
- // const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4467
- // if (tpitg_0 >= args.N) {
4468
- // return;
4469
- // }
4470
- //
4471
- // const int64_t tpitg_1 = HW / args.KW;
4472
- // const int64_t tpitg_2 = HW % args.KW;
4473
- //
4474
- // const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4475
- // const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4476
- //
4477
- // const int64_t offset_dst =
4478
- // (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4479
- // (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4480
- //
4481
- // device T * pdst = (device T *) (dst);
4482
- //
4483
- // if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4484
- // pdst[offset_dst] = 0.0f;
4485
- // } else {
4486
- // const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4487
- // pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4488
- // }
4489
- //}
4490
- //
4491
- //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4492
- //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4734
+ // TODO: optimize
4735
+ typedef void (im2col_ext_t)(
4736
+ constant ggml_metal_kargs_im2col & args,
4737
+ device const float * x,
4738
+ device char * dst,
4739
+ uint3 tgpig[[threadgroup_position_in_grid]],
4740
+ uint3 tgpg[[threadgroups_per_grid]],
4741
+ uint3 tpitg[[thread_position_in_threadgroup]],
4742
+ uint3 ntg[[threads_per_threadgroup]]);
4743
+
4744
+ template <typename T>
4745
+ kernel void kernel_im2col_ext(
4746
+ constant ggml_metal_kargs_im2col & args,
4747
+ device const float * x,
4748
+ device char * dst,
4749
+ uint3 tgpig[[threadgroup_position_in_grid]],
4750
+ uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4751
+ uint3 tpitg[[thread_position_in_threadgroup]],
4752
+ uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4753
+ const int64_t KHW = (int64_t)args.KHW;
4754
+
4755
+ const int64_t d = tgpig[0] / args.CHW;
4756
+ const int64_t chw = tgpig[0] % args.CHW;
4757
+ const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4758
+ const int64_t HW = tgpig[0] % KHW;
4759
+
4760
+ const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4761
+ if (tpitg_0 >= args.N) {
4762
+ return;
4763
+ }
4764
+
4765
+ const int64_t tpitg_1 = HW / args.KW;
4766
+ const int64_t tpitg_2 = HW % args.KW;
4767
+
4768
+ const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4769
+ const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4770
+
4771
+ const int64_t offset_dst =
4772
+ (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4773
+ (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4774
+
4775
+ device T * pdst = (device T *) (dst);
4776
+
4777
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4778
+ pdst[offset_dst] = 0.0f;
4779
+ } else {
4780
+ const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4781
+ pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4782
+ }
4783
+ }
4784
+
4785
+ template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4786
+ template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4493
4787
 
4494
4788
  template <typename TK>
4495
4789
  kernel void kernel_conv_2d(
@@ -4622,15 +4916,32 @@ kernel void kernel_conv_transpose_1d(
4622
4916
  uint3 tgpig[[threadgroup_position_in_grid]],
4623
4917
  uint3 tgpg[[threadgroups_per_grid]]) {
4624
4918
 
4625
- float v = 0.0f;
4919
+ // For output position j on the time axis, only input positions
4920
+ // i such that i*s0 <= j < i*s0 + K
4921
+ // contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
4922
+ // intersected with [0, IL-1]. That's at most ceil(K/s0) values
4923
+ // (typically 2 for stride==K/2 transposed convs).
4924
+ const int32_t j = tgpig[0];
4925
+ const int32_t s0 = args.s0;
4926
+ const int32_t K = args.K;
4927
+ const int32_t IL = args.IL;
4928
+
4929
+ int32_t i_min;
4930
+ {
4931
+ int32_t a = j - K + 1;
4932
+ i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
4933
+ }
4934
+ int32_t i_max = j / s0;
4935
+ if (i_max > IL - 1) i_max = IL - 1;
4626
4936
 
4627
- for (int64_t c = 0; c < args.IC; c++) {
4628
- const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
4629
- const int32_t input_offset = c * args.IL;
4937
+ float v = 0.0f;
4938
+ if (i_min <= i_max) {
4939
+ for (int64_t c = 0; c < args.IC; c++) {
4940
+ const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
4941
+ const int32_t input_offset = c * IL;
4630
4942
 
4631
- for (int64_t i = 0; i < args.IL; i++) {
4632
- if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
4633
- v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
4943
+ for (int32_t i = i_min; i <= i_max; i++) {
4944
+ v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
4634
4945
  }
4635
4946
  }
4636
4947
  }
@@ -4749,7 +5060,9 @@ kernel void kernel_conv_transpose_2d<half>(
4749
5060
  uint3 tpitg[[thread_position_in_threadgroup]],
4750
5061
  uint3 ntg[[threads_per_threadgroup]]);
4751
5062
 
4752
- kernel void kernel_upscale_f32(
5063
+ constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
5064
+
5065
+ kernel void kernel_upscale_nearest_f32(
4753
5066
  constant ggml_metal_kargs_upscale & args,
4754
5067
  device const char * src0,
4755
5068
  device char * dst,
@@ -4775,8 +5088,12 @@ kernel void kernel_upscale_f32(
4775
5088
  }
4776
5089
  }
4777
5090
 
4778
- kernel void kernel_pad_f32(
4779
- constant ggml_metal_kargs_pad & args,
5091
+ static inline float bilinear_tri(float x) {
5092
+ return MAX(0.0f, 1.0f - fabs(x));
5093
+ }
5094
+
5095
+ kernel void kernel_upscale_bilinear_f32(
5096
+ constant ggml_metal_kargs_upscale & args,
4780
5097
  device const char * src0,
4781
5098
  device char * dst,
4782
5099
  uint3 tgpig[[threadgroup_position_in_grid]],
@@ -4787,30 +5104,306 @@ kernel void kernel_pad_f32(
4787
5104
  const int64_t i2 = tgpig.y;
4788
5105
  const int64_t i1 = tgpig.x;
4789
5106
 
4790
- const int64_t i03 = i3;
4791
- const int64_t i02 = i2;
4792
- const int64_t i01 = i1;
5107
+ const int64_t i03 = i3 / args.sf3;
5108
+ const int64_t i02 = i2 / args.sf2;
4793
5109
 
4794
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4795
- device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
5110
+ const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
5111
+ const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
5112
+ const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
5113
+ const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
5114
+
5115
+ src0 += i03*args.nb03 + i02*args.nb02;
5116
+
5117
+ device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
5118
+
5119
+ if (FC_upscale_aa) {
5120
+ const float support0 = MAX(1.0f, 1.0f / args.sf0);
5121
+ const float invscale0 = 1.0f / support0;
5122
+ const float support1 = MAX(1.0f, 1.0f / args.sf1);
5123
+ const float invscale1 = 1.0f / support1;
4796
5124
 
4797
- if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
4798
5125
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4799
- if (i0 < args.ne00) {
4800
- dst_ptr[i0] = src0_ptr[i0];
4801
- } else {
4802
- dst_ptr[i0] = 0.0f;
5126
+ const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
5127
+
5128
+ int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
5129
+ int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
5130
+
5131
+ int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
5132
+ int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
5133
+
5134
+ float sum = 0.0f;
5135
+ float wsum = 0.0f;
5136
+
5137
+ for (int64_t sy = y_min; sy < y_max; ++sy) {
5138
+ const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
5139
+ for (int64_t sx = x_min; sx < x_max; ++sx) {
5140
+ const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
5141
+ const float w = wx * wy;
5142
+ device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
5143
+ sum += (*src_ptr) * w;
5144
+ wsum += w;
5145
+ }
4803
5146
  }
5147
+
5148
+ const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
5149
+ dst_ptr[i0] = v;
4804
5150
  }
5151
+ } else {
5152
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
5153
+ const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
5154
+ const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
5155
+ const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
5156
+ const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
4805
5157
 
4806
- return;
5158
+ device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
5159
+ device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
5160
+ device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
5161
+ device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
5162
+
5163
+ const float v =
5164
+ (*src00) * (1.0f - fd0) * (1.0f - fd1) +
5165
+ (*src10) * fd0 * (1.0f - fd1) +
5166
+ (*src01) * (1.0f - fd0) * fd1 +
5167
+ (*src11) * fd0 * fd1;
5168
+
5169
+ dst_ptr[i0] = v;
5170
+ }
5171
+ }
5172
+ }
5173
+
5174
+ template <typename T>
5175
+ kernel void kernel_conv_3d(
5176
+ constant ggml_metal_kargs_conv_3d & args,
5177
+ device const char * src0, // Weights [IC * OC, KD, KH, KW]
5178
+ device const char * src1, // Inputs [IC * N, ID, IH, IW]
5179
+ device char * dst, // Outputs [OC * N, OD, OH, OW]
5180
+ uint3 tgpig[[threadgroup_position_in_grid]],
5181
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
5182
+
5183
+ // 1. Un-flatten the spatial dimension from Grid X
5184
+ int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
5185
+
5186
+ if (spatial_idx >= args.OW * args.OH * args.OD) {
5187
+ return; // Thread falls outside the spatial volume
5188
+ }
5189
+
5190
+ int64_t od = spatial_idx / (args.OW * args.OH);
5191
+ int64_t oh = (spatial_idx / args.OW) % args.OH;
5192
+ int64_t ow = spatial_idx % args.OW;
5193
+
5194
+ // 2. Map Y to Channels, Z to Batch
5195
+ int64_t oc = tgpig.y;
5196
+ int64_t batch_idx = tgpig.z;
5197
+
5198
+ // 3. Calculate anchor coordinates in the Input volume
5199
+ int64_t i_w_base = ow * args.s0 - args.p0;
5200
+ int64_t i_h_base = oh * args.s1 - args.p1;
5201
+ int64_t i_d_base = od * args.s2 - args.p2;
5202
+
5203
+ float sum = 0.0f;
5204
+
5205
+ // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
5206
+ for (int64_t ic = 0; ic < args.IC; ++ic) {
5207
+
5208
+ // ggml packs batch and channel together in the 4th dimension
5209
+ int64_t src_cn_idx = batch_idx * args.IC + ic;
5210
+ int64_t w_cn_idx = oc * args.IC + ic;
5211
+
5212
+ for (int64_t kz = 0; kz < args.KD; ++kz) {
5213
+ int64_t id = i_d_base + kz * args.d2;
5214
+ if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
5215
+
5216
+ for (int64_t ky = 0; ky < args.KH; ++ky) {
5217
+ int64_t ih = i_h_base + ky * args.d1;
5218
+ if (ih < 0 || ih >= args.IH) continue;
5219
+
5220
+ for (int64_t kx = 0; kx < args.KW; ++kx) {
5221
+ int64_t iw = i_w_base + kx * args.d0;
5222
+ if (iw < 0 || iw >= args.IW) continue;
5223
+
5224
+ // Convert multi-dimensional coordinates to flat byte offsets
5225
+ int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
5226
+ int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
5227
+
5228
+ // Dereference memory and cast weights to f32 if they were f16
5229
+ float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
5230
+ float i_val = *(device const float*)((device const char*)src1 + i_idx);
5231
+
5232
+ sum += w_val * i_val;
5233
+ }
5234
+ }
5235
+ }
5236
+ }
5237
+
5238
+ // 5. Write the accumulated value out to RAM
5239
+ int64_t dst_cn_idx = batch_idx * args.OC + oc;
5240
+ int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
5241
+
5242
+ *(device float*)(dst + d_idx) = sum;
5243
+ }
5244
+
5245
+ // Explicit instantiations so the JIT compiler can find them by name
5246
+ template [[host_name("kernel_conv_3d_f32_f32")]]
5247
+ kernel void kernel_conv_3d<float>(
5248
+ constant ggml_metal_kargs_conv_3d & args,
5249
+ device const char * src0,
5250
+ device const char * src1,
5251
+ device char * dst,
5252
+ uint3 tgpig[[threadgroup_position_in_grid]],
5253
+ uint3 tpitg[[thread_position_in_threadgroup]]);
5254
+
5255
+ // Explicit instantiation for f16 weights
5256
+ template [[host_name("kernel_conv_3d_f16_f32")]]
5257
+ kernel void kernel_conv_3d<half>(
5258
+ constant ggml_metal_kargs_conv_3d & args,
5259
+ device const char * src0,
5260
+ device const char * src1,
5261
+ device char * dst,
5262
+ uint3 tgpig[[threadgroup_position_in_grid]],
5263
+ uint3 tpitg[[thread_position_in_threadgroup]]);
5264
+
5265
+
5266
+ static inline float bicubic_weight1(float x) {
5267
+ const float a = -0.75f;
5268
+ return ((a + 2) * x - (a + 3)) * x * x + 1;
5269
+ }
5270
+
5271
+ static inline float bicubic_weight2(float x) {
5272
+ const float a = -0.75f;
5273
+ return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
5274
+ }
5275
+
5276
+ kernel void kernel_upscale_bicubic_f32(
5277
+ constant ggml_metal_kargs_upscale & args,
5278
+ device const char * src0,
5279
+ device char * dst,
5280
+ uint3 tgpig[[threadgroup_position_in_grid]],
5281
+ uint3 tpitg[[thread_position_in_threadgroup]],
5282
+ uint3 ntg[[threads_per_threadgroup]]) {
5283
+
5284
+ const int64_t i3 = tgpig.z;
5285
+ const int64_t i2 = tgpig.y;
5286
+ const int64_t i1 = tgpig.x;
5287
+
5288
+ const int64_t i03 = i3 / args.sf3;
5289
+ const int64_t i02 = i2 / args.sf2;
5290
+
5291
+ const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
5292
+ const int64_t i01 = (int64_t)floor(f01);
5293
+ const float fd1 = f01 - (float)i01;
5294
+
5295
+ const float w_y0 = bicubic_weight2(fd1 + 1.0f);
5296
+ const float w_y1 = bicubic_weight1(fd1);
5297
+ const float w_y2 = bicubic_weight1(1.0f - fd1);
5298
+ const float w_y3 = bicubic_weight2(2.0f - fd1);
5299
+
5300
+ const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
5301
+
5302
+ device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
5303
+
5304
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
5305
+ const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
5306
+ const int64_t i00 = (int64_t)floor(f00);
5307
+ const float fd0 = f00 - (float)i00;
5308
+
5309
+ const float w_x0 = bicubic_weight2(fd0 + 1.0f);
5310
+ const float w_x1 = bicubic_weight1(fd0);
5311
+ const float w_x2 = bicubic_weight1(1.0f - fd0);
5312
+ const float w_x3 = bicubic_weight2(2.0f - fd0);
5313
+
5314
+ float sum = 0.0f;
5315
+
5316
+ for (int dy = -1; dy <= 2; ++dy) {
5317
+ const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
5318
+ const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
5319
+
5320
+ for (int dx = -1; dx <= 2; ++dx) {
5321
+ const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
5322
+ const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
5323
+
5324
+ device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
5325
+ sum += (*src_ptr) * wx * wy;
5326
+ }
5327
+ }
5328
+
5329
+ dst_ptr[i0] = sum;
4807
5330
  }
5331
+ }
5332
+
5333
+ kernel void kernel_roll_f32(
5334
+ constant ggml_metal_kargs_roll & args,
5335
+ device const char * src0,
5336
+ device char * dst,
5337
+ uint3 tgpig[[threadgroup_position_in_grid]],
5338
+ uint3 tpitg[[thread_position_in_threadgroup]],
5339
+ uint3 ntg[[threads_per_threadgroup]]) {
5340
+
5341
+ const int64_t i3 = tgpig.z;
5342
+ const int64_t i2 = tgpig.y;
5343
+ const int64_t i1 = tgpig.x;
5344
+
5345
+ device const float * src0_ptr = (device const float *) src0;
5346
+ device float * dst_ptr = (device float *) dst;
4808
5347
 
4809
5348
  for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4810
- dst_ptr[i0] = 0.0f;
5349
+ // apply shifts and wrap around
5350
+ int64_t i00 = i0 - args.s0;
5351
+ int64_t i01 = i1 - args.s1;
5352
+ int64_t i02 = i2 - args.s2;
5353
+ int64_t i03 = i3 - args.s3;
5354
+
5355
+ if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
5356
+ if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
5357
+ if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
5358
+ if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
5359
+
5360
+ int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
5361
+ int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
5362
+
5363
+ dst_ptr[dst_idx] = src0_ptr[src_idx];
5364
+ }
5365
+ }
5366
+
5367
+ template <typename T>
5368
+ kernel void kernel_pad_impl(
5369
+ constant ggml_metal_kargs_pad & args,
5370
+ device const char * src0,
5371
+ device char * dst,
5372
+ uint3 tgpig[[threadgroup_position_in_grid]],
5373
+ uint3 tpitg[[thread_position_in_threadgroup]],
5374
+ uint3 ntg[[threads_per_threadgroup]]) {
5375
+ const int32_t i3 = tgpig.z;
5376
+ const int32_t i2 = tgpig.y;
5377
+ const int32_t k0 = tgpig.x/args.ne1;
5378
+ const int32_t i1 = tgpig.x - k0*args.ne1;
5379
+
5380
+ const int32_t i03 = i3;
5381
+ const int32_t i02 = i2;
5382
+ const int32_t i01 = i1;
5383
+
5384
+ device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
5385
+ device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
5386
+
5387
+ for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
5388
+ const int32_t i0 = k0*1024 + tpitg.x + l0;
5389
+ if (i0 >= args.ne0) {
5390
+ break;
5391
+ }
5392
+
5393
+ if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
5394
+ dst_ptr[i0] = src0_ptr[i0];
5395
+ } else {
5396
+ dst_ptr[i0] = 0.0f;
5397
+ }
4811
5398
  }
4812
5399
  }
4813
5400
 
5401
+ typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
5402
+
5403
+ template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
5404
+ template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
5405
+
5406
+ // TODO: this is slow - optimize
4814
5407
  kernel void kernel_pad_reflect_1d_f32(
4815
5408
  constant ggml_metal_kargs_pad_reflect_1d & args,
4816
5409
  device const char * src0,
@@ -5114,24 +5707,6 @@ kernel void kernel_argsort_merge_f32_i32(
5114
5707
  template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
5115
5708
  template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
5116
5709
 
5117
- kernel void kernel_leaky_relu_f32(
5118
- constant ggml_metal_kargs_leaky_relu & args,
5119
- device const float * src0,
5120
- device float * dst,
5121
- uint tpig[[thread_position_in_grid]]) {
5122
- const float x = src0[tpig];
5123
- dst[tpig] = x > 0.0f ? x : x * args.slope;
5124
- }
5125
-
5126
- kernel void kernel_leaky_relu_f32_4(
5127
- constant ggml_metal_kargs_leaky_relu & args,
5128
- device const float4 * src0,
5129
- device float4 * dst,
5130
- uint tpig[[thread_position_in_grid]]) {
5131
- const float4 x = src0[tpig];
5132
- dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
5133
- }
5134
-
5135
5710
  constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
5136
5711
 
5137
5712
  constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -5208,6 +5783,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
5208
5783
  // scan the blocks of the mask that are not masked
5209
5784
  // 0 - masked (i.e. full of -INF, skip)
5210
5785
  // 1 - not masked (i.e. at least one element of the mask is not -INF)
5786
+ // 2 - all zero
5211
5787
  kernel void kernel_flash_attn_ext_blk(
5212
5788
  constant ggml_metal_kargs_flash_attn_ext_blk & args,
5213
5789
  device const char * mask,
@@ -5229,27 +5805,29 @@ kernel void kernel_flash_attn_ext_blk(
5229
5805
 
5230
5806
  device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
5231
5807
 
5232
- // fast route
5233
- if (res == 0) {
5234
- if (simd_max(*mask_src) > -MAXHALF/2) {
5235
- res = 1;
5236
- }
5237
- }
5238
-
5239
5808
  // detailed check of the elements of the block
5240
5809
  if ((C > NW || Q > 1) && res == 0) {
5241
- half m = -MAXHALF;
5810
+ half mmin = MAXHALF;
5811
+ half mmax = -MAXHALF;
5242
5812
 
5243
5813
  FOR_UNROLL (short j = 0; j < Q; ++j) {
5244
5814
  FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
5245
- m = max(m, mask_src[ii*NW]);
5815
+ mmin = min(mmin, mask_src[ii*NW]);
5816
+ mmax = max(mmax, mask_src[ii*NW]);
5246
5817
  }
5247
5818
 
5248
5819
  mask_src += args.nb31/2;
5249
5820
  }
5250
5821
 
5251
- if (simd_max(m) > -MAXHALF/2) {
5252
- res = 1;
5822
+ mmin = simd_min(mmin);
5823
+ mmax = simd_max(mmax);
5824
+
5825
+ if (mmax > -MAXHALF) {
5826
+ if (mmin == 0.0 && mmax == 0.0) {
5827
+ res = 2;
5828
+ } else {
5829
+ res = 1;
5830
+ }
5253
5831
  }
5254
5832
  }
5255
5833
 
@@ -5491,9 +6069,13 @@ void kernel_flash_attn_ext_impl(
5491
6069
  ic = 0;
5492
6070
  }
5493
6071
 
6072
+ char blk_cur = 1;
6073
+
5494
6074
  // read the mask into shared mem
5495
6075
  if (FC_flash_attn_ext_has_mask) {
5496
- if (blk[ic0] == 0) {
6076
+ blk_cur = blk[ic0];
6077
+
6078
+ if (blk_cur == 0) {
5497
6079
  FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5498
6080
  pm2[jj] += NW;
5499
6081
  }
@@ -5501,16 +6083,22 @@ void kernel_flash_attn_ext_impl(
5501
6083
  continue;
5502
6084
  }
5503
6085
 
5504
- FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5505
- const short j = jj*NSG + sgitg;
6086
+ if (blk_cur == 1) {
6087
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
6088
+ const short j = jj*NSG + sgitg;
5506
6089
 
5507
- if (FC_flash_attn_ext_bc_mask) {
5508
- sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
5509
- } else {
5510
- sm2[j*SH + tiisg] = pm2[jj][tiisg];
5511
- }
6090
+ if (FC_flash_attn_ext_bc_mask) {
6091
+ sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
6092
+ } else {
6093
+ sm2[j*SH + tiisg] = pm2[jj][tiisg];
6094
+ }
5512
6095
 
5513
- pm2[jj] += NW;
6096
+ pm2[jj] += NW;
6097
+ }
6098
+ } else if (blk_cur == 2) {
6099
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
6100
+ pm2[jj] += NW;
6101
+ }
5514
6102
  }
5515
6103
 
5516
6104
  #if 0
@@ -5552,9 +6140,7 @@ void kernel_flash_attn_ext_impl(
5552
6140
 
5553
6141
  constexpr short NC = (C/8)/NSG;
5554
6142
 
5555
- // note: do not unroll for large heads
5556
- #pragma unroll (DK <= 64 ? NC : 1)
5557
- for (short cc = 0; cc < NC; ++cc) {
6143
+ FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
5558
6144
  qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
5559
6145
 
5560
6146
  if (DK % 16 != 0) {
@@ -5575,7 +6161,9 @@ void kernel_flash_attn_ext_impl(
5575
6161
  k8x8_t mk[2];
5576
6162
  q8x8_t mq[2];
5577
6163
 
5578
- FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
6164
+ // note: too much unroll can tank the performance for large heads
6165
+ #pragma unroll (MIN(DK8/2, 4*NSG))
6166
+ for (short i = 0; i < DK8/2; ++i) {
5579
6167
  simdgroup_barrier(mem_flags::mem_none);
5580
6168
 
5581
6169
  simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -5675,10 +6263,12 @@ void kernel_flash_attn_ext_impl(
5675
6263
  }
5676
6264
 
5677
6265
  // mqk = mqk + slope*mask
5678
- if (FC_flash_attn_ext_has_bias) {
5679
- s2 += s2_t(sm2[j*SH + tiisg])*slope;
5680
- } else {
5681
- s2 += s2_t(sm2[j*SH + tiisg]);
6266
+ if (blk_cur != 2) {
6267
+ if (FC_flash_attn_ext_has_bias) {
6268
+ s2 += s2_t(sm2[j*SH + tiisg])*slope;
6269
+ } else {
6270
+ s2 += s2_t(sm2[j*SH + tiisg]);
6271
+ }
5682
6272
  }
5683
6273
 
5684
6274
  M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
@@ -5749,7 +6339,9 @@ void kernel_flash_attn_ext_impl(
5749
6339
  pv += 8*NS20;
5750
6340
  }
5751
6341
  } else {
5752
- FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
6342
+ constexpr short NC = (C/8)/2;
6343
+
6344
+ FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
5753
6345
  s8x8_t vs[2];
5754
6346
 
5755
6347
  simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -5929,7 +6521,7 @@ template<
5929
6521
  void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
5930
6522
  short DK, // K head size
5931
6523
  short DV, // V head size
5932
- short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
6524
+ short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
5933
6525
  short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
5934
6526
  kernel void kernel_flash_attn_ext(
5935
6527
  constant ggml_metal_kargs_flash_attn_ext & args,
@@ -5952,6 +6544,7 @@ kernel void kernel_flash_attn_ext(
5952
6544
  //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
5953
6545
  //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
5954
6546
  case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
6547
+ case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
5955
6548
  }
5956
6549
  #undef FWD_TMPL
5957
6550
  #undef FWD_ARGS
@@ -6001,6 +6594,8 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at
6001
6594
  template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
6002
6595
  template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
6003
6596
  template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
6597
+ template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>;
6598
+ template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 512, 512>;
6004
6599
  template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
6005
6600
 
6006
6601
  template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
@@ -6015,6 +6610,8 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at
6015
6610
  template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
6016
6611
  template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
6017
6612
  template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
6613
+ template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>;
6614
+ template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 512, 512>;
6018
6615
  template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
6019
6616
 
6020
6617
  #if defined(GGML_METAL_HAS_BF16)
@@ -6030,6 +6627,8 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at
6030
6627
  template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
6031
6628
  template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
6032
6629
  template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
6630
+ template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>;
6631
+ template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 512, 512>;
6033
6632
  template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
6034
6633
  #endif
6035
6634
 
@@ -6045,6 +6644,8 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at
6045
6644
  template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
6046
6645
  template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
6047
6646
  template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
6647
+ template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;
6648
+ template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 512, 512>;
6048
6649
  template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
6049
6650
 
6050
6651
  template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
@@ -6059,6 +6660,8 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at
6059
6660
  template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
6060
6661
  template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
6061
6662
  template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
6663
+ template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;
6664
+ template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 512, 512>;
6062
6665
  template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
6063
6666
 
6064
6667
  template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
@@ -6073,6 +6676,8 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at
6073
6676
  template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
6074
6677
  template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
6075
6678
  template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
6679
+ template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;
6680
+ template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 512, 512>;
6076
6681
  template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
6077
6682
 
6078
6683
  template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
@@ -6087,6 +6692,8 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at
6087
6692
  template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
6088
6693
  template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
6089
6694
  template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
6695
+ template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;
6696
+ template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 512, 512>;
6090
6697
  template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
6091
6698
 
6092
6699
  template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
@@ -6101,6 +6708,8 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at
6101
6708
  template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
6102
6709
  template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
6103
6710
  template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
6711
+ template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;
6712
+ template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 512, 512>;
6104
6713
  template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
6105
6714
 
6106
6715
  #undef FA_TYPES
@@ -6138,11 +6747,10 @@ template<
6138
6747
  void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
6139
6748
  short DK, // K head size
6140
6749
  short DV, // V head size
6141
- short NE, // head elements per thread
6142
- short Q, // queries per threadgroup
6143
- short C, // cache items per threadgroup
6144
- short NSG> // number of simd groups
6145
- void kernel_flash_attn_ext_vec_impl(
6750
+ short NE = 4, // head elements per thread
6751
+ short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup
6752
+ short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
6753
+ kernel void kernel_flash_attn_ext_vec(
6146
6754
  constant ggml_metal_kargs_flash_attn_ext_vec & args,
6147
6755
  device const char * q,
6148
6756
  device const char * k,
@@ -6159,6 +6767,7 @@ void kernel_flash_attn_ext_vec_impl(
6159
6767
  static_assert(DV % 32 == 0, "DV must be divisible by 32");
6160
6768
 
6161
6769
  #define NWG (FC_flash_attn_ext_vec_nwg)
6770
+ #define NSG (FC_flash_attn_ext_vec_nsg)
6162
6771
 
6163
6772
  #define NS10 (FC_flash_attn_ext_vec_ns10)
6164
6773
  #define NS20 (FC_flash_attn_ext_vec_ns20)
@@ -6185,14 +6794,14 @@ void kernel_flash_attn_ext_vec_impl(
6185
6794
  static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
6186
6795
  static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
6187
6796
 
6188
- const short T = PK + NSG*SH; // shared memory size per query in (half)
6797
+ //const short T = PK + NSG*SH; // shared memory size per query in (half)
6189
6798
 
6190
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
6191
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
6192
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention
6193
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t
6194
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
6195
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results
6799
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
6800
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
6801
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention
6802
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t
6803
+ threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
6804
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results
6196
6805
 
6197
6806
  // store the result for all queries in shared memory (the O matrix from the paper)
6198
6807
  so4 += tiisg;
@@ -6210,11 +6819,13 @@ void kernel_flash_attn_ext_vec_impl(
6210
6819
  // load heads from Q to shared memory
6211
6820
  device const float4 * q4 = (device const float4 *) ((device const char *) q);
6212
6821
 
6213
- for (short i = tiisg; i < PK4; i += NW) {
6214
- if (iq1 < args.ne01 && i < DK4) {
6215
- sq4[i] = (q4_t) q4[i];
6216
- } else {
6217
- sq4[i] = (q4_t) 0.0f;
6822
+ if (iq1 < args.ne01) {
6823
+ for (short i = tiisg; i < PK4; i += NW) {
6824
+ if (i < DK4) {
6825
+ sq4[i] = (q4_t) q4[i];
6826
+ } else {
6827
+ sq4[i] = (q4_t) 0.0f;
6828
+ }
6218
6829
  }
6219
6830
  }
6220
6831
 
@@ -6292,7 +6903,7 @@ void kernel_flash_attn_ext_vec_impl(
6292
6903
  }
6293
6904
 
6294
6905
  // skip -INF blocks
6295
- if (simd_max(sm[tiisg]) == -INFINITY) {
6906
+ if (simd_max(sm[tiisg]) <= -MAXHALF) {
6296
6907
  continue;
6297
6908
  }
6298
6909
 
@@ -6566,57 +7177,11 @@ void kernel_flash_attn_ext_vec_impl(
6566
7177
  }
6567
7178
 
6568
7179
  #undef NWG
7180
+ #undef NSG
6569
7181
  #undef NS10
6570
7182
  #undef NS20
6571
7183
  }
6572
7184
 
6573
- template<
6574
- typename q4_t, // query types in shared memory
6575
- typename k4_t, // key types in shared memory
6576
- typename v4_t, // value types in shared memory
6577
- typename qk_t, // Q*K types
6578
- typename s_t, // soft-max types
6579
- typename s4_t,
6580
- typename o4_t, // attention accumulation types
6581
- typename kd4_t, // key type in device memory
6582
- short nl_k,
6583
- void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
6584
- typename vd4_t, // value type in device memory
6585
- short nl_v,
6586
- void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
6587
- short DK, // K head size
6588
- short DV, // V head size
6589
- short NE = 4, // head elements per thread
6590
- short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
6591
- short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
6592
- kernel void kernel_flash_attn_ext_vec(
6593
- constant ggml_metal_kargs_flash_attn_ext_vec & args,
6594
- device const char * q,
6595
- device const char * k,
6596
- device const char * v,
6597
- device const char * mask,
6598
- device const char * sinks,
6599
- device const char * pad,
6600
- device char * dst,
6601
- threadgroup half * shmem_f16 [[threadgroup(0)]],
6602
- uint3 tgpig[[threadgroup_position_in_grid]],
6603
- ushort tiisg[[thread_index_in_simdgroup]],
6604
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6605
- #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
6606
- #define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
6607
- switch (FC_flash_attn_ext_vec_nsg) {
6608
- // note: disabled cases to reduce library load time
6609
- case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
6610
- case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
6611
- case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
6612
- //case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
6613
- //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
6614
- //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
6615
- }
6616
- #undef FWD_TMPL
6617
- #undef FWD_ARGS
6618
- }
6619
-
6620
7185
  // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
6621
7186
  // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
6622
7187
  //
@@ -6715,6 +7280,28 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
6715
7280
  template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
6716
7281
  template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
6717
7282
 
7283
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 320, 256, 2>;
7284
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 320, 256, 2>;
7285
+ #if defined(GGML_METAL_HAS_BF16)
7286
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 320, 256, 2>;
7287
+ #endif
7288
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 320, 256, 2>;
7289
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 320, 256, 2>;
7290
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 320, 256, 2>;
7291
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>;
7292
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>;
7293
+
7294
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 512, 512, 1>;
7295
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 512, 512, 1>;
7296
+ #if defined(GGML_METAL_HAS_BF16)
7297
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 512, 512, 1>;
7298
+ #endif
7299
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 512, 512, 1>;
7300
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 512, 512, 1>;
7301
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 512, 512, 1>;
7302
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 512, 512, 1>;
7303
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 512, 512, 1>;
7304
+
6718
7305
  template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
6719
7306
  template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
6720
7307
  #if defined(GGML_METAL_HAS_BF16)
@@ -6780,23 +7367,27 @@ kernel void kernel_cpy_t_t(
6780
7367
  device const char * src0,
6781
7368
  device char * dst,
6782
7369
  uint3 tgpig[[threadgroup_position_in_grid]],
6783
- ushort tiitg[[thread_index_in_threadgroup]],
7370
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6784
7371
  ushort3 ntg[[threads_per_threadgroup]]) {
6785
- const int i03 = tgpig[2];
6786
- const int i02 = tgpig[1];
6787
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6788
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7372
+ const int32_t i03 = tgpig[2];
7373
+ const int32_t i02 = tgpig[1];
7374
+ const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7375
+ const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7376
+
7377
+ if (i01 >= args.ne01) {
7378
+ return;
7379
+ }
6789
7380
 
6790
7381
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6791
7382
 
6792
- const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
6793
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
6794
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
6795
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
7383
+ const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
7384
+ const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7385
+ const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7386
+ const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
6796
7387
 
6797
7388
  device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6798
7389
 
6799
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
7390
+ for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
6800
7391
  device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
6801
7392
  dst_data[i00] = (T1) src[0];
6802
7393
  break;
@@ -6828,23 +7419,27 @@ kernel void kernel_cpy_f32_q(
6828
7419
  device const char * src0,
6829
7420
  device char * dst,
6830
7421
  uint3 tgpig[[threadgroup_position_in_grid]],
6831
- ushort tiitg[[thread_index_in_threadgroup]],
7422
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6832
7423
  ushort3 ntg[[threads_per_threadgroup]]) {
6833
- const int i03 = tgpig[2];
6834
- const int i02 = tgpig[1];
6835
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6836
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7424
+ const int32_t i03 = tgpig[2];
7425
+ const int32_t i02 = tgpig[1];
7426
+ const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7427
+ const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7428
+
7429
+ if (i01 >= args.ne01) {
7430
+ return;
7431
+ }
6837
7432
 
6838
7433
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6839
7434
 
6840
- const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
6841
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
6842
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
6843
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
7435
+ const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
7436
+ const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
7437
+ const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
7438
+ const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
6844
7439
 
6845
7440
  device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6846
7441
 
6847
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
7442
+ for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
6848
7443
  device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
6849
7444
 
6850
7445
  quantize_func(src, dst_data[i00]);
@@ -6856,6 +7451,7 @@ kernel void kernel_cpy_f32_q(
6856
7451
  typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
6857
7452
 
6858
7453
  template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
7454
+ template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
6859
7455
  template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
6860
7456
  template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
6861
7457
  template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
@@ -6868,24 +7464,28 @@ kernel void kernel_cpy_q_f32(
6868
7464
  device const char * src0,
6869
7465
  device char * dst,
6870
7466
  uint3 tgpig[[threadgroup_position_in_grid]],
6871
- ushort tiitg[[thread_index_in_threadgroup]],
7467
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6872
7468
  ushort3 ntg[[threads_per_threadgroup]]) {
6873
- const int i03 = tgpig[2];
6874
- const int i02 = tgpig[1];
6875
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6876
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7469
+ const int32_t i03 = tgpig[2];
7470
+ const int32_t i02 = tgpig[1];
7471
+ const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7472
+ const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7473
+
7474
+ if (i01 >= args.ne01) {
7475
+ return;
7476
+ }
6877
7477
 
6878
7478
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6879
7479
 
6880
- const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
6881
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
6882
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
6883
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
7480
+ const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
7481
+ const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7482
+ const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7483
+ const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
6884
7484
 
6885
7485
  device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
6886
7486
  device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6887
7487
 
6888
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
7488
+ for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
6889
7489
  T4x4 temp;
6890
7490
  dequantize_func(src_data + i00/nl, i00%nl, temp);
6891
7491
  dst_data[i00] = temp;
@@ -6896,12 +7496,14 @@ kernel void kernel_cpy_q_f32(
6896
7496
 
6897
7497
  typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
6898
7498
 
7499
+ template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
6899
7500
  template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
6900
7501
  template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
6901
7502
  template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
6902
7503
  template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
6903
7504
  template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
6904
7505
 
7506
+ template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
6905
7507
  template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
6906
7508
  template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
6907
7509
  template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
@@ -6919,7 +7521,11 @@ kernel void kernel_concat(
6919
7521
 
6920
7522
  const int i3 = tgpig.z;
6921
7523
  const int i2 = tgpig.y;
6922
- const int i1 = tgpig.x;
7524
+ const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
7525
+
7526
+ if (i1 >= args.ne1) {
7527
+ return;
7528
+ }
6923
7529
 
6924
7530
  int o[4] = {0, 0, 0, 0};
6925
7531
  o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
@@ -6959,10 +7565,10 @@ void kernel_mul_mv_q2_K_f32_impl(
6959
7565
 
6960
7566
  const int first_row = (r0 * NSG + sgitg) * nr0;
6961
7567
 
6962
- const uint i12 = im%args.ne12;
6963
- const uint i13 = im/args.ne12;
7568
+ const uint i12 = im%FC_mul_mv_ne12;
7569
+ const uint i13 = im/FC_mul_mv_ne12;
6964
7570
 
6965
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7571
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
6966
7572
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
6967
7573
 
6968
7574
  device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
@@ -7064,10 +7670,10 @@ void kernel_mul_mv_q3_K_f32_impl(
7064
7670
 
7065
7671
  const int first_row = (r0 * NSG + sgitg) * nr0;
7066
7672
 
7067
- const uint i12 = im%args.ne12;
7068
- const uint i13 = im/args.ne12;
7673
+ const uint i12 = im%FC_mul_mv_ne12;
7674
+ const uint i13 = im/FC_mul_mv_ne12;
7069
7675
 
7070
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7676
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7071
7677
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7072
7678
 
7073
7679
  device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
@@ -7238,10 +7844,10 @@ void kernel_mul_mv_q4_K_f32_impl(
7238
7844
 
7239
7845
  const int first_row = (r0 * NSG + sgitg) * nr0;
7240
7846
 
7241
- const uint i12 = im%args.ne12;
7242
- const uint i13 = im/args.ne12;
7847
+ const uint i12 = im%FC_mul_mv_ne12;
7848
+ const uint i13 = im/FC_mul_mv_ne12;
7243
7849
 
7244
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7850
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7245
7851
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7246
7852
 
7247
7853
  device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
@@ -7350,10 +7956,10 @@ void kernel_mul_mv_q5_K_f32_impl(
7350
7956
 
7351
7957
  const int first_row = (r0 * NSG + sgitg) * nr0;
7352
7958
 
7353
- const uint i12 = im%args.ne12;
7354
- const uint i13 = im/args.ne12;
7959
+ const uint i12 = im%FC_mul_mv_ne12;
7960
+ const uint i13 = im/FC_mul_mv_ne12;
7355
7961
 
7356
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7962
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7357
7963
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7358
7964
 
7359
7965
  device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
@@ -7486,10 +8092,10 @@ void kernel_mul_mv_q6_K_f32_impl(
7486
8092
 
7487
8093
  const int first_row = (r0 * NSG + sgitg) * nr0;
7488
8094
 
7489
- const uint i12 = im%args.ne12;
7490
- const uint i13 = im/args.ne12;
8095
+ const uint i12 = im%FC_mul_mv_ne12;
8096
+ const uint i13 = im/FC_mul_mv_ne12;
7491
8097
 
7492
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8098
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7493
8099
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7494
8100
 
7495
8101
  device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
@@ -7591,10 +8197,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
7591
8197
 
7592
8198
  const int first_row = (r0 * NSG + sgitg) * nr0;
7593
8199
 
7594
- const uint i12 = im%args.ne12;
7595
- const uint i13 = im/args.ne12;
8200
+ const uint i12 = im%FC_mul_mv_ne12;
8201
+ const uint i13 = im/FC_mul_mv_ne12;
7596
8202
 
7597
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8203
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7598
8204
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7599
8205
 
7600
8206
  device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
@@ -7699,10 +8305,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
7699
8305
 
7700
8306
  const int first_row = (r0 * NSG + sgitg) * nr0;
7701
8307
 
7702
- const uint i12 = im%args.ne12;
7703
- const uint i13 = im/args.ne12;
8308
+ const uint i12 = im%FC_mul_mv_ne12;
8309
+ const uint i13 = im/FC_mul_mv_ne12;
7704
8310
 
7705
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8311
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7706
8312
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7707
8313
 
7708
8314
  device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
@@ -7818,10 +8424,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
7818
8424
 
7819
8425
  const int first_row = (r0 * NSG + sgitg) * nr0;
7820
8426
 
7821
- const uint i12 = im%args.ne12;
7822
- const uint i13 = im/args.ne12;
8427
+ const uint i12 = im%FC_mul_mv_ne12;
8428
+ const uint i13 = im/FC_mul_mv_ne12;
7823
8429
 
7824
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8430
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7825
8431
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7826
8432
 
7827
8433
  device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
@@ -7930,10 +8536,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
7930
8536
 
7931
8537
  const int first_row = (r0 * NSG + sgitg) * nr0;
7932
8538
 
7933
- const uint i12 = im%args.ne12;
7934
- const uint i13 = im/args.ne12;
8539
+ const uint i12 = im%FC_mul_mv_ne12;
8540
+ const uint i13 = im/FC_mul_mv_ne12;
7935
8541
 
7936
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8542
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7937
8543
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7938
8544
 
7939
8545
  device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
@@ -8042,10 +8648,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
8042
8648
 
8043
8649
  const int first_row = (r0 * NSG + sgitg) * nr0;
8044
8650
 
8045
- const uint i12 = im%args.ne12;
8046
- const uint i13 = im/args.ne12;
8651
+ const uint i12 = im%FC_mul_mv_ne12;
8652
+ const uint i13 = im/FC_mul_mv_ne12;
8047
8653
 
8048
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8654
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8049
8655
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8050
8656
 
8051
8657
  device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
@@ -8155,10 +8761,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
8155
8761
 
8156
8762
  const int first_row = (r0 * NSG + sgitg) * nr0;
8157
8763
 
8158
- const uint i12 = im%args.ne12;
8159
- const uint i13 = im/args.ne12;
8764
+ const uint i12 = im%FC_mul_mv_ne12;
8765
+ const uint i13 = im/FC_mul_mv_ne12;
8160
8766
 
8161
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8767
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8162
8768
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8163
8769
 
8164
8770
  device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
@@ -8254,10 +8860,10 @@ void kernel_mul_mv_iq1_m_f32_impl(
8254
8860
 
8255
8861
  const int first_row = (r0 * NSG + sgitg) * nr0;
8256
8862
 
8257
- const uint i12 = im%args.ne12;
8258
- const uint i13 = im/args.ne12;
8863
+ const uint i12 = im%FC_mul_mv_ne12;
8864
+ const uint i13 = im/FC_mul_mv_ne12;
8259
8865
 
8260
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8866
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8261
8867
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8262
8868
 
8263
8869
  device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
@@ -8363,10 +8969,10 @@ void kernel_mul_mv_iq4_nl_f32_impl(
8363
8969
 
8364
8970
  const int first_row = (r0 * NSG + sgitg) * NR0;
8365
8971
 
8366
- const uint i12 = im%args.ne12;
8367
- const uint i13 = im/args.ne12;
8972
+ const uint i12 = im%FC_mul_mv_ne12;
8973
+ const uint i13 = im/FC_mul_mv_ne12;
8368
8974
 
8369
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8975
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8370
8976
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8371
8977
 
8372
8978
  device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
@@ -8472,10 +9078,10 @@ void kernel_mul_mv_iq4_xs_f32_impl(
8472
9078
  const int im = tgpig.z;
8473
9079
  const int first_row = (r0 * NSG + sgitg) * NR0;
8474
9080
 
8475
- const uint i12 = im%args.ne12;
8476
- const uint i13 = im/args.ne12;
9081
+ const uint i12 = im%FC_mul_mv_ne12;
9082
+ const uint i13 = im/FC_mul_mv_ne12;
8477
9083
 
8478
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
9084
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8479
9085
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8480
9086
 
8481
9087
  device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
@@ -8583,10 +9189,10 @@ void kernel_mul_mv_mxfp4_f32_impl(
8583
9189
 
8584
9190
  const int first_row = (r0 * NSG + sgitg) * NR0;
8585
9191
 
8586
- const uint i12 = im%args.ne12;
8587
- const uint i13 = im/args.ne12;
9192
+ const uint i12 = im%FC_mul_mv_ne12;
9193
+ const uint i13 = im/FC_mul_mv_ne12;
8588
9194
 
8589
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
9195
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8590
9196
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8591
9197
 
8592
9198
  device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
@@ -8779,11 +9385,165 @@ kernel void kernel_set_rows_f(
8779
9385
  }
8780
9386
  }
8781
9387
 
9388
+ kernel void kernel_diag_f32(
9389
+ constant ggml_metal_kargs_diag & args,
9390
+ device const char * src0,
9391
+ device char * dst,
9392
+ uint3 tgpig[[threadgroup_position_in_grid]],
9393
+ ushort tiitg[[thread_index_in_threadgroup]]) {
9394
+ constexpr short NW = N_SIMDWIDTH;
9395
+
9396
+ const int32_t i3 = tgpig.z;
9397
+ const int32_t i2 = tgpig.y;
9398
+ const int32_t i1 = tgpig.x;
9399
+
9400
+ device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
9401
+ device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
9402
+
9403
+ for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
9404
+ dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
9405
+ }
9406
+ }
9407
+
8782
9408
  constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
8783
9409
  constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
9410
+ constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
9411
+ constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
9412
+ constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
9413
+ constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
8784
9414
 
8785
9415
  // each block_q contains 16*nl weights
8786
- template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
9416
+ #ifdef GGML_METAL_HAS_TENSOR
9417
+ template<
9418
+ typename SA, typename SA_4x4, typename SA_8x8,
9419
+ typename SB, typename SB_2x4, typename SB_8x8,
9420
+ typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
9421
+ typename T0, typename T0_4x4, typename T1, typename T1_2x4>
9422
+ kernel void kernel_mul_mm(
9423
+ constant ggml_metal_kargs_mul_mm & args,
9424
+ device const char * srcA,
9425
+ device const char * srcB,
9426
+ device char * dst,
9427
+ threadgroup char * shmem [[threadgroup(0)]],
9428
+ uint3 tgpig [[threadgroup_position_in_grid]],
9429
+ ushort tiitg [[thread_index_in_threadgroup]],
9430
+ ushort sgitg [[simdgroup_index_in_threadgroup]]) {
9431
+ (void) sgitg;
9432
+
9433
+ // Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
9434
+ const int K = args.ne00;
9435
+ const int M = args.ne0;
9436
+ const int N = args.ne1;
9437
+
9438
+ // Batch dimension handling
9439
+ const int im = tgpig.z;
9440
+ const int i12 = im % FC_mul_mm_ne12;
9441
+ const int i13 = im / FC_mul_mm_ne12;
9442
+
9443
+ // Batch offsets for srcA and srcB
9444
+ const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
9445
+
9446
+ // Tile dimensions
9447
+ constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
9448
+ constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
9449
+
9450
+ // Tile offsets in output matrix
9451
+ const int ra = tgpig.y * NRA;
9452
+ const int rb = tgpig.x * NRB;
9453
+
9454
+ // Threadgroup memory for dequantized A tile only
9455
+ threadgroup SA * sa = (threadgroup SA *)(shmem);
9456
+
9457
+ // Work-item count for A loading
9458
+ constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
9459
+ constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
9460
+
9461
+ // tA wraps threadgroup memory
9462
+ auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
9463
+
9464
+ // tB wraps device memory directly
9465
+ device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
9466
+ const int strideB = args.nb11 / sizeof(T1);
9467
+ auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
9468
+
9469
+ // Configure matmul operation
9470
+ mpp::tensor_ops::matmul2d<
9471
+ mpp::tensor_ops::matmul2d_descriptor(
9472
+ NRB, NRA, N_MM_NK_TOTAL, false, true, true,
9473
+ mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
9474
+ execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
9475
+
9476
+ auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
9477
+
9478
+ // Accumulate partial results over K dimension
9479
+ for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
9480
+ // === PHASE 1: Dequantization of A into threadgroup memory ===
9481
+ for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
9482
+ const int row = work / N_MM_NK;
9483
+ const int k_chunk = work % N_MM_NK;
9484
+ const int k_pos = loop_k + k_chunk * 16;
9485
+ const short k_base = k_chunk * 16;
9486
+
9487
+ // Bounds check: skip device read if row is out of matrix bounds
9488
+ if (ra + row < M) {
9489
+ if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
9490
+ // Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
9491
+ // MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
9492
+ // nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
9493
+ // Mirrors the legacy kernel's existing guard.
9494
+ device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
9495
+
9496
+ FOR_UNROLL (short i = 0; i < 16; i++) {
9497
+ sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
9498
+ }
9499
+ } else {
9500
+ const int block_idx = k_pos / (16 * nl);
9501
+ const short il = (k_pos / 16) % nl;
9502
+
9503
+ device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
9504
+
9505
+ SA_4x4 temp_a;
9506
+ dequantize_func(row_ptr + block_idx, il, temp_a);
9507
+
9508
+ FOR_UNROLL (short i = 0; i < 16; i++) {
9509
+ // Zero-pad A for K positions beyond valid range (handles partial K iterations)
9510
+ sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
9511
+ }
9512
+ }
9513
+ } else {
9514
+ // Zero-pad rows beyond matrix bounds
9515
+ FOR_UNROLL (short i = 0; i < 16; i++) {
9516
+ sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
9517
+ }
9518
+ }
9519
+ }
9520
+
9521
+ threadgroup_barrier(mem_flags::mem_threadgroup);
9522
+
9523
+ // === PHASE 2: Tensor matmul ===
9524
+ auto mA = tA.slice(0, 0);
9525
+ auto mB = tB.slice(loop_k, rb);
9526
+
9527
+ mm.run(mB, mA, cT);
9528
+
9529
+ threadgroup_barrier(mem_flags::mem_threadgroup);
9530
+ }
9531
+
9532
+ // Store result tile to output matrix (with batch offset)
9533
+ // cT.store handles bounds checking via tD's extents (M, N)
9534
+ device float * dstBatch = (device float *)dst + im * N * M;
9535
+
9536
+ auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
9537
+ cT.store(tD.slice(ra, rb));
9538
+ }
9539
+
9540
+ #else
9541
+
9542
+ template<
9543
+ typename S0, typename S0_4x4, typename S0_8x8,
9544
+ typename S1, typename S1_2x4, typename S1_8x8,
9545
+ typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
9546
+ typename T0, typename T0_4x4, typename T1, typename T1_2x4>
8787
9547
  kernel void kernel_mul_mm(
8788
9548
  constant ggml_metal_kargs_mul_mm & args,
8789
9549
  device const char * src0,
@@ -8797,8 +9557,6 @@ kernel void kernel_mul_mm(
8797
9557
  threadgroup S0 * sa = (threadgroup S0 *)(shmem);
8798
9558
  threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
8799
9559
 
8800
- threadgroup float * sc = (threadgroup float *)(shmem);
8801
-
8802
9560
  constexpr int NR0 = 64;
8803
9561
  constexpr int NR1 = 32;
8804
9562
 
@@ -8822,10 +9580,10 @@ kernel void kernel_mul_mm(
8822
9580
 
8823
9581
  short il = il0;
8824
9582
 
8825
- const int i12 = im%args.ne12;
8826
- const int i13 = im/args.ne12;
9583
+ const int i12 = im % FC_mul_mm_ne12;
9584
+ const int i13 = im / FC_mul_mm_ne12;
8827
9585
 
8828
- const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
9586
+ const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
8829
9587
  const short offset1 = il0/nl;
8830
9588
 
8831
9589
  device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
@@ -8838,7 +9596,6 @@ kernel void kernel_mul_mm(
8838
9596
  + args.nb11*(r1 + lr1)
8839
9597
  + args.nb10*iy);
8840
9598
 
8841
- #ifndef GGML_METAL_HAS_TENSOR
8842
9599
  S0_8x8 ma[4];
8843
9600
  S1_8x8 mb[2];
8844
9601
 
@@ -8847,19 +9604,8 @@ kernel void kernel_mul_mm(
8847
9604
  for (short i = 0; i < 8; i++){
8848
9605
  mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8849
9606
  }
8850
- #else
8851
- auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8852
- auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8853
-
8854
- mpp::tensor_ops::matmul2d<
8855
- mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8856
- execution_simdgroups<4>> mm;
8857
-
8858
- auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8859
- #endif
8860
9607
 
8861
9608
  for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8862
- #ifndef GGML_METAL_HAS_TENSOR
8863
9609
  // load data and store to threadgroup memory
8864
9610
  if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8865
9611
  threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -8920,8 +9666,8 @@ kernel void kernel_mul_mm(
8920
9666
  const short sx = (tiitg%NL1);
8921
9667
  const short sy = (tiitg/NL1)/8;
8922
9668
 
8923
- const short dx = sx;
8924
- const short dy = sy;
9669
+ //const short dx = sx;
9670
+ //const short dy = sy;
8925
9671
 
8926
9672
  const short ly = (tiitg/NL1)%8;
8927
9673
 
@@ -8929,66 +9675,6 @@ kernel void kernel_mul_mm(
8929
9675
 
8930
9676
  *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
8931
9677
  }
8932
- #else
8933
- // load data and store to threadgroup memory
8934
- if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8935
- threadgroup_barrier(mem_flags::mem_threadgroup);
8936
-
8937
- // no need for dequantization
8938
- for (short i = 0; i < 16; i++) {
8939
- const short sx = 2*il0 + i/8;
8940
- const short sy = (tiitg/NL0)/8;
8941
-
8942
- const short lx = i%8;
8943
- const short ly = (tiitg/NL0)%8;
8944
- //const short lx = (tiitg/NL0)%8;
8945
- //const short ly = i%8;
8946
-
8947
- *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8948
- }
8949
- } else {
8950
- S0_4x4 temp_a;
8951
- dequantize_func(x, il, temp_a);
8952
-
8953
- threadgroup_barrier(mem_flags::mem_threadgroup);
8954
-
8955
- FOR_UNROLL (short i = 0; i < 16; i++) {
8956
- const short sx = 2*il0 + i/8;
8957
- const short sy = (tiitg/NL0)/8;
8958
-
8959
- const short lx = i%8;
8960
- const short ly = (tiitg/NL0)%8;
8961
- //const short lx = (tiitg/NL0)%8;
8962
- //const short ly = i%8;
8963
-
8964
- *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
8965
- }
8966
- }
8967
-
8968
- if (FC_mul_mm_bc_inp) {
8969
- for (short i = 0; i < 8; ++i) {
8970
- const short sx = (tiitg%NL1);
8971
- const short sy = (tiitg/NL1)/8;
8972
-
8973
- const short lx = i;
8974
- const short ly = (tiitg/NL1)%8;
8975
- //const short lx = (tiitg/NL1)%8;
8976
- //const short ly = i;
8977
-
8978
- *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8979
- }
8980
- } else {
8981
- const short sx = (tiitg%NL1);
8982
- const short sy = (tiitg/NL1)/8;
8983
-
8984
- //const short lx = i;
8985
- const short ly = (tiitg/NL1)%8;
8986
- //const short lx = (tiitg/NL1)%8;
8987
- //const short ly = i;
8988
-
8989
- *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
8990
- }
8991
- #endif
8992
9678
 
8993
9679
  il = (il + 2 < nl) ? il + 2 : il % 2;
8994
9680
  x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -8997,7 +9683,6 @@ kernel void kernel_mul_mm(
8997
9683
 
8998
9684
  threadgroup_barrier(mem_flags::mem_threadgroup);
8999
9685
 
9000
- #ifndef GGML_METAL_HAS_TENSOR
9001
9686
  // load matrices from threadgroup memory and conduct outer products
9002
9687
  threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
9003
9688
  threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
@@ -9024,24 +9709,10 @@ kernel void kernel_mul_mm(
9024
9709
  lsma += 8*64;
9025
9710
  lsmb += 4*64;
9026
9711
  }
9027
- #else
9028
- auto sA = tA.slice(0, 0);
9029
- auto sB = tB.slice(0, 0);
9030
-
9031
- mm.run(sB, sA, cT);
9032
- #endif
9033
9712
  }
9034
9713
 
9035
9714
  if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
9036
9715
  // if no bounds checks on the output are needed, we can directly write to device memory
9037
- #ifdef GGML_METAL_HAS_TENSOR
9038
- device float * C = (device float *) dst +
9039
- r0 + \
9040
- r1 * args.ne0 + im*args.ne1*args.ne0;
9041
-
9042
- auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
9043
- cT.store(tC);
9044
- #else
9045
9716
  device float * C = (device float *) dst +
9046
9717
  (r0 + 32*(sgitg & 1)) + \
9047
9718
  (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
@@ -9049,21 +9720,15 @@ kernel void kernel_mul_mm(
9049
9720
  for (short i = 0; i < 8; i++) {
9050
9721
  simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
9051
9722
  }
9052
- #endif
9053
9723
  } else {
9054
9724
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
9055
9725
  threadgroup_barrier(mem_flags::mem_threadgroup);
9056
9726
 
9057
9727
  threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
9058
9728
 
9059
- #ifdef GGML_METAL_HAS_TENSOR
9060
- auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
9061
- cT.store(tC);
9062
- #else
9063
9729
  for (short i = 0; i < 8; i++) {
9064
9730
  simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
9065
9731
  }
9066
- #endif
9067
9732
 
9068
9733
  threadgroup_barrier(mem_flags::mem_threadgroup);
9069
9734
 
@@ -9089,6 +9754,8 @@ kernel void kernel_mul_mm(
9089
9754
  }
9090
9755
  }
9091
9756
 
9757
+ #endif // GGML_METAL_HAS_TENSOR
9758
+
9092
9759
  template<short ne20> // n_expert_used
9093
9760
  kernel void kernel_mul_mm_id_map0(
9094
9761
  constant ggml_metal_kargs_mul_mm_id_map0 & args,
@@ -9153,6 +9820,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
9153
9820
  template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
9154
9821
  template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
9155
9822
  template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
9823
+ template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
9156
9824
 
9157
9825
  template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
9158
9826
  kernel void kernel_mul_mm_id(
@@ -9170,7 +9838,9 @@ kernel void kernel_mul_mm_id(
9170
9838
  threadgroup S0 * sa = (threadgroup S0 *)(shmem);
9171
9839
  threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
9172
9840
 
9841
+ #ifdef GGML_METAL_HAS_TENSOR
9173
9842
  threadgroup float * sc = (threadgroup float *)(shmem);
9843
+ #endif
9174
9844
 
9175
9845
  constexpr int NR0 = 64;
9176
9846
  constexpr int NR1 = 32;
@@ -9261,7 +9931,7 @@ kernel void kernel_mul_mm_id(
9261
9931
 
9262
9932
  const short ib = 8*sx + sy;
9263
9933
 
9264
- *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
9934
+ *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
9265
9935
  }
9266
9936
  } else {
9267
9937
  S0_4x4 temp_a;
@@ -9305,8 +9975,8 @@ kernel void kernel_mul_mm_id(
9305
9975
  const short sx = (tiitg%NL1);
9306
9976
  const short sy = (tiitg/NL1)/8;
9307
9977
 
9308
- const short dx = sx;
9309
- const short dy = sy;
9978
+ //const short dx = sx;
9979
+ //const short dy = sy;
9310
9980
 
9311
9981
  const short ly = (tiitg/NL1)%8;
9312
9982
 
@@ -9474,6 +10144,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
9474
10144
 
9475
10145
  typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
9476
10146
 
10147
+ template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
9477
10148
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
9478
10149
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
9479
10150
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
@@ -9536,6 +10207,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
9536
10207
  #if defined(GGML_METAL_HAS_BF16)
9537
10208
  template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
9538
10209
  #endif
10210
+ template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
9539
10211
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
9540
10212
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
9541
10213
  template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
@@ -9559,6 +10231,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
9559
10231
 
9560
10232
  template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
9561
10233
  template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
10234
+ template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
9562
10235
  template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
9563
10236
  template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
9564
10237
  template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -9591,6 +10264,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m
9591
10264
  #if defined(GGML_METAL_HAS_BF16)
9592
10265
  template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
9593
10266
  #endif
10267
+ template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
9594
10268
  template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
9595
10269
  template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
9596
10270
  template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
@@ -9614,6 +10288,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
9614
10288
 
9615
10289
  template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
9616
10290
  template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
10291
+ template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
9617
10292
  template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
9618
10293
  template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
9619
10294
  template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -9768,6 +10443,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
9768
10443
 
9769
10444
  template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
9770
10445
 
10446
+ template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
9771
10447
  template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
9772
10448
  template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
9773
10449
  template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
@@ -9869,6 +10545,74 @@ kernel void kernel_pool_2d_avg_f32(
9869
10545
  o_ptr[cur_oh * args.OW + cur_ow] = res;
9870
10546
  }
9871
10547
 
10548
+
10549
+ kernel void kernel_pool_1d_max_f32(
10550
+ constant ggml_metal_kargs_pool_1d & args,
10551
+ device const float * src,
10552
+ device float * dst,
10553
+ uint gid [[thread_position_in_grid]]
10554
+ ) {
10555
+
10556
+ if (gid >= args.np) {
10557
+ return;
10558
+ }
10559
+
10560
+ const int ow = (int)gid % args.OW;
10561
+ const int row = (int)gid / args.OW;
10562
+
10563
+ const int base = ow * args.s0 - args.p0;
10564
+
10565
+ float acc = -INFINITY;
10566
+
10567
+ const int src_off = row * args.IW;
10568
+ const int dst_off = row * args.OW;
10569
+
10570
+ for (int ki = 0; ki < args.k0; ++ki) {
10571
+ int j = base + ki;
10572
+ if (j < 0 || j >= args.IW){
10573
+ continue;
10574
+ }
10575
+ float v = src[src_off + j];
10576
+ acc = max(acc, v);
10577
+ }
10578
+
10579
+ dst[dst_off + ow] = acc;
10580
+ }
10581
+
10582
+ kernel void kernel_pool_1d_avg_f32(
10583
+ constant ggml_metal_kargs_pool_1d & args,
10584
+ device const float * src,
10585
+ device float * dst,
10586
+ uint gid [[thread_position_in_grid]]
10587
+ ) {
10588
+
10589
+ if (gid >= args.np) {
10590
+ return;
10591
+ }
10592
+
10593
+ const int ow = (int)gid % args.OW;
10594
+ const int row = (int)gid / args.OW;
10595
+
10596
+ const int base = ow * args.s0 - args.p0;
10597
+
10598
+ float acc = 0.0f;
10599
+ int cnt = 0;
10600
+
10601
+ const int src_off = row * args.IW;
10602
+ const int dst_off = row * args.OW;
10603
+
10604
+ for (int ki = 0; ki < args.k0; ++ki) {
10605
+ const int j = base + ki;
10606
+ if (j < 0 || j >= args.IW) {
10607
+ continue;
10608
+ }
10609
+ acc += src[src_off + j];
10610
+ cnt += 1;
10611
+ }
10612
+
10613
+ dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
10614
+ }
10615
+
9872
10616
  kernel void kernel_opt_step_adamw_f32(
9873
10617
  constant ggml_metal_kargs_opt_step_adamw & args,
9874
10618
  device float * x,
@@ -9919,7 +10663,7 @@ kernel void kernel_opt_step_sgd_f32(
9919
10663
 
9920
10664
  template<typename T>
9921
10665
  kernel void kernel_memset(
9922
- constant ggml_metal_kargs_fill & args,
10666
+ constant ggml_metal_kargs_memset & args,
9923
10667
  device T * dst,
9924
10668
  uint tpig[[thread_position_in_grid]]) {
9925
10669
  dst[tpig] = args.val;