whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
118
118
  }
119
119
  #endif
120
120
 
121
+ template <typename type4x4>
122
+ void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
123
+ device const uint8_t * qs = xb->qs;
124
+ const float d = xb->d;
125
+ const float neg_d = -d;
126
+
127
+ const int byte_offset = il * 2; // il*16 bits = il*2 bytes
128
+ const uint8_t b0 = qs[byte_offset];
129
+ const uint8_t b1 = qs[byte_offset + 1];
130
+
131
+ float4x4 reg_f;
132
+
133
+ reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
134
+ reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
135
+ reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
136
+ reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
137
+ reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
138
+ reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
139
+ reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
140
+ reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
141
+
142
+ reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
143
+ reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
144
+ reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
145
+ reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
146
+ reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
147
+ reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
148
+ reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
149
+ reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
150
+
151
+ reg = (type4x4) reg_f;
152
+ }
153
+
154
+ template <typename type4>
155
+ void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
156
+ const float d = xb->d;
157
+ const float neg_d = -d;
158
+ const int base = il * 4;
159
+ const uint8_t byte = xb->qs[base / 8];
160
+ const int s = base % 8;
161
+
162
+ float4 reg_f;
163
+ reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
164
+ reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
165
+ reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
166
+ reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
167
+
168
+ reg = (type4) reg_f;
169
+ }
170
+
121
171
  template <typename type4x4>
122
172
  void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
123
173
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
152
202
  }
153
203
  }
154
204
 
205
+ void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
206
+ float sum_abs = 0.0f;
207
+ for (int j = 0; j < QK1_0; j++) {
208
+ sum_abs += fabs(src[j]);
209
+ }
210
+ dst.d = sum_abs / QK1_0;
211
+
212
+ for (int j = 0; j < QK1_0 / 8; j++) {
213
+ dst.qs[j] = 0;
214
+ }
215
+ for (int j = 0; j < QK1_0; j++) {
216
+ if (src[j] >= 0.0f) {
217
+ dst.qs[j / 8] |= (1 << (j % 8));
218
+ }
219
+ }
220
+ }
221
+
155
222
  void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
156
223
  #pragma METAL fp math_mode(safe)
157
224
  float amax = 0.0f; // absolute max
@@ -1094,6 +1161,31 @@ kernel void kernel_unary_impl(
1094
1161
  // TODO: precise implementation
1095
1162
  dst_ptr[i0] = (T) (exp(x) - 1);
1096
1163
  }
1164
+
1165
+ if (FC_OP == OP_UNARY_NUM_FLOOR) {
1166
+ dst_ptr[i0] = (T) floor(x);
1167
+ }
1168
+
1169
+ if (FC_OP == OP_UNARY_NUM_CEIL) {
1170
+ dst_ptr[i0] = (T) ceil(x);
1171
+ }
1172
+
1173
+ if (FC_OP == OP_UNARY_NUM_ROUND) {
1174
+ dst_ptr[i0] = (T) round(x);
1175
+ }
1176
+
1177
+ if (FC_OP == OP_UNARY_NUM_TRUNC) {
1178
+ dst_ptr[i0] = (T) trunc(x);
1179
+ }
1180
+
1181
+ if (FC_OP == OP_UNARY_NUM_XIELU) {
1182
+ const TC xi = x;
1183
+ const TC gate = TC(xi > TC(0.0f));
1184
+ const TC clamped = fmin(xi, TC(args.val));
1185
+ const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
1186
+ const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
1187
+ dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
1188
+ }
1097
1189
  }
1098
1190
 
1099
1191
  #undef FC_OP
@@ -1329,7 +1421,8 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
1329
1421
  template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
1330
1422
  template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
1331
1423
 
1332
- kernel void kernel_reglu_f32(
1424
+ template<typename T>
1425
+ kernel void kernel_reglu(
1333
1426
  constant ggml_metal_kargs_glu & args,
1334
1427
  device const char * src0,
1335
1428
  device const char * src1,
@@ -1337,19 +1430,25 @@ kernel void kernel_reglu_f32(
1337
1430
  uint tgpig[[threadgroup_position_in_grid]],
1338
1431
  uint tpitg[[thread_position_in_threadgroup]],
1339
1432
  uint ntg[[threads_per_threadgroup]]) {
1340
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1341
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1342
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1433
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1434
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1435
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1343
1436
 
1344
1437
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1345
1438
  const float x0 = src0_row[i0];
1346
1439
  const float x1 = src1_row[i0];
1347
1440
 
1348
- dst_row[i0] = x0*x1*(x0 > 0.0f);
1441
+ dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
1349
1442
  }
1350
1443
  }
1351
1444
 
1352
- kernel void kernel_geglu_f32(
1445
+ typedef decltype(kernel_reglu<float>) kernel_reglu_t;
1446
+
1447
+ template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
1448
+ template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
1449
+
1450
+ template<typename T>
1451
+ kernel void kernel_geglu(
1353
1452
  constant ggml_metal_kargs_glu & args,
1354
1453
  device const char * src0,
1355
1454
  device const char * src1,
@@ -1357,9 +1456,9 @@ kernel void kernel_geglu_f32(
1357
1456
  uint tgpig[[threadgroup_position_in_grid]],
1358
1457
  uint tpitg[[thread_position_in_threadgroup]],
1359
1458
  uint ntg[[threads_per_threadgroup]]) {
1360
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1361
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1362
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1459
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1460
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1461
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1363
1462
 
1364
1463
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1365
1464
  const float x0 = src0_row[i0];
@@ -1367,11 +1466,17 @@ kernel void kernel_geglu_f32(
1367
1466
 
1368
1467
  const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1369
1468
 
1370
- dst_row[i0] = gelu*x1;
1469
+ dst_row[i0] = (T)(gelu*x1);
1371
1470
  }
1372
1471
  }
1373
1472
 
1374
- kernel void kernel_swiglu_f32(
1473
+ typedef decltype(kernel_geglu<float>) kernel_geglu_t;
1474
+
1475
+ template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
1476
+ template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
1477
+
1478
+ template<typename T>
1479
+ kernel void kernel_swiglu(
1375
1480
  constant ggml_metal_kargs_glu & args,
1376
1481
  device const char * src0,
1377
1482
  device const char * src1,
@@ -1379,9 +1484,9 @@ kernel void kernel_swiglu_f32(
1379
1484
  uint tgpig[[threadgroup_position_in_grid]],
1380
1485
  uint tpitg[[thread_position_in_threadgroup]],
1381
1486
  uint ntg[[threads_per_threadgroup]]) {
1382
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1383
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1384
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1487
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1488
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1489
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1385
1490
 
1386
1491
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1387
1492
  const float x0 = src0_row[i0];
@@ -1389,11 +1494,17 @@ kernel void kernel_swiglu_f32(
1389
1494
 
1390
1495
  const float silu = x0 / (1.0f + exp(-x0));
1391
1496
 
1392
- dst_row[i0] = silu*x1;
1497
+ dst_row[i0] = (T)(silu*x1);
1393
1498
  }
1394
1499
  }
1395
1500
 
1396
- kernel void kernel_swiglu_oai_f32(
1501
+ typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
1502
+
1503
+ template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
1504
+ template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
1505
+
1506
+ template<typename T>
1507
+ kernel void kernel_swiglu_oai(
1397
1508
  constant ggml_metal_kargs_glu & args,
1398
1509
  device const char * src0,
1399
1510
  device const char * src1,
@@ -1401,9 +1512,9 @@ kernel void kernel_swiglu_oai_f32(
1401
1512
  uint tgpig[[threadgroup_position_in_grid]],
1402
1513
  uint tpitg[[thread_position_in_threadgroup]],
1403
1514
  uint ntg[[threads_per_threadgroup]]) {
1404
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1405
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1406
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1515
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1516
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1517
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1407
1518
 
1408
1519
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1409
1520
  float x0 = src0_row[i0];
@@ -1415,11 +1526,17 @@ kernel void kernel_swiglu_oai_f32(
1415
1526
  float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
1416
1527
  out_glu = out_glu * (1.0f + x1);
1417
1528
 
1418
- dst_row[i0] = out_glu;
1529
+ dst_row[i0] = (T)out_glu;
1419
1530
  }
1420
1531
  }
1421
1532
 
1422
- kernel void kernel_geglu_erf_f32(
1533
+ typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
1534
+
1535
+ template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
1536
+ template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
1537
+
1538
+ template<typename T>
1539
+ kernel void kernel_geglu_erf(
1423
1540
  constant ggml_metal_kargs_glu & args,
1424
1541
  device const char * src0,
1425
1542
  device const char * src1,
@@ -1427,9 +1544,9 @@ kernel void kernel_geglu_erf_f32(
1427
1544
  uint tgpig[[threadgroup_position_in_grid]],
1428
1545
  uint tpitg[[thread_position_in_threadgroup]],
1429
1546
  uint ntg[[threads_per_threadgroup]]) {
1430
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1431
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1432
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1547
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1548
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1549
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1433
1550
 
1434
1551
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1435
1552
  const float x0 = src0_row[i0];
@@ -1437,11 +1554,17 @@ kernel void kernel_geglu_erf_f32(
1437
1554
 
1438
1555
  const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
1439
1556
 
1440
- dst_row[i0] = gelu_erf*x1;
1557
+ dst_row[i0] = (T)(gelu_erf*x1);
1441
1558
  }
1442
1559
  }
1443
1560
 
1444
- kernel void kernel_geglu_quick_f32(
1561
+ typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
1562
+
1563
+ template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
1564
+ template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
1565
+
1566
+ template<typename T>
1567
+ kernel void kernel_geglu_quick(
1445
1568
  constant ggml_metal_kargs_glu & args,
1446
1569
  device const char * src0,
1447
1570
  device const char * src1,
@@ -1449,9 +1572,9 @@ kernel void kernel_geglu_quick_f32(
1449
1572
  uint tgpig[[threadgroup_position_in_grid]],
1450
1573
  uint tpitg[[thread_position_in_threadgroup]],
1451
1574
  uint ntg[[threads_per_threadgroup]]) {
1452
- device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1453
- device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1454
- device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1575
+ device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1576
+ device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1577
+ device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
1455
1578
 
1456
1579
  for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1457
1580
  const float x0 = src0_row[i0];
@@ -1459,10 +1582,15 @@ kernel void kernel_geglu_quick_f32(
1459
1582
 
1460
1583
  const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
1461
1584
 
1462
- dst_row[i0] = gelu_quick*x1;
1585
+ dst_row[i0] = (T)(gelu_quick*x1);
1463
1586
  }
1464
1587
  }
1465
1588
 
1589
+ typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
1590
+
1591
+ template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
1592
+ template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
1593
+
1466
1594
  kernel void kernel_op_sum_f32(
1467
1595
  constant ggml_metal_kargs_sum & args,
1468
1596
  device const float * src0,
@@ -2439,6 +2567,7 @@ kernel void kernel_rwkv_wkv7_f32(
2439
2567
 
2440
2568
  constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
2441
2569
  constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
2570
+ constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
2442
2571
 
2443
2572
  #if 1
2444
2573
  template<short NSG>
@@ -2456,21 +2585,24 @@ kernel void kernel_gated_delta_net_impl(
2456
2585
  uint3 ntg[[threads_per_threadgroup]]) {
2457
2586
  #define S_v FC_gated_delta_net_ne20
2458
2587
  #define G FC_gated_delta_net_ne30
2588
+ #define K FC_gated_delta_net_K
2459
2589
 
2460
2590
  const uint tx = tpitg.x;
2461
2591
  const uint ty = tpitg.y;
2462
2592
 
2463
- const uint i23 = tgpig.z; // B
2464
- const uint i21 = tgpig.y; // H
2465
- const uint i20 = tgpig.x*NSG + ty;
2593
+ const uint i23 = tgpig.z; // B (n_seqs)
2594
+ const uint i21 = tgpig.y; // H (head)
2595
+ const uint i20 = tgpig.x*NSG + ty; // row within S_v
2466
2596
 
2467
2597
  const uint i01 = i21 % args.ne01;
2468
2598
  const uint i11 = i21 % args.ne11;
2469
2599
 
2470
2600
  const float scale = 1.0f / sqrt((float)S_v);
2471
2601
 
2602
+ // input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
2472
2603
  // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
2473
- device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
2604
+ const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
2605
+ device const float * s_ptr = (device const float *) (s) + state_in_base;
2474
2606
 
2475
2607
  float ls[NSG];
2476
2608
 
@@ -2488,6 +2620,16 @@ kernel void kernel_gated_delta_net_impl(
2488
2620
  device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
2489
2621
  device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
2490
2622
 
2623
+ // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
2624
+ // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
2625
+
2626
+ // output state base offset: after attention scores
2627
+ const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
2628
+ // output state per-slot size: S_v * S_v * H * n_seqs
2629
+ const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
2630
+ // per-(seq,head) offset within a slot
2631
+ const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
2632
+
2491
2633
  for (short t = 0; t < args.ne22; t++) {
2492
2634
  float s_k = 0.0f;
2493
2635
 
@@ -2535,17 +2677,30 @@ kernel void kernel_gated_delta_net_impl(
2535
2677
 
2536
2678
  b_ptr += args.ne21;
2537
2679
  g_ptr += args.ne21*G;
2538
- }
2539
2680
 
2540
- device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
2681
+ if (K > 1) {
2682
+ const int target_slot = (int)args.ne22 - 1 - (int)t;
2683
+ if (target_slot >= 0 && target_slot < (int)K) {
2684
+ device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
2685
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2686
+ const short is = tx*NSG + j;
2687
+ dst_state[is] = ls[j];
2688
+ }
2689
+ }
2690
+ }
2691
+ }
2541
2692
 
2542
- FOR_UNROLL (short j = 0; j < NSG; j++) {
2543
- const short is = tx*NSG + j;
2544
- dst_state[is] = ls[j];
2693
+ if (K == 1) {
2694
+ device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
2695
+ FOR_UNROLL (short j = 0; j < NSG; j++) {
2696
+ const short is = tx*NSG + j;
2697
+ dst_state[is] = ls[j];
2698
+ }
2545
2699
  }
2546
2700
 
2547
2701
  #undef S_v
2548
2702
  #undef G
2703
+ #undef K
2549
2704
  }
2550
2705
 
2551
2706
  typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
@@ -3100,6 +3255,35 @@ kernel void kernel_group_norm_f32(
3100
3255
  }
3101
3256
  }
3102
3257
 
3258
+ // Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
3259
+ inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
3260
+ device const uint8_t * qs = qb_curr->qs + il / 8;
3261
+ const uint8_t b0 = qs[0];
3262
+ const uint8_t b1 = qs[1];
3263
+
3264
+ float acc = 0.0f;
3265
+
3266
+ acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
3267
+ acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
3268
+ acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
3269
+ acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
3270
+ acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
3271
+ acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
3272
+ acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
3273
+ acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
3274
+
3275
+ acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
3276
+ acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
3277
+ acc += select(0.0f, yl[10], bool(b1 & 0x04));
3278
+ acc += select(0.0f, yl[11], bool(b1 & 0x08));
3279
+ acc += select(0.0f, yl[12], bool(b1 & 0x10));
3280
+ acc += select(0.0f, yl[13], bool(b1 & 0x20));
3281
+ acc += select(0.0f, yl[14], bool(b1 & 0x40));
3282
+ acc += select(0.0f, yl[15], bool(b1 & 0x80));
3283
+
3284
+ return qb_curr->d * (2.0f * acc - sumy);
3285
+ }
3286
+
3103
3287
  // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
3104
3288
  // il indicates where the q4 quants begin (0 or QK4_0/4)
3105
3289
  // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -3232,6 +3416,9 @@ static inline void helper_mv_reduce_and_write(
3232
3416
 
3233
3417
  constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
3234
3418
  constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
3419
+ constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]];
3420
+ constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]];
3421
+ constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]];
3235
3422
 
3236
3423
  template<typename block_q_type, short NR0, typename args_t>
3237
3424
  void mul_vec_q_n_f32_impl(
@@ -3255,10 +3442,10 @@ void mul_vec_q_n_f32_impl(
3255
3442
  const int r1 = tgpig.y;
3256
3443
  const int im = tgpig.z;
3257
3444
 
3258
- const uint i12 = im%args.ne12;
3259
- const uint i13 = im/args.ne12;
3445
+ const uint i12 = im%FC_mul_mv_ne12;
3446
+ const uint i13 = im/FC_mul_mv_ne12;
3260
3447
 
3261
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3448
+ //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3262
3449
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3263
3450
 
3264
3451
  //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
@@ -3267,7 +3454,7 @@ void mul_vec_q_n_f32_impl(
3267
3454
  // pointers to src0 rows
3268
3455
  device const block_q_type * ax[NR0];
3269
3456
  FOR_UNROLL (int row = 0; row < NR0; ++row) {
3270
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3457
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3271
3458
 
3272
3459
  ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
3273
3460
  }
@@ -3321,6 +3508,85 @@ void mul_vec_q_n_f32_impl(
3321
3508
  }
3322
3509
  }
3323
3510
 
3511
+ template<int nr0, typename args_t>
3512
+ void kernel_mul_mv_q1_0_f32_impl(
3513
+ args_t args,
3514
+ device const char * src0,
3515
+ device const char * src1,
3516
+ device char * dst,
3517
+ threadgroup char * shmem,
3518
+ uint3 tgpig,
3519
+ ushort tiisg,
3520
+ ushort sgitg) {
3521
+ const short NSG = FC_mul_mv_nsg;
3522
+
3523
+ const int nb = args.ne00/QK1_0;
3524
+
3525
+ const int r0 = tgpig.x;
3526
+ const int r1 = tgpig.y;
3527
+ const int im = tgpig.z;
3528
+
3529
+ const int first_row = (r0 * NSG + sgitg) * nr0;
3530
+
3531
+ const uint i12 = im%FC_mul_mv_ne12;
3532
+ const uint i13 = im/FC_mul_mv_ne12;
3533
+
3534
+ const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
3535
+
3536
+ device const float * y = (device const float *) (src1 + offset1);
3537
+
3538
+ device const block_q1_0 * ax[nr0];
3539
+ for (int row = 0; row < nr0; ++row) {
3540
+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3541
+ ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
3542
+ }
3543
+
3544
+ float yl[16];
3545
+ float sumf[nr0] = {0.f};
3546
+
3547
+ const short ix = (tiisg/8);
3548
+ const short il = (tiisg%8)*16;
3549
+
3550
+ device const float * yb = y + ix*QK1_0 + il;
3551
+
3552
+ for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
3553
+ float sumy = 0.f;
3554
+
3555
+ FOR_UNROLL (short i = 0; i < 16; i++) {
3556
+ yl[i] = yb[i];
3557
+ sumy += yb[i];
3558
+ }
3559
+
3560
+ FOR_UNROLL (short row = 0; row < nr0; row++) {
3561
+ sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
3562
+ }
3563
+
3564
+ yb += QK1_0 * (N_SIMDWIDTH/8);
3565
+ }
3566
+
3567
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3568
+
3569
+ for (int row = 0; row < nr0; ++row) {
3570
+ const float tot = simd_sum(sumf[row]);
3571
+
3572
+ if (tiisg == 0 && first_row + row < args.ne01) {
3573
+ dst_f32[first_row + row] = tot;
3574
+ }
3575
+ }
3576
+ }
3577
+
3578
+ [[host_name("kernel_mul_mv_q1_0_f32")]]
3579
+ kernel void kernel_mul_mv_q1_0_f32(
3580
+ constant ggml_metal_kargs_mul_mv & args,
3581
+ device const char * src0,
3582
+ device const char * src1,
3583
+ device char * dst,
3584
+ uint3 tgpig[[threadgroup_position_in_grid]],
3585
+ ushort tiisg[[thread_index_in_simdgroup]],
3586
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3587
+ kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
3588
+ }
3589
+
3324
3590
  kernel void kernel_mul_mv_q4_0_f32(
3325
3591
  constant ggml_metal_kargs_mul_mv & args,
3326
3592
  device const char * src0,
@@ -3390,10 +3656,10 @@ void kernel_mul_mv_q8_0_f32_impl(
3390
3656
  const int r1 = tgpig.y;
3391
3657
  const int im = tgpig.z;
3392
3658
 
3393
- const uint i12 = im%args.ne12;
3394
- const uint i13 = im/args.ne12;
3659
+ const uint i12 = im%FC_mul_mv_ne12;
3660
+ const uint i13 = im/FC_mul_mv_ne12;
3395
3661
 
3396
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3662
+ //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3397
3663
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3398
3664
 
3399
3665
  //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
@@ -3402,7 +3668,7 @@ void kernel_mul_mv_q8_0_f32_impl(
3402
3668
  // pointers to src0 rows
3403
3669
  device const block_q8_0 * ax[NR0];
3404
3670
  FOR_UNROLL (short row = 0; row < NR0; ++row) {
3405
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3671
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3406
3672
 
3407
3673
  ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
3408
3674
  }
@@ -3482,10 +3748,10 @@ void kernel_mul_mv_ext_q4_f32_impl(
3482
3748
  const int i11 = tgpig.y*r1ptg;
3483
3749
  const int i1m = tgpig.z;
3484
3750
 
3485
- const int i12 = i1m%args.ne12;
3486
- const int i13 = i1m/args.ne12;
3751
+ const int i12 = i1m%FC_mul_mv_ne12;
3752
+ const int i13 = i1m/FC_mul_mv_ne12;
3487
3753
 
3488
- const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3754
+ const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3489
3755
  const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3490
3756
 
3491
3757
  device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
@@ -3585,10 +3851,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
3585
3851
  const int i11 = tgpig.y*r1ptg;
3586
3852
  const int i1m = tgpig.z;
3587
3853
 
3588
- const int i12 = i1m%args.ne12;
3589
- const int i13 = i1m/args.ne12;
3854
+ const int i12 = i1m%FC_mul_mv_ne12;
3855
+ const int i13 = i1m/FC_mul_mv_ne12;
3590
3856
 
3591
- const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3857
+ const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3592
3858
  const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3593
3859
 
3594
3860
  device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
@@ -3713,6 +3979,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
3713
3979
  template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
3714
3980
  #endif
3715
3981
 
3982
+ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
3983
+ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
3984
+ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
3985
+ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
3986
+
3716
3987
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
3717
3988
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
3718
3989
  template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@@ -3795,10 +4066,10 @@ void kernel_mul_mv_t_t_impl(
3795
4066
  const int r1 = tgpig.y;
3796
4067
  const int im = tgpig.z;
3797
4068
 
3798
- const uint i12 = im%args.ne12;
3799
- const uint i13 = im/args.ne12;
4069
+ const uint i12 = im%FC_mul_mv_ne12;
4070
+ const uint i13 = im/FC_mul_mv_ne12;
3800
4071
 
3801
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4072
+ //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3802
4073
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3803
4074
 
3804
4075
  //device const T0 * x = (device const T0 *) (src0 + offset0);
@@ -3807,7 +4078,7 @@ void kernel_mul_mv_t_t_impl(
3807
4078
  // pointers to src0 rows
3808
4079
  device const T0 * ax [NR0];
3809
4080
  FOR_UNROLL (short row = 0; row < NR0; ++row) {
3810
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4081
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3811
4082
 
3812
4083
  ax[row] = (device const T0 *) ((device char *) src0 + offset0);
3813
4084
  }
@@ -3917,10 +4188,10 @@ void kernel_mul_mv_t_t_4_impl(
3917
4188
  const int r1 = tgpig.y;
3918
4189
  const int im = tgpig.z;
3919
4190
 
3920
- const uint i12 = im%args.ne12;
3921
- const uint i13 = im/args.ne12;
4191
+ const uint i12 = im%FC_mul_mv_ne12;
4192
+ const uint i13 = im/FC_mul_mv_ne12;
3922
4193
 
3923
- //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4194
+ //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3924
4195
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3925
4196
 
3926
4197
  device const T1 * y = (device const T1 *) (src1 + offset1);
@@ -3930,7 +4201,7 @@ void kernel_mul_mv_t_t_4_impl(
3930
4201
  device const T0 * ax [NR0];
3931
4202
  device const T04 * ax4[NR0];
3932
4203
  FOR_UNROLL (short row = 0; row < NR0; ++row) {
3933
- const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4204
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
3934
4205
 
3935
4206
  ax [row] = (device const T0 *) ((device char *) src0 + offset0);
3936
4207
  ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
@@ -4034,10 +4305,10 @@ void kernel_mul_mv_t_t_short_impl(
4034
4305
  return;
4035
4306
  }
4036
4307
 
4037
- const uint i12 = im%args.ne12;
4038
- const uint i13 = im/args.ne12;
4308
+ const uint i12 = im%FC_mul_mv_ne12;
4309
+ const uint i13 = im/FC_mul_mv_ne12;
4039
4310
 
4040
- const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
4311
+ const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
4041
4312
 
4042
4313
  device const T0 * x = (device const T0 *) (src0 + offset0);
4043
4314
 
@@ -4460,59 +4731,59 @@ kernel void kernel_im2col(
4460
4731
  template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4461
4732
  template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4462
4733
 
4463
- // TODO: obsolete -- remove
4464
- //typedef void (im2col_ext_t)(
4465
- // constant ggml_metal_kargs_im2col & args,
4466
- // device const float * x,
4467
- // device char * dst,
4468
- // uint3 tgpig[[threadgroup_position_in_grid]],
4469
- // uint3 tgpg[[threadgroups_per_grid]],
4470
- // uint3 tpitg[[thread_position_in_threadgroup]],
4471
- // uint3 ntg[[threads_per_threadgroup]]);
4472
- //
4473
- //template <typename T>
4474
- //kernel void kernel_im2col_ext(
4475
- // constant ggml_metal_kargs_im2col & args,
4476
- // device const float * x,
4477
- // device char * dst,
4478
- // uint3 tgpig[[threadgroup_position_in_grid]],
4479
- // uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4480
- // uint3 tpitg[[thread_position_in_threadgroup]],
4481
- // uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4482
- // const int64_t KHW = (int64_t)args.KHW;
4483
- //
4484
- // const int64_t d = tgpig[0] / args.CHW;
4485
- // const int64_t chw = tgpig[0] % args.CHW;
4486
- // const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4487
- // const int64_t HW = tgpig[0] % KHW;
4488
- //
4489
- // const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4490
- // if (tpitg_0 >= args.N) {
4491
- // return;
4492
- // }
4493
- //
4494
- // const int64_t tpitg_1 = HW / args.KW;
4495
- // const int64_t tpitg_2 = HW % args.KW;
4496
- //
4497
- // const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4498
- // const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4499
- //
4500
- // const int64_t offset_dst =
4501
- // (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4502
- // (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4503
- //
4504
- // device T * pdst = (device T *) (dst);
4505
- //
4506
- // if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4507
- // pdst[offset_dst] = 0.0f;
4508
- // } else {
4509
- // const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4510
- // pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4511
- // }
4512
- //}
4513
- //
4514
- //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4515
- //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4734
+ // TODO: optimize
4735
+ typedef void (im2col_ext_t)(
4736
+ constant ggml_metal_kargs_im2col & args,
4737
+ device const float * x,
4738
+ device char * dst,
4739
+ uint3 tgpig[[threadgroup_position_in_grid]],
4740
+ uint3 tgpg[[threadgroups_per_grid]],
4741
+ uint3 tpitg[[thread_position_in_threadgroup]],
4742
+ uint3 ntg[[threads_per_threadgroup]]);
4743
+
4744
+ template <typename T>
4745
+ kernel void kernel_im2col_ext(
4746
+ constant ggml_metal_kargs_im2col & args,
4747
+ device const float * x,
4748
+ device char * dst,
4749
+ uint3 tgpig[[threadgroup_position_in_grid]],
4750
+ uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4751
+ uint3 tpitg[[thread_position_in_threadgroup]],
4752
+ uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4753
+ const int64_t KHW = (int64_t)args.KHW;
4754
+
4755
+ const int64_t d = tgpig[0] / args.CHW;
4756
+ const int64_t chw = tgpig[0] % args.CHW;
4757
+ const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4758
+ const int64_t HW = tgpig[0] % KHW;
4759
+
4760
+ const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4761
+ if (tpitg_0 >= args.N) {
4762
+ return;
4763
+ }
4764
+
4765
+ const int64_t tpitg_1 = HW / args.KW;
4766
+ const int64_t tpitg_2 = HW % args.KW;
4767
+
4768
+ const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4769
+ const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4770
+
4771
+ const int64_t offset_dst =
4772
+ (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4773
+ (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4774
+
4775
+ device T * pdst = (device T *) (dst);
4776
+
4777
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4778
+ pdst[offset_dst] = 0.0f;
4779
+ } else {
4780
+ const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4781
+ pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4782
+ }
4783
+ }
4784
+
4785
+ template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4786
+ template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4516
4787
 
4517
4788
  template <typename TK>
4518
4789
  kernel void kernel_conv_2d(
@@ -4645,15 +4916,32 @@ kernel void kernel_conv_transpose_1d(
4645
4916
  uint3 tgpig[[threadgroup_position_in_grid]],
4646
4917
  uint3 tgpg[[threadgroups_per_grid]]) {
4647
4918
 
4648
- float v = 0.0f;
4919
+ // For output position j on the time axis, only input positions
4920
+ // i such that i*s0 <= j < i*s0 + K
4921
+ // contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
4922
+ // intersected with [0, IL-1]. That's at most ceil(K/s0) values
4923
+ // (typically 2 for stride==K/2 transposed convs).
4924
+ const int32_t j = tgpig[0];
4925
+ const int32_t s0 = args.s0;
4926
+ const int32_t K = args.K;
4927
+ const int32_t IL = args.IL;
4928
+
4929
+ int32_t i_min;
4930
+ {
4931
+ int32_t a = j - K + 1;
4932
+ i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
4933
+ }
4934
+ int32_t i_max = j / s0;
4935
+ if (i_max > IL - 1) i_max = IL - 1;
4649
4936
 
4650
- for (int64_t c = 0; c < args.IC; c++) {
4651
- const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
4652
- const int32_t input_offset = c * args.IL;
4937
+ float v = 0.0f;
4938
+ if (i_min <= i_max) {
4939
+ for (int64_t c = 0; c < args.IC; c++) {
4940
+ const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
4941
+ const int32_t input_offset = c * IL;
4653
4942
 
4654
- for (int64_t i = 0; i < args.IL; i++) {
4655
- if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
4656
- v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
4943
+ for (int32_t i = i_min; i <= i_max; i++) {
4944
+ v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
4657
4945
  }
4658
4946
  }
4659
4947
  }
@@ -4851,7 +5139,7 @@ kernel void kernel_upscale_bilinear_f32(
4851
5139
  for (int64_t sx = x_min; sx < x_max; ++sx) {
4852
5140
  const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
4853
5141
  const float w = wx * wy;
4854
- const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
5142
+ device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
4855
5143
  sum += (*src_ptr) * w;
4856
5144
  wsum += w;
4857
5145
  }
@@ -4883,6 +5171,98 @@ kernel void kernel_upscale_bilinear_f32(
4883
5171
  }
4884
5172
  }
4885
5173
 
5174
+ template <typename T>
5175
+ kernel void kernel_conv_3d(
5176
+ constant ggml_metal_kargs_conv_3d & args,
5177
+ device const char * src0, // Weights [IC * OC, KD, KH, KW]
5178
+ device const char * src1, // Inputs [IC * N, ID, IH, IW]
5179
+ device char * dst, // Outputs [OC * N, OD, OH, OW]
5180
+ uint3 tgpig[[threadgroup_position_in_grid]],
5181
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
5182
+
5183
+ // 1. Un-flatten the spatial dimension from Grid X
5184
+ int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
5185
+
5186
+ if (spatial_idx >= args.OW * args.OH * args.OD) {
5187
+ return; // Thread falls outside the spatial volume
5188
+ }
5189
+
5190
+ int64_t od = spatial_idx / (args.OW * args.OH);
5191
+ int64_t oh = (spatial_idx / args.OW) % args.OH;
5192
+ int64_t ow = spatial_idx % args.OW;
5193
+
5194
+ // 2. Map Y to Channels, Z to Batch
5195
+ int64_t oc = tgpig.y;
5196
+ int64_t batch_idx = tgpig.z;
5197
+
5198
+ // 3. Calculate anchor coordinates in the Input volume
5199
+ int64_t i_w_base = ow * args.s0 - args.p0;
5200
+ int64_t i_h_base = oh * args.s1 - args.p1;
5201
+ int64_t i_d_base = od * args.s2 - args.p2;
5202
+
5203
+ float sum = 0.0f;
5204
+
5205
+ // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
5206
+ for (int64_t ic = 0; ic < args.IC; ++ic) {
5207
+
5208
+ // ggml packs batch and channel together in the 4th dimension
5209
+ int64_t src_cn_idx = batch_idx * args.IC + ic;
5210
+ int64_t w_cn_idx = oc * args.IC + ic;
5211
+
5212
+ for (int64_t kz = 0; kz < args.KD; ++kz) {
5213
+ int64_t id = i_d_base + kz * args.d2;
5214
+ if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
5215
+
5216
+ for (int64_t ky = 0; ky < args.KH; ++ky) {
5217
+ int64_t ih = i_h_base + ky * args.d1;
5218
+ if (ih < 0 || ih >= args.IH) continue;
5219
+
5220
+ for (int64_t kx = 0; kx < args.KW; ++kx) {
5221
+ int64_t iw = i_w_base + kx * args.d0;
5222
+ if (iw < 0 || iw >= args.IW) continue;
5223
+
5224
+ // Convert multi-dimensional coordinates to flat byte offsets
5225
+ int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
5226
+ int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
5227
+
5228
+ // Dereference memory and cast weights to f32 if they were f16
5229
+ float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
5230
+ float i_val = *(device const float*)((device const char*)src1 + i_idx);
5231
+
5232
+ sum += w_val * i_val;
5233
+ }
5234
+ }
5235
+ }
5236
+ }
5237
+
5238
+ // 5. Write the accumulated value out to RAM
5239
+ int64_t dst_cn_idx = batch_idx * args.OC + oc;
5240
+ int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
5241
+
5242
+ *(device float*)(dst + d_idx) = sum;
5243
+ }
5244
+
5245
+ // Explicit instantiations so the JIT compiler can find them by name
5246
+ template [[host_name("kernel_conv_3d_f32_f32")]]
5247
+ kernel void kernel_conv_3d<float>(
5248
+ constant ggml_metal_kargs_conv_3d & args,
5249
+ device const char * src0,
5250
+ device const char * src1,
5251
+ device char * dst,
5252
+ uint3 tgpig[[threadgroup_position_in_grid]],
5253
+ uint3 tpitg[[thread_position_in_threadgroup]]);
5254
+
5255
+ // Explicit instantiation for f16 weights
5256
+ template [[host_name("kernel_conv_3d_f16_f32")]]
5257
+ kernel void kernel_conv_3d<half>(
5258
+ constant ggml_metal_kargs_conv_3d & args,
5259
+ device const char * src0,
5260
+ device const char * src1,
5261
+ device char * dst,
5262
+ uint3 tgpig[[threadgroup_position_in_grid]],
5263
+ uint3 tpitg[[thread_position_in_threadgroup]]);
5264
+
5265
+
4886
5266
  static inline float bicubic_weight1(float x) {
4887
5267
  const float a = -0.75f;
4888
5268
  return ((a + 2) * x - (a + 3)) * x * x + 1;
@@ -4941,7 +5321,7 @@ kernel void kernel_upscale_bicubic_f32(
4941
5321
  const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
4942
5322
  const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
4943
5323
 
4944
- const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
5324
+ device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
4945
5325
  sum += (*src_ptr) * wx * wy;
4946
5326
  }
4947
5327
  }
@@ -4950,8 +5330,8 @@ kernel void kernel_upscale_bicubic_f32(
4950
5330
  }
4951
5331
  }
4952
5332
 
4953
- kernel void kernel_pad_f32(
4954
- constant ggml_metal_kargs_pad & args,
5333
+ kernel void kernel_roll_f32(
5334
+ constant ggml_metal_kargs_roll & args,
4955
5335
  device const char * src0,
4956
5336
  device char * dst,
4957
5337
  uint3 tgpig[[threadgroup_position_in_grid]],
@@ -4962,30 +5342,68 @@ kernel void kernel_pad_f32(
4962
5342
  const int64_t i2 = tgpig.y;
4963
5343
  const int64_t i1 = tgpig.x;
4964
5344
 
4965
- const int64_t i03 = i3;
4966
- const int64_t i02 = i2;
4967
- const int64_t i01 = i1;
5345
+ device const float * src0_ptr = (device const float *) src0;
5346
+ device float * dst_ptr = (device float *) dst;
4968
5347
 
4969
- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4970
- device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
5348
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
5349
+ // apply shifts and wrap around
5350
+ int64_t i00 = i0 - args.s0;
5351
+ int64_t i01 = i1 - args.s1;
5352
+ int64_t i02 = i2 - args.s2;
5353
+ int64_t i03 = i3 - args.s3;
4971
5354
 
4972
- if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
4973
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4974
- if (i0 < args.ne00) {
4975
- dst_ptr[i0] = src0_ptr[i0];
4976
- } else {
4977
- dst_ptr[i0] = 0.0f;
4978
- }
4979
- }
5355
+ if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
5356
+ if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
5357
+ if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
5358
+ if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
4980
5359
 
4981
- return;
5360
+ int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
5361
+ int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
5362
+
5363
+ dst_ptr[dst_idx] = src0_ptr[src_idx];
4982
5364
  }
5365
+ }
4983
5366
 
4984
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4985
- dst_ptr[i0] = 0.0f;
5367
+ template <typename T>
5368
+ kernel void kernel_pad_impl(
5369
+ constant ggml_metal_kargs_pad & args,
5370
+ device const char * src0,
5371
+ device char * dst,
5372
+ uint3 tgpig[[threadgroup_position_in_grid]],
5373
+ uint3 tpitg[[thread_position_in_threadgroup]],
5374
+ uint3 ntg[[threads_per_threadgroup]]) {
5375
+ const int32_t i3 = tgpig.z;
5376
+ const int32_t i2 = tgpig.y;
5377
+ const int32_t k0 = tgpig.x/args.ne1;
5378
+ const int32_t i1 = tgpig.x - k0*args.ne1;
5379
+
5380
+ const int32_t i03 = i3;
5381
+ const int32_t i02 = i2;
5382
+ const int32_t i01 = i1;
5383
+
5384
+ device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
5385
+ device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
5386
+
5387
+ for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
5388
+ const int32_t i0 = k0*1024 + tpitg.x + l0;
5389
+ if (i0 >= args.ne0) {
5390
+ break;
5391
+ }
5392
+
5393
+ if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
5394
+ dst_ptr[i0] = src0_ptr[i0];
5395
+ } else {
5396
+ dst_ptr[i0] = 0.0f;
5397
+ }
4986
5398
  }
4987
5399
  }
4988
5400
 
5401
+ typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
5402
+
5403
+ template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
5404
+ template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
5405
+
5406
+ // TODO: this is slow - optimize
4989
5407
  kernel void kernel_pad_reflect_1d_f32(
4990
5408
  constant ggml_metal_kargs_pad_reflect_1d & args,
4991
5409
  device const char * src0,
@@ -6177,6 +6595,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_at
6177
6595
  template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
6178
6596
  template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
6179
6597
  template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>;
6598
+ template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 512, 512>;
6180
6599
  template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
6181
6600
 
6182
6601
  template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
@@ -6192,6 +6611,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_at
6192
6611
  template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
6193
6612
  template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
6194
6613
  template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>;
6614
+ template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 512, 512>;
6195
6615
  template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
6196
6616
 
6197
6617
  #if defined(GGML_METAL_HAS_BF16)
@@ -6208,6 +6628,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_at
6208
6628
  template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
6209
6629
  template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
6210
6630
  template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>;
6631
+ template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 512, 512>;
6211
6632
  template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
6212
6633
  #endif
6213
6634
 
@@ -6224,6 +6645,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_at
6224
6645
  template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
6225
6646
  template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
6226
6647
  template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>;
6648
+ template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 512, 512>;
6227
6649
  template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
6228
6650
 
6229
6651
  template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
@@ -6239,6 +6661,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_at
6239
6661
  template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
6240
6662
  template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
6241
6663
  template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>;
6664
+ template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 512, 512>;
6242
6665
  template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
6243
6666
 
6244
6667
  template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
@@ -6254,6 +6677,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_at
6254
6677
  template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
6255
6678
  template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
6256
6679
  template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>;
6680
+ template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 512, 512>;
6257
6681
  template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
6258
6682
 
6259
6683
  template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
@@ -6269,6 +6693,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_at
6269
6693
  template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
6270
6694
  template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
6271
6695
  template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>;
6696
+ template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 512, 512>;
6272
6697
  template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
6273
6698
 
6274
6699
  template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
@@ -6284,6 +6709,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_at
6284
6709
  template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
6285
6710
  template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
6286
6711
  template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>;
6712
+ template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 512, 512>;
6287
6713
  template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
6288
6714
 
6289
6715
  #undef FA_TYPES
@@ -6865,6 +7291,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flas
6865
7291
  template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>;
6866
7292
  template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>;
6867
7293
 
7294
+ template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 512, 512, 1>;
7295
+ template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 512, 512, 1>;
7296
+ #if defined(GGML_METAL_HAS_BF16)
7297
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 512, 512, 1>;
7298
+ #endif
7299
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 512, 512, 1>;
7300
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 512, 512, 1>;
7301
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 512, 512, 1>;
7302
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 512, 512, 1>;
7303
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 512, 512, 1>;
7304
+
6868
7305
  template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
6869
7306
  template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
6870
7307
  #if defined(GGML_METAL_HAS_BF16)
@@ -6930,23 +7367,27 @@ kernel void kernel_cpy_t_t(
6930
7367
  device const char * src0,
6931
7368
  device char * dst,
6932
7369
  uint3 tgpig[[threadgroup_position_in_grid]],
6933
- ushort tiitg[[thread_index_in_threadgroup]],
7370
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6934
7371
  ushort3 ntg[[threads_per_threadgroup]]) {
6935
- const int i03 = tgpig[2];
6936
- const int i02 = tgpig[1];
6937
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6938
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7372
+ const int32_t i03 = tgpig[2];
7373
+ const int32_t i02 = tgpig[1];
7374
+ const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7375
+ const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7376
+
7377
+ if (i01 >= args.ne01) {
7378
+ return;
7379
+ }
6939
7380
 
6940
7381
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6941
7382
 
6942
- const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
6943
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
6944
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
6945
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
7383
+ const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
7384
+ const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7385
+ const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7386
+ const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
6946
7387
 
6947
7388
  device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6948
7389
 
6949
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
7390
+ for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
6950
7391
  device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
6951
7392
  dst_data[i00] = (T1) src[0];
6952
7393
  break;
@@ -6978,23 +7419,27 @@ kernel void kernel_cpy_f32_q(
6978
7419
  device const char * src0,
6979
7420
  device char * dst,
6980
7421
  uint3 tgpig[[threadgroup_position_in_grid]],
6981
- ushort tiitg[[thread_index_in_threadgroup]],
7422
+ ushort3 tpitg[[thread_position_in_threadgroup]],
6982
7423
  ushort3 ntg[[threads_per_threadgroup]]) {
6983
- const int i03 = tgpig[2];
6984
- const int i02 = tgpig[1];
6985
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6986
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7424
+ const int32_t i03 = tgpig[2];
7425
+ const int32_t i02 = tgpig[1];
7426
+ const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7427
+ const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7428
+
7429
+ if (i01 >= args.ne01) {
7430
+ return;
7431
+ }
6987
7432
 
6988
7433
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6989
7434
 
6990
- const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
6991
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
6992
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
6993
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
7435
+ const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
7436
+ const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
7437
+ const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
7438
+ const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
6994
7439
 
6995
7440
  device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6996
7441
 
6997
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
7442
+ for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
6998
7443
  device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
6999
7444
 
7000
7445
  quantize_func(src, dst_data[i00]);
@@ -7006,6 +7451,7 @@ kernel void kernel_cpy_f32_q(
7006
7451
  typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
7007
7452
 
7008
7453
  template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
7454
+ template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
7009
7455
  template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
7010
7456
  template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
7011
7457
  template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
@@ -7018,24 +7464,28 @@ kernel void kernel_cpy_q_f32(
7018
7464
  device const char * src0,
7019
7465
  device char * dst,
7020
7466
  uint3 tgpig[[threadgroup_position_in_grid]],
7021
- ushort tiitg[[thread_index_in_threadgroup]],
7467
+ ushort3 tpitg[[thread_position_in_threadgroup]],
7022
7468
  ushort3 ntg[[threads_per_threadgroup]]) {
7023
- const int i03 = tgpig[2];
7024
- const int i02 = tgpig[1];
7025
- const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
7026
- const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7469
+ const int32_t i03 = tgpig[2];
7470
+ const int32_t i02 = tgpig[1];
7471
+ const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7472
+ const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7473
+
7474
+ if (i01 >= args.ne01) {
7475
+ return;
7476
+ }
7027
7477
 
7028
7478
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
7029
7479
 
7030
- const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
7031
- const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7032
- const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7033
- const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
7480
+ const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
7481
+ const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7482
+ const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7483
+ const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
7034
7484
 
7035
7485
  device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
7036
7486
  device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
7037
7487
 
7038
- for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
7488
+ for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
7039
7489
  T4x4 temp;
7040
7490
  dequantize_func(src_data + i00/nl, i00%nl, temp);
7041
7491
  dst_data[i00] = temp;
@@ -7046,12 +7496,14 @@ kernel void kernel_cpy_q_f32(
7046
7496
 
7047
7497
  typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
7048
7498
 
7499
+ template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
7049
7500
  template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
7050
7501
  template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
7051
7502
  template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
7052
7503
  template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
7053
7504
  template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
7054
7505
 
7506
+ template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
7055
7507
  template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
7056
7508
  template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
7057
7509
  template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
@@ -7069,7 +7521,11 @@ kernel void kernel_concat(
7069
7521
 
7070
7522
  const int i3 = tgpig.z;
7071
7523
  const int i2 = tgpig.y;
7072
- const int i1 = tgpig.x;
7524
+ const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
7525
+
7526
+ if (i1 >= args.ne1) {
7527
+ return;
7528
+ }
7073
7529
 
7074
7530
  int o[4] = {0, 0, 0, 0};
7075
7531
  o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
@@ -7109,10 +7565,10 @@ void kernel_mul_mv_q2_K_f32_impl(
7109
7565
 
7110
7566
  const int first_row = (r0 * NSG + sgitg) * nr0;
7111
7567
 
7112
- const uint i12 = im%args.ne12;
7113
- const uint i13 = im/args.ne12;
7568
+ const uint i12 = im%FC_mul_mv_ne12;
7569
+ const uint i13 = im/FC_mul_mv_ne12;
7114
7570
 
7115
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7571
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7116
7572
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7117
7573
 
7118
7574
  device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
@@ -7214,10 +7670,10 @@ void kernel_mul_mv_q3_K_f32_impl(
7214
7670
 
7215
7671
  const int first_row = (r0 * NSG + sgitg) * nr0;
7216
7672
 
7217
- const uint i12 = im%args.ne12;
7218
- const uint i13 = im/args.ne12;
7673
+ const uint i12 = im%FC_mul_mv_ne12;
7674
+ const uint i13 = im/FC_mul_mv_ne12;
7219
7675
 
7220
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7676
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7221
7677
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7222
7678
 
7223
7679
  device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
@@ -7388,10 +7844,10 @@ void kernel_mul_mv_q4_K_f32_impl(
7388
7844
 
7389
7845
  const int first_row = (r0 * NSG + sgitg) * nr0;
7390
7846
 
7391
- const uint i12 = im%args.ne12;
7392
- const uint i13 = im/args.ne12;
7847
+ const uint i12 = im%FC_mul_mv_ne12;
7848
+ const uint i13 = im/FC_mul_mv_ne12;
7393
7849
 
7394
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7850
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7395
7851
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7396
7852
 
7397
7853
  device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
@@ -7500,10 +7956,10 @@ void kernel_mul_mv_q5_K_f32_impl(
7500
7956
 
7501
7957
  const int first_row = (r0 * NSG + sgitg) * nr0;
7502
7958
 
7503
- const uint i12 = im%args.ne12;
7504
- const uint i13 = im/args.ne12;
7959
+ const uint i12 = im%FC_mul_mv_ne12;
7960
+ const uint i13 = im/FC_mul_mv_ne12;
7505
7961
 
7506
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7962
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7507
7963
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7508
7964
 
7509
7965
  device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
@@ -7636,10 +8092,10 @@ void kernel_mul_mv_q6_K_f32_impl(
7636
8092
 
7637
8093
  const int first_row = (r0 * NSG + sgitg) * nr0;
7638
8094
 
7639
- const uint i12 = im%args.ne12;
7640
- const uint i13 = im/args.ne12;
8095
+ const uint i12 = im%FC_mul_mv_ne12;
8096
+ const uint i13 = im/FC_mul_mv_ne12;
7641
8097
 
7642
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8098
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7643
8099
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7644
8100
 
7645
8101
  device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
@@ -7741,10 +8197,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
7741
8197
 
7742
8198
  const int first_row = (r0 * NSG + sgitg) * nr0;
7743
8199
 
7744
- const uint i12 = im%args.ne12;
7745
- const uint i13 = im/args.ne12;
8200
+ const uint i12 = im%FC_mul_mv_ne12;
8201
+ const uint i13 = im/FC_mul_mv_ne12;
7746
8202
 
7747
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8203
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7748
8204
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7749
8205
 
7750
8206
  device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
@@ -7849,10 +8305,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
7849
8305
 
7850
8306
  const int first_row = (r0 * NSG + sgitg) * nr0;
7851
8307
 
7852
- const uint i12 = im%args.ne12;
7853
- const uint i13 = im/args.ne12;
8308
+ const uint i12 = im%FC_mul_mv_ne12;
8309
+ const uint i13 = im/FC_mul_mv_ne12;
7854
8310
 
7855
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8311
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7856
8312
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7857
8313
 
7858
8314
  device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
@@ -7968,10 +8424,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
7968
8424
 
7969
8425
  const int first_row = (r0 * NSG + sgitg) * nr0;
7970
8426
 
7971
- const uint i12 = im%args.ne12;
7972
- const uint i13 = im/args.ne12;
8427
+ const uint i12 = im%FC_mul_mv_ne12;
8428
+ const uint i13 = im/FC_mul_mv_ne12;
7973
8429
 
7974
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8430
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
7975
8431
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7976
8432
 
7977
8433
  device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
@@ -8080,10 +8536,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
8080
8536
 
8081
8537
  const int first_row = (r0 * NSG + sgitg) * nr0;
8082
8538
 
8083
- const uint i12 = im%args.ne12;
8084
- const uint i13 = im/args.ne12;
8539
+ const uint i12 = im%FC_mul_mv_ne12;
8540
+ const uint i13 = im/FC_mul_mv_ne12;
8085
8541
 
8086
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8542
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8087
8543
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8088
8544
 
8089
8545
  device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
@@ -8192,10 +8648,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
8192
8648
 
8193
8649
  const int first_row = (r0 * NSG + sgitg) * nr0;
8194
8650
 
8195
- const uint i12 = im%args.ne12;
8196
- const uint i13 = im/args.ne12;
8651
+ const uint i12 = im%FC_mul_mv_ne12;
8652
+ const uint i13 = im/FC_mul_mv_ne12;
8197
8653
 
8198
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8654
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8199
8655
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8200
8656
 
8201
8657
  device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
@@ -8305,10 +8761,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
8305
8761
 
8306
8762
  const int first_row = (r0 * NSG + sgitg) * nr0;
8307
8763
 
8308
- const uint i12 = im%args.ne12;
8309
- const uint i13 = im/args.ne12;
8764
+ const uint i12 = im%FC_mul_mv_ne12;
8765
+ const uint i13 = im/FC_mul_mv_ne12;
8310
8766
 
8311
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8767
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8312
8768
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8313
8769
 
8314
8770
  device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
@@ -8404,10 +8860,10 @@ void kernel_mul_mv_iq1_m_f32_impl(
8404
8860
 
8405
8861
  const int first_row = (r0 * NSG + sgitg) * nr0;
8406
8862
 
8407
- const uint i12 = im%args.ne12;
8408
- const uint i13 = im/args.ne12;
8863
+ const uint i12 = im%FC_mul_mv_ne12;
8864
+ const uint i13 = im/FC_mul_mv_ne12;
8409
8865
 
8410
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8866
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8411
8867
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8412
8868
 
8413
8869
  device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
@@ -8513,10 +8969,10 @@ void kernel_mul_mv_iq4_nl_f32_impl(
8513
8969
 
8514
8970
  const int first_row = (r0 * NSG + sgitg) * NR0;
8515
8971
 
8516
- const uint i12 = im%args.ne12;
8517
- const uint i13 = im/args.ne12;
8972
+ const uint i12 = im%FC_mul_mv_ne12;
8973
+ const uint i13 = im/FC_mul_mv_ne12;
8518
8974
 
8519
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8975
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8520
8976
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8521
8977
 
8522
8978
  device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
@@ -8622,10 +9078,10 @@ void kernel_mul_mv_iq4_xs_f32_impl(
8622
9078
  const int im = tgpig.z;
8623
9079
  const int first_row = (r0 * NSG + sgitg) * NR0;
8624
9080
 
8625
- const uint i12 = im%args.ne12;
8626
- const uint i13 = im/args.ne12;
9081
+ const uint i12 = im%FC_mul_mv_ne12;
9082
+ const uint i13 = im/FC_mul_mv_ne12;
8627
9083
 
8628
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
9084
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8629
9085
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8630
9086
 
8631
9087
  device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
@@ -8733,10 +9189,10 @@ void kernel_mul_mv_mxfp4_f32_impl(
8733
9189
 
8734
9190
  const int first_row = (r0 * NSG + sgitg) * NR0;
8735
9191
 
8736
- const uint i12 = im%args.ne12;
8737
- const uint i13 = im/args.ne12;
9192
+ const uint i12 = im%FC_mul_mv_ne12;
9193
+ const uint i13 = im/FC_mul_mv_ne12;
8738
9194
 
8739
- const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
9195
+ const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03;
8740
9196
  const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8741
9197
 
8742
9198
  device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
@@ -8951,9 +9407,143 @@ kernel void kernel_diag_f32(
8951
9407
 
8952
9408
  constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
8953
9409
  constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
9410
+ constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
9411
+ constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
9412
+ constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
9413
+ constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
8954
9414
 
8955
9415
  // each block_q contains 16*nl weights
8956
- template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
9416
+ #ifdef GGML_METAL_HAS_TENSOR
9417
+ template<
9418
+ typename SA, typename SA_4x4, typename SA_8x8,
9419
+ typename SB, typename SB_2x4, typename SB_8x8,
9420
+ typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
9421
+ typename T0, typename T0_4x4, typename T1, typename T1_2x4>
9422
+ kernel void kernel_mul_mm(
9423
+ constant ggml_metal_kargs_mul_mm & args,
9424
+ device const char * srcA,
9425
+ device const char * srcB,
9426
+ device char * dst,
9427
+ threadgroup char * shmem [[threadgroup(0)]],
9428
+ uint3 tgpig [[threadgroup_position_in_grid]],
9429
+ ushort tiitg [[thread_index_in_threadgroup]],
9430
+ ushort sgitg [[simdgroup_index_in_threadgroup]]) {
9431
+ (void) sgitg;
9432
+
9433
+ // Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
9434
+ const int K = args.ne00;
9435
+ const int M = args.ne0;
9436
+ const int N = args.ne1;
9437
+
9438
+ // Batch dimension handling
9439
+ const int im = tgpig.z;
9440
+ const int i12 = im % FC_mul_mm_ne12;
9441
+ const int i13 = im / FC_mul_mm_ne12;
9442
+
9443
+ // Batch offsets for srcA and srcB
9444
+ const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
9445
+
9446
+ // Tile dimensions
9447
+ constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
9448
+ constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
9449
+
9450
+ // Tile offsets in output matrix
9451
+ const int ra = tgpig.y * NRA;
9452
+ const int rb = tgpig.x * NRB;
9453
+
9454
+ // Threadgroup memory for dequantized A tile only
9455
+ threadgroup SA * sa = (threadgroup SA *)(shmem);
9456
+
9457
+ // Work-item count for A loading
9458
+ constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
9459
+ constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
9460
+
9461
+ // tA wraps threadgroup memory
9462
+ auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
9463
+
9464
+ // tB wraps device memory directly
9465
+ device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
9466
+ const int strideB = args.nb11 / sizeof(T1);
9467
+ auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
9468
+
9469
+ // Configure matmul operation
9470
+ mpp::tensor_ops::matmul2d<
9471
+ mpp::tensor_ops::matmul2d_descriptor(
9472
+ NRB, NRA, N_MM_NK_TOTAL, false, true, true,
9473
+ mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
9474
+ execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
9475
+
9476
+ auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
9477
+
9478
+ // Accumulate partial results over K dimension
9479
+ for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
9480
+ // === PHASE 1: Dequantization of A into threadgroup memory ===
9481
+ for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
9482
+ const int row = work / N_MM_NK;
9483
+ const int k_chunk = work % N_MM_NK;
9484
+ const int k_pos = loop_k + k_chunk * 16;
9485
+ const short k_base = k_chunk * 16;
9486
+
9487
+ // Bounds check: skip device read if row is out of matrix bounds
9488
+ if (ra + row < M) {
9489
+ if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
9490
+ // Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
9491
+ // MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
9492
+ // nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
9493
+ // Mirrors the legacy kernel's existing guard.
9494
+ device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
9495
+
9496
+ FOR_UNROLL (short i = 0; i < 16; i++) {
9497
+ sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
9498
+ }
9499
+ } else {
9500
+ const int block_idx = k_pos / (16 * nl);
9501
+ const short il = (k_pos / 16) % nl;
9502
+
9503
+ device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
9504
+
9505
+ SA_4x4 temp_a;
9506
+ dequantize_func(row_ptr + block_idx, il, temp_a);
9507
+
9508
+ FOR_UNROLL (short i = 0; i < 16; i++) {
9509
+ // Zero-pad A for K positions beyond valid range (handles partial K iterations)
9510
+ sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
9511
+ }
9512
+ }
9513
+ } else {
9514
+ // Zero-pad rows beyond matrix bounds
9515
+ FOR_UNROLL (short i = 0; i < 16; i++) {
9516
+ sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
9517
+ }
9518
+ }
9519
+ }
9520
+
9521
+ threadgroup_barrier(mem_flags::mem_threadgroup);
9522
+
9523
+ // === PHASE 2: Tensor matmul ===
9524
+ auto mA = tA.slice(0, 0);
9525
+ auto mB = tB.slice(loop_k, rb);
9526
+
9527
+ mm.run(mB, mA, cT);
9528
+
9529
+ threadgroup_barrier(mem_flags::mem_threadgroup);
9530
+ }
9531
+
9532
+ // Store result tile to output matrix (with batch offset)
9533
+ // cT.store handles bounds checking via tD's extents (M, N)
9534
+ device float * dstBatch = (device float *)dst + im * N * M;
9535
+
9536
+ auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
9537
+ cT.store(tD.slice(ra, rb));
9538
+ }
9539
+
9540
+ #else
9541
+
9542
+ template<
9543
+ typename S0, typename S0_4x4, typename S0_8x8,
9544
+ typename S1, typename S1_2x4, typename S1_8x8,
9545
+ typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
9546
+ typename T0, typename T0_4x4, typename T1, typename T1_2x4>
8957
9547
  kernel void kernel_mul_mm(
8958
9548
  constant ggml_metal_kargs_mul_mm & args,
8959
9549
  device const char * src0,
@@ -8967,10 +9557,6 @@ kernel void kernel_mul_mm(
8967
9557
  threadgroup S0 * sa = (threadgroup S0 *)(shmem);
8968
9558
  threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
8969
9559
 
8970
- #ifdef GGML_METAL_HAS_TENSOR
8971
- threadgroup float * sc = (threadgroup float *)(shmem);
8972
- #endif
8973
-
8974
9560
  constexpr int NR0 = 64;
8975
9561
  constexpr int NR1 = 32;
8976
9562
 
@@ -8994,10 +9580,10 @@ kernel void kernel_mul_mm(
8994
9580
 
8995
9581
  short il = il0;
8996
9582
 
8997
- const int i12 = im%args.ne12;
8998
- const int i13 = im/args.ne12;
9583
+ const int i12 = im % FC_mul_mm_ne12;
9584
+ const int i13 = im / FC_mul_mm_ne12;
8999
9585
 
9000
- const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
9586
+ const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
9001
9587
  const short offset1 = il0/nl;
9002
9588
 
9003
9589
  device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
@@ -9010,7 +9596,6 @@ kernel void kernel_mul_mm(
9010
9596
  + args.nb11*(r1 + lr1)
9011
9597
  + args.nb10*iy);
9012
9598
 
9013
- #ifndef GGML_METAL_HAS_TENSOR
9014
9599
  S0_8x8 ma[4];
9015
9600
  S1_8x8 mb[2];
9016
9601
 
@@ -9019,19 +9604,8 @@ kernel void kernel_mul_mm(
9019
9604
  for (short i = 0; i < 8; i++){
9020
9605
  mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
9021
9606
  }
9022
- #else
9023
- auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
9024
- auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
9025
-
9026
- mpp::tensor_ops::matmul2d<
9027
- mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
9028
- execution_simdgroups<4>> mm;
9029
-
9030
- auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
9031
- #endif
9032
9607
 
9033
9608
  for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
9034
- #ifndef GGML_METAL_HAS_TENSOR
9035
9609
  // load data and store to threadgroup memory
9036
9610
  if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
9037
9611
  threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -9101,66 +9675,6 @@ kernel void kernel_mul_mm(
9101
9675
 
9102
9676
  *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
9103
9677
  }
9104
- #else
9105
- // load data and store to threadgroup memory
9106
- if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
9107
- threadgroup_barrier(mem_flags::mem_threadgroup);
9108
-
9109
- // no need for dequantization
9110
- for (short i = 0; i < 16; i++) {
9111
- const short sx = 2*il0 + i/8;
9112
- const short sy = (tiitg/NL0)/8;
9113
-
9114
- const short lx = i%8;
9115
- const short ly = (tiitg/NL0)%8;
9116
- //const short lx = (tiitg/NL0)%8;
9117
- //const short ly = i%8;
9118
-
9119
- *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
9120
- }
9121
- } else {
9122
- S0_4x4 temp_a;
9123
- dequantize_func(x, il, temp_a);
9124
-
9125
- threadgroup_barrier(mem_flags::mem_threadgroup);
9126
-
9127
- FOR_UNROLL (short i = 0; i < 16; i++) {
9128
- const short sx = 2*il0 + i/8;
9129
- const short sy = (tiitg/NL0)/8;
9130
-
9131
- const short lx = i%8;
9132
- const short ly = (tiitg/NL0)%8;
9133
- //const short lx = (tiitg/NL0)%8;
9134
- //const short ly = i%8;
9135
-
9136
- *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
9137
- }
9138
- }
9139
-
9140
- if (FC_mul_mm_bc_inp) {
9141
- for (short i = 0; i < 8; ++i) {
9142
- const short sx = (tiitg%NL1);
9143
- const short sy = (tiitg/NL1)/8;
9144
-
9145
- const short lx = i;
9146
- const short ly = (tiitg/NL1)%8;
9147
- //const short lx = (tiitg/NL1)%8;
9148
- //const short ly = i;
9149
-
9150
- *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
9151
- }
9152
- } else {
9153
- const short sx = (tiitg%NL1);
9154
- const short sy = (tiitg/NL1)/8;
9155
-
9156
- //const short lx = i;
9157
- const short ly = (tiitg/NL1)%8;
9158
- //const short lx = (tiitg/NL1)%8;
9159
- //const short ly = i;
9160
-
9161
- *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
9162
- }
9163
- #endif
9164
9678
 
9165
9679
  il = (il + 2 < nl) ? il + 2 : il % 2;
9166
9680
  x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -9169,7 +9683,6 @@ kernel void kernel_mul_mm(
9169
9683
 
9170
9684
  threadgroup_barrier(mem_flags::mem_threadgroup);
9171
9685
 
9172
- #ifndef GGML_METAL_HAS_TENSOR
9173
9686
  // load matrices from threadgroup memory and conduct outer products
9174
9687
  threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
9175
9688
  threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
@@ -9196,24 +9709,10 @@ kernel void kernel_mul_mm(
9196
9709
  lsma += 8*64;
9197
9710
  lsmb += 4*64;
9198
9711
  }
9199
- #else
9200
- auto sA = tA.slice(0, 0);
9201
- auto sB = tB.slice(0, 0);
9202
-
9203
- mm.run(sB, sA, cT);
9204
- #endif
9205
9712
  }
9206
9713
 
9207
9714
  if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
9208
9715
  // if no bounds checks on the output are needed, we can directly write to device memory
9209
- #ifdef GGML_METAL_HAS_TENSOR
9210
- device float * C = (device float *) dst +
9211
- r0 + \
9212
- r1 * args.ne0 + im*args.ne1*args.ne0;
9213
-
9214
- auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
9215
- cT.store(tC);
9216
- #else
9217
9716
  device float * C = (device float *) dst +
9218
9717
  (r0 + 32*(sgitg & 1)) + \
9219
9718
  (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
@@ -9221,21 +9720,15 @@ kernel void kernel_mul_mm(
9221
9720
  for (short i = 0; i < 8; i++) {
9222
9721
  simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
9223
9722
  }
9224
- #endif
9225
9723
  } else {
9226
9724
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
9227
9725
  threadgroup_barrier(mem_flags::mem_threadgroup);
9228
9726
 
9229
9727
  threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
9230
9728
 
9231
- #ifdef GGML_METAL_HAS_TENSOR
9232
- auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
9233
- cT.store(tC);
9234
- #else
9235
9729
  for (short i = 0; i < 8; i++) {
9236
9730
  simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
9237
9731
  }
9238
- #endif
9239
9732
 
9240
9733
  threadgroup_barrier(mem_flags::mem_threadgroup);
9241
9734
 
@@ -9261,6 +9754,8 @@ kernel void kernel_mul_mm(
9261
9754
  }
9262
9755
  }
9263
9756
 
9757
+ #endif // GGML_METAL_HAS_TENSOR
9758
+
9264
9759
  template<short ne20> // n_expert_used
9265
9760
  kernel void kernel_mul_mm_id_map0(
9266
9761
  constant ggml_metal_kargs_mul_mm_id_map0 & args,
@@ -9436,7 +9931,7 @@ kernel void kernel_mul_mm_id(
9436
9931
 
9437
9932
  const short ib = 8*sx + sy;
9438
9933
 
9439
- *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
9934
+ *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
9440
9935
  }
9441
9936
  } else {
9442
9937
  S0_4x4 temp_a;
@@ -9649,6 +10144,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
9649
10144
 
9650
10145
  typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
9651
10146
 
10147
+ template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
9652
10148
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
9653
10149
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
9654
10150
  template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
@@ -9711,6 +10207,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
9711
10207
  #if defined(GGML_METAL_HAS_BF16)
9712
10208
  template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
9713
10209
  #endif
10210
+ template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
9714
10211
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
9715
10212
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
9716
10213
  template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
@@ -9734,6 +10231,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
9734
10231
 
9735
10232
  template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
9736
10233
  template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
10234
+ template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
9737
10235
  template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
9738
10236
  template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
9739
10237
  template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -9766,6 +10264,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m
9766
10264
  #if defined(GGML_METAL_HAS_BF16)
9767
10265
  template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
9768
10266
  #endif
10267
+ template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
9769
10268
  template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
9770
10269
  template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
9771
10270
  template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
@@ -9789,6 +10288,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
9789
10288
 
9790
10289
  template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
9791
10290
  template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
10291
+ template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
9792
10292
  template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
9793
10293
  template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
9794
10294
  template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -9943,6 +10443,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
9943
10443
 
9944
10444
  template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
9945
10445
 
10446
+ template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
9946
10447
  template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
9947
10448
  template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
9948
10449
  template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;