whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
203
203
  GGML_ABORT("unsupported op");
204
204
  }
205
205
 
206
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
207
+ return 1;
208
+ }
209
+
206
210
  int n_fuse = 1;
207
211
 
208
212
  // check if the current node can run concurrently with other nodes before it
@@ -283,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
283
287
  n_fuse = ggml_metal_op_acc(ctx, idx);
284
288
  } break;
285
289
  case GGML_OP_SCALE:
286
- {
287
- n_fuse = ggml_metal_op_scale(ctx, idx);
288
- } break;
289
290
  case GGML_OP_FILL:
290
- {
291
- n_fuse = ggml_metal_op_fill(ctx, idx);
292
- } break;
293
291
  case GGML_OP_CLAMP:
294
- {
295
- n_fuse = ggml_metal_op_clamp(ctx, idx);
296
- } break;
292
+ case GGML_OP_LEAKY_RELU:
297
293
  case GGML_OP_SQR:
298
294
  case GGML_OP_SQRT:
299
295
  case GGML_OP_SIN:
@@ -337,6 +333,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
337
333
  {
338
334
  n_fuse = ggml_metal_op_rwkv(ctx, idx);
339
335
  } break;
336
+ case GGML_OP_GATED_DELTA_NET:
337
+ {
338
+ n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
339
+ } break;
340
+ case GGML_OP_SOLVE_TRI:
341
+ {
342
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
343
+ } break;
340
344
  case GGML_OP_MUL_MAT:
341
345
  {
342
346
  n_fuse = ggml_metal_op_mul_mat(ctx, idx);
@@ -353,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
353
357
  {
354
358
  n_fuse = ggml_metal_op_set_rows(ctx, idx);
355
359
  } break;
360
+ case GGML_OP_DIAG:
361
+ {
362
+ n_fuse = ggml_metal_op_diag(ctx, idx);
363
+ } break;
356
364
  case GGML_OP_L2_NORM:
357
365
  {
358
366
  n_fuse = ggml_metal_op_l2_norm(ctx, idx);
@@ -386,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
386
394
  {
387
395
  n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
388
396
  } break;
397
+ case GGML_OP_CONV_3D:
398
+ {
399
+ n_fuse = ggml_metal_op_conv_3d(ctx, idx);
400
+ } break;
389
401
  case GGML_OP_UPSCALE:
390
402
  {
391
403
  n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -398,6 +410,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
398
410
  {
399
411
  n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);
400
412
  } break;
413
+ case GGML_OP_ROLL:
414
+ {
415
+ n_fuse = ggml_metal_op_roll(ctx, idx);
416
+ } break;
401
417
  case GGML_OP_ARANGE:
402
418
  {
403
419
  n_fuse = ggml_metal_op_arange(ctx, idx);
@@ -414,10 +430,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
414
430
  {
415
431
  n_fuse = ggml_metal_op_top_k(ctx, idx);
416
432
  } break;
417
- case GGML_OP_LEAKY_RELU:
418
- {
419
- n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
420
- } break;
421
433
  case GGML_OP_TRI:
422
434
  {
423
435
  n_fuse = ggml_metal_op_tri(ctx, idx);
@@ -426,12 +438,20 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
426
438
  {
427
439
  n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
428
440
  } break;
441
+ case GGML_OP_SET:
442
+ {
443
+ n_fuse = ggml_metal_op_set(ctx, idx);
444
+ } break;
429
445
  case GGML_OP_DUP:
430
446
  case GGML_OP_CPY:
431
447
  case GGML_OP_CONT:
432
448
  {
433
449
  n_fuse = ggml_metal_op_cpy(ctx, idx);
434
450
  } break;
451
+ case GGML_OP_POOL_1D:
452
+ {
453
+ n_fuse = ggml_metal_op_pool_1d(ctx, idx);
454
+ } break;
435
455
  case GGML_OP_POOL_2D:
436
456
  {
437
457
  n_fuse = ggml_metal_op_pool_2d(ctx, idx);
@@ -544,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
544
564
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
545
565
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
546
566
 
547
- const int nth = std::min(1024, ne0);
567
+ int nth = std::min(256, ne0);
548
568
 
549
- ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
569
+ // when rows are small, we can batch them together in a single threadgroup
570
+ int nrptg = 1;
571
+ if (nth < 256) {
572
+ nrptg = std::min((256 + nth - 1) / nth, ne1);
573
+ if (nrptg * nth > 256) {
574
+ nrptg = 256 / nth;
575
+ }
576
+ }
577
+
578
+ const int nw0 = (ne1 + nrptg - 1) / nrptg;
579
+
580
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1);
550
581
 
551
582
  return 1;
552
583
  }
@@ -612,8 +643,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
612
643
  GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
613
644
  GGML_ASSERT(op->type == GGML_TYPE_F32);
614
645
 
615
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
616
- GGML_ASSERT(ggml_is_contiguous(op->src[1]));
646
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
647
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
617
648
 
618
649
  const size_t pnb1 = ((const int32_t *) op->op_params)[0];
619
650
  const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -623,7 +654,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
623
654
  const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
624
655
 
625
656
  if (!inplace) {
626
- // run a separete kernel to cpy src->dst
657
+ // run a separate kernel to cpy src->dst
627
658
  // not sure how to avoid this
628
659
  // TODO: make a simpler cpy_bytes kernel
629
660
 
@@ -663,10 +694,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
663
694
  }
664
695
 
665
696
  ggml_metal_kargs_bin args = {
666
- /*.ne00 =*/ ne00,
667
- /*.ne01 =*/ ne01,
668
- /*.ne02 =*/ ne02,
669
- /*.ne03 =*/ ne03,
697
+ /*.ne00 =*/ ne10,
698
+ /*.ne01 =*/ ne11,
699
+ /*.ne02 =*/ ne12,
700
+ /*.ne03 =*/ ne13,
670
701
  /*.nb00 =*/ nb00,
671
702
  /*.nb01 =*/ pnb1,
672
703
  /*.nb02 =*/ pnb2,
@@ -679,10 +710,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
679
710
  /*.nb11 =*/ nb11,
680
711
  /*.nb12 =*/ nb12,
681
712
  /*.nb13 =*/ nb13,
682
- /*.ne0 =*/ ne0,
683
- /*.ne1 =*/ ne1,
684
- /*.ne2 =*/ ne2,
685
- /*.ne3 =*/ ne3,
713
+ /*.ne0 =*/ ne10,
714
+ /*.ne1 =*/ ne11,
715
+ /*.ne2 =*/ ne12,
716
+ /*.ne3 =*/ ne13,
686
717
  /*.nb0 =*/ nb0,
687
718
  /*.nb1 =*/ pnb1,
688
719
  /*.nb2 =*/ pnb2,
@@ -691,7 +722,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
691
722
  /*.o1 =*/ { 0 },
692
723
  };
693
724
 
694
- auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
725
+ auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
695
726
 
696
727
  ggml_metal_encoder_set_pipeline(enc, pipeline);
697
728
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -699,53 +730,20 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
699
730
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
700
731
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
701
732
 
702
- const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
703
-
704
- ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
733
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
705
734
 
706
- return 1;
707
- }
708
-
709
- int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
710
- ggml_tensor * op = ctx->node(idx);
711
-
712
- ggml_metal_library_t lib = ctx->lib;
713
- ggml_metal_encoder_t enc = ctx->enc;
714
-
715
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
716
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
717
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
718
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
719
-
720
- float scale;
721
- float bias;
722
- memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
723
- memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
724
-
725
- ggml_metal_kargs_scale args = {
726
- /*.scale =*/ scale,
727
- /*.bias =*/ bias,
728
- };
729
-
730
- int64_t n = ggml_nelements(op);
735
+ int nth = 1;
731
736
 
732
- if (n % 4 == 0) {
733
- n /= 4;
737
+ while (2*nth < args.ne0 && nth < nth_max) {
738
+ nth *= 2;
734
739
  }
735
740
 
736
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
737
-
738
- ggml_metal_encoder_set_pipeline(enc, pipeline);
739
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
740
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
741
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
742
-
743
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
741
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
744
742
 
745
743
  return 1;
746
744
  }
747
745
 
748
- int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
746
+ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
749
747
  ggml_tensor * op = ctx->node(idx);
750
748
 
751
749
  ggml_metal_library_t lib = ctx->lib;
@@ -756,94 +754,85 @@ int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
756
754
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
757
755
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
758
756
 
759
- const float val = ggml_get_op_params_f32(op, 0);
757
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
760
758
 
761
- ggml_metal_kargs_fill args = {
762
- /*.val =*/ val
763
- };
759
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
760
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
764
761
 
765
- int64_t n = ggml_nelements(op);
762
+ ggml_metal_kargs_unary args = {
763
+ /*.ne00 =*/ ne00,
764
+ /*.ne01 =*/ ne01,
765
+ /*.ne02 =*/ ne02,
766
+ /*.ne03 =*/ ne03,
767
+ /*.nb00 =*/ nb00,
768
+ /*.nb01 =*/ nb01,
769
+ /*.nb02 =*/ nb02,
770
+ /*.nb03 =*/ nb03,
771
+ /*.ne0 =*/ ne0,
772
+ /*.ne1 =*/ ne1,
773
+ /*.ne2 =*/ ne2,
774
+ /*.ne3 =*/ ne3,
775
+ /*.nb0 =*/ nb0,
776
+ /*.nb1 =*/ nb1,
777
+ /*.nb2 =*/ nb2,
778
+ /*.nb3 =*/ nb3,
779
+ /*.slope =*/ 0.0,
780
+ /*.scale =*/ 0.0,
781
+ /*.bias =*/ 0.0,
782
+ /*.val =*/ 0.0,
783
+ /*.min =*/ 0.0,
784
+ /*.max =*/ 0.0,
785
+ };
766
786
 
767
- if (n % 4 == 0) {
768
- n /= 4;
787
+ if (op->op == GGML_OP_LEAKY_RELU) {
788
+ args.slope = ggml_get_op_params_f32(op, 0);
769
789
  }
770
790
 
771
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
772
-
773
- ggml_metal_encoder_set_pipeline(enc, pipeline);
774
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
775
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
776
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
777
-
778
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
779
-
780
- return 1;
781
- }
782
-
783
- int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
784
- ggml_tensor * op = ctx->node(idx);
785
-
786
- ggml_metal_library_t lib = ctx->lib;
787
- ggml_metal_encoder_t enc = ctx->enc;
788
-
789
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
790
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
791
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
792
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
793
-
794
- float min;
795
- float max;
796
- memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
797
- memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
791
+ if (op->op == GGML_OP_SCALE) {
792
+ args.scale = ggml_get_op_params_f32(op, 0);
793
+ args.bias = ggml_get_op_params_f32(op, 1);
794
+ }
798
795
 
799
- ggml_metal_kargs_clamp args = {
800
- /*.min =*/ min,
801
- /*.max =*/ max,
802
- };
796
+ if (op->op == GGML_OP_FILL) {
797
+ args.val = ggml_get_op_params_f32(op, 0);
798
+ }
803
799
 
804
- int64_t n = ggml_nelements(op);
800
+ if (op->op == GGML_OP_CLAMP) {
801
+ args.min = ggml_get_op_params_f32(op, 0);
802
+ args.max = ggml_get_op_params_f32(op, 1);
803
+ }
805
804
 
806
- if (n % 4 == 0) {
807
- n /= 4;
805
+ if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) {
806
+ args.slope = ggml_get_op_params_f32(op, 1); // alpha_n
807
+ args.scale = ggml_get_op_params_f32(op, 2); // alpha_p
808
+ args.bias = ggml_get_op_params_f32(op, 3); // beta
809
+ args.val = ggml_get_op_params_f32(op, 4); // eps
808
810
  }
809
811
 
810
812
  auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
811
813
 
814
+ if (pipeline.c4) {
815
+ args.ne00 = ne00/4;
816
+ args.ne0 = ne0/4;
817
+ }
818
+
812
819
  ggml_metal_encoder_set_pipeline(enc, pipeline);
813
820
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
814
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
815
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
816
-
817
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
818
-
819
- return 1;
820
- }
821
-
822
- int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
823
- ggml_tensor * op = ctx->node(idx);
824
-
825
- ggml_metal_library_t lib = ctx->lib;
826
- ggml_metal_encoder_t enc = ctx->enc;
821
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
822
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
827
823
 
828
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
829
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
830
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
831
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
824
+ if (pipeline.cnt) {
825
+ const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
832
826
 
833
- int64_t n = ggml_nelements(op);
827
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
828
+ } else {
829
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
830
+ const int nth = MIN(args.ne00, nth_max);
831
+ const int nk0 = (args.ne00 + nth - 1)/nth;
834
832
 
835
- if (n % 4 == 0) {
836
- n /= 4;
833
+ ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
837
834
  }
838
835
 
839
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
840
-
841
- ggml_metal_encoder_set_pipeline(enc, pipeline);
842
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
843
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
844
-
845
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
846
-
847
836
  return 1;
848
837
  }
849
838
 
@@ -953,6 +942,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
953
942
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
954
943
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
955
944
 
945
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
946
+
947
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
948
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
949
+
956
950
  ggml_metal_kargs_sum_rows args = {
957
951
  /*.ne00 =*/ ne00,
958
952
  /*.ne01 =*/ ne01,
@@ -974,21 +968,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
974
968
 
975
969
  auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
976
970
 
971
+ if (pipeline.c4) {
972
+ args.ne00 = ne00/4;
973
+ args.ne0 = ne0/4;
974
+ }
975
+
977
976
  int nth = 32; // SIMD width
978
977
 
979
- while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
978
+ while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
980
979
  nth *= 2;
981
980
  }
982
981
 
983
982
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
984
- nth = std::min(nth, ne00);
983
+ nth = std::min(nth, (int) args.ne00);
985
984
 
986
985
  const size_t smem = pipeline.smem;
987
986
 
988
987
  ggml_metal_encoder_set_pipeline(enc, pipeline);
989
988
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
990
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
991
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
989
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
990
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
992
991
 
993
992
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
994
993
 
@@ -1247,6 +1246,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
1247
1246
  return 1;
1248
1247
  }
1249
1248
 
1249
+ int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
1250
+ ggml_tensor * op = ctx->node(idx);
1251
+
1252
+ ggml_metal_library_t lib = ctx->lib;
1253
+ ggml_metal_encoder_t enc = ctx->enc;
1254
+
1255
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
1256
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1257
+ GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
1258
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1259
+
1260
+ ggml_metal_kargs_diag args = {
1261
+ /*.ne00 =*/ne00,
1262
+ /*.ne01 =*/ne01,
1263
+ /*.ne02 =*/ne02,
1264
+ /*.ne03 =*/ne03,
1265
+ /*.nb00 =*/nb00,
1266
+ /*.nb01 =*/nb01,
1267
+ /*.nb02 =*/nb02,
1268
+ /*.nb03 =*/nb03,
1269
+ /*.ne0 =*/ne0,
1270
+ /*.ne1 =*/ne1,
1271
+ /*.ne2 =*/ne2,
1272
+ /*.ne3 =*/ne3,
1273
+ /*.nb0 =*/nb0,
1274
+ /*.nb1 =*/nb1,
1275
+ /*.nb2 =*/nb2,
1276
+ /*.nb3 =*/nb3,
1277
+ };
1278
+
1279
+ auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
1280
+
1281
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1282
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1283
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1284
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
1285
+
1286
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
1287
+
1288
+ return 1;
1289
+ }
1290
+
1250
1291
  int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1251
1292
  ggml_tensor * op = ctx->node(idx);
1252
1293
 
@@ -1524,27 +1565,287 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1524
1565
  const int64_t C = op->ne[0];
1525
1566
  const int64_t H = op->src[0]->ne[1];
1526
1567
 
1527
- auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1568
+ auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1569
+
1570
+ int ida = 0;
1571
+
1572
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1573
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
1574
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
1575
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
1576
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
1577
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
1578
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
1579
+ if (op->op == GGML_OP_RWKV_WKV7) {
1580
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
1581
+ }
1582
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
1583
+ ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
1584
+ ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
1585
+ ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
1586
+ ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
1587
+
1588
+ ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1589
+
1590
+ return 1;
1591
+ }
1592
+
1593
+ int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
1594
+ ggml_tensor * op = ctx->node(idx);
1595
+
1596
+ ggml_metal_library_t lib = ctx->lib;
1597
+ ggml_metal_encoder_t enc = ctx->enc;
1598
+
1599
+
1600
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1601
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1602
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1603
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1604
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1605
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1606
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1607
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1608
+
1609
+ auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
1610
+
1611
+ int ida = 0;
1612
+
1613
+ ggml_metal_kargs_gated_delta_net args = {
1614
+ /*.ne00 =*/ ne00,
1615
+ /*.ne01 =*/ ne01,
1616
+ /*.ne02 =*/ ne02,
1617
+ /*.ne03 =*/ ne03,
1618
+ /*.nb00 =*/ nb00,
1619
+ /*.nb01 =*/ nb01,
1620
+ /*.nb02 =*/ nb02,
1621
+ /*.nb03 =*/ nb03,
1622
+ /*.ne10 =*/ ne10,
1623
+ /*.ne11 =*/ ne11,
1624
+ /*.ne12 =*/ ne12,
1625
+ /*.ne13 =*/ ne13,
1626
+ /*.nb10 =*/ nb10,
1627
+ /*.nb11 =*/ nb11,
1628
+ /*.nb12 =*/ nb12,
1629
+ /*.nb13 =*/ nb13,
1630
+ /*.ne20 =*/ ne20,
1631
+ /*.ne21 =*/ ne21,
1632
+ /*.ne22 =*/ ne22,
1633
+ /*.ne23 =*/ ne23,
1634
+ /*.nb20 =*/ nb20,
1635
+ /*.nb21 =*/ nb21,
1636
+ /*.nb22 =*/ nb22,
1637
+ /*.nb23 =*/ nb23,
1638
+ /*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
1639
+ /*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
1640
+ /*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
1641
+ /*.ne0 =*/ ne0,
1642
+ /*.ne1 =*/ ne1,
1643
+ /*.ne2 =*/ ne2,
1644
+ /*.ne3 =*/ ne3,
1645
+ /*.nb0 =*/ nb0,
1646
+ /*.nb1 =*/ nb1,
1647
+ /*.nb2 =*/ nb2,
1648
+ /*.nb3 =*/ nb3,
1649
+ };
1650
+
1651
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1652
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
1653
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
1654
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
1655
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
1656
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
1657
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
1658
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
1659
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
1660
+
1661
+ const int nsg = pipeline.nsg;
1662
+
1663
+ ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
1664
+
1665
+ return 1;
1666
+ }
1667
+
1668
+ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
1669
+ ggml_tensor * op = ctx->node(idx);
1670
+
1671
+ ggml_metal_library_t lib = ctx->lib;
1672
+ ggml_metal_encoder_t enc = ctx->enc;
1673
+
1674
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1675
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1676
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1677
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1678
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1679
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1680
+
1681
+ ggml_metal_kargs_solve_tri args = {
1682
+ /*.ne00 =*/ ne00,
1683
+ /*.ne01 =*/ ne01,
1684
+ /*.ne02 =*/ ne02,
1685
+ /*.ne03 =*/ ne03,
1686
+ /*.nb00 =*/ nb00,
1687
+ /*.nb01 =*/ nb01,
1688
+ /*.nb02 =*/ nb02,
1689
+ /*.nb03 =*/ nb03,
1690
+ /*.ne10 =*/ ne10,
1691
+ /*.ne11 =*/ ne11,
1692
+ /*.ne12 =*/ ne12,
1693
+ /*.ne13 =*/ ne13,
1694
+ /*.nb10 =*/ nb10,
1695
+ /*.nb11 =*/ nb11,
1696
+ /*.nb12 =*/ nb12,
1697
+ /*.nb13 =*/ nb13,
1698
+ /*.ne0 =*/ ne0,
1699
+ /*.ne1 =*/ ne1,
1700
+ /*.ne2 =*/ ne2,
1701
+ /*.ne3 =*/ ne3,
1702
+ /*.nb0 =*/ nb0,
1703
+ /*.nb1 =*/ nb1,
1704
+ /*.nb2 =*/ nb2,
1705
+ /*.nb3 =*/ nb3,
1706
+ };
1707
+
1708
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
1709
+
1710
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1711
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1712
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1713
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1714
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1715
+
1716
+ const int nsg = pipeline.nsg;
1717
+
1718
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
1719
+
1720
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
1721
+
1722
+ return 1;
1723
+ }
1724
+
1725
+ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
1726
+ ggml_tensor * op = ctx->node(idx);
1727
+
1728
+ ggml_metal_library_t lib = ctx->lib;
1729
+ ggml_metal_encoder_t enc = ctx->enc;
1730
+
1731
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1732
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1733
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1734
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1735
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1736
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1737
+
1738
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
1739
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
1740
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
1741
+
1742
+ const size_t pnb1 = ((const int32_t *) op->op_params)[0];
1743
+ const size_t pnb2 = ((const int32_t *) op->op_params)[1];
1744
+ const size_t pnb3 = ((const int32_t *) op->op_params)[2];
1745
+ const size_t offs = ((const int32_t *) op->op_params)[3];
1746
+
1747
+ const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
1748
+
1749
+ if (!inplace) {
1750
+ // run a separate kernel to cpy src->dst
1751
+ // not sure how to avoid this
1752
+ // TODO: make a simpler cpy_bytes kernel
1753
+
1754
+ //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
1755
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1756
+
1757
+ ggml_metal_kargs_cpy args = {
1758
+ /*.nk0 =*/ ne00,
1759
+ /*.ne00 =*/ ne00,
1760
+ /*.ne01 =*/ ne01,
1761
+ /*.ne02 =*/ ne02,
1762
+ /*.ne03 =*/ ne03,
1763
+ /*.nb00 =*/ nb00,
1764
+ /*.nb01 =*/ nb01,
1765
+ /*.nb02 =*/ nb02,
1766
+ /*.nb03 =*/ nb03,
1767
+ /*.ne0 =*/ ne0,
1768
+ /*.ne1 =*/ ne1,
1769
+ /*.ne2 =*/ ne2,
1770
+ /*.ne3 =*/ ne3,
1771
+ /*.nb0 =*/ nb0,
1772
+ /*.nb1 =*/ nb1,
1773
+ /*.nb2 =*/ nb2,
1774
+ /*.nb3 =*/ nb3,
1775
+ };
1776
+
1777
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1778
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1779
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1780
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1781
+
1782
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
1783
+
1784
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1785
+
1786
+ ggml_metal_op_concurrency_reset(ctx);
1787
+ }
1788
+
1789
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
1790
+
1791
+ GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
1792
+
1793
+ int64_t nk0 = ne10;
1794
+ if (ggml_is_quantized(op->src[1]->type)) {
1795
+ nk0 = ne10/16;
1796
+ } else if (ggml_is_quantized(op->type)) {
1797
+ nk0 = ne10/ggml_blck_size(op->type);
1798
+ }
1799
+
1800
+ int nth = std::min<int>(nk0*ne11, 256);
1801
+
1802
+ // when rows are small, we can batch them together in a single threadgroup
1803
+ int nrptg = 1;
1804
+
1805
+ // TODO: relax this constraint in the future
1806
+ if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
1807
+ if (nth > nk0) {
1808
+ nrptg = (nth + nk0 - 1)/nk0;
1809
+ nth = nk0;
1810
+
1811
+ if (nrptg*nth > 256) {
1812
+ nrptg--;
1813
+ }
1814
+ }
1815
+ }
1816
+
1817
+ nth = std::min<int>(nth, nk0);
1818
+
1819
+ ggml_metal_kargs_cpy args = {
1820
+ /*.nk0 =*/ nk0,
1821
+ /*.ne00 =*/ ne10,
1822
+ /*.ne01 =*/ ne11,
1823
+ /*.ne02 =*/ ne12,
1824
+ /*.ne03 =*/ ne13,
1825
+ /*.nb00 =*/ nb10,
1826
+ /*.nb01 =*/ nb11,
1827
+ /*.nb02 =*/ nb12,
1828
+ /*.nb03 =*/ nb13,
1829
+ /*.ne0 =*/ ne10,
1830
+ /*.ne1 =*/ ne11,
1831
+ /*.ne2 =*/ ne12,
1832
+ /*.ne3 =*/ ne13,
1833
+ /*.nb0 =*/ ggml_element_size(op),
1834
+ /*.nb1 =*/ pnb1,
1835
+ /*.nb2 =*/ pnb2,
1836
+ /*.nb3 =*/ pnb3,
1837
+ };
1838
+
1839
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1528
1840
 
1529
- int ida = 0;
1841
+ bid_dst.offs += offs;
1530
1842
 
1531
1843
  ggml_metal_encoder_set_pipeline(enc, pipeline);
1532
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
1533
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
1534
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
1535
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
1536
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
1537
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
1538
- if (op->op == GGML_OP_RWKV_WKV7) {
1539
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
1540
- }
1541
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
1542
- ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
1543
- ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
1544
- ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
1545
- ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
1844
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1845
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
1846
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1546
1847
 
1547
- ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1848
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
1548
1849
 
1549
1850
  return 1;
1550
1851
  }
@@ -1571,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1571
1872
  nk0 = ne00/ggml_blck_size(op->type);
1572
1873
  }
1573
1874
 
1574
- int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1875
+ int nth = std::min<int>(nk0*ne01, 256);
1575
1876
 
1576
1877
  // when rows are small, we can batch them together in a single threadgroup
1577
1878
  int nrptg = 1;
@@ -1582,7 +1883,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1582
1883
  nrptg = (nth + nk0 - 1)/nk0;
1583
1884
  nth = nk0;
1584
1885
 
1585
- if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1886
+ if (nrptg*nth > 256) {
1586
1887
  nrptg--;
1587
1888
  }
1588
1889
  }
@@ -1622,6 +1923,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1622
1923
  return 1;
1623
1924
  }
1624
1925
 
1926
+ int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
1927
+ ggml_tensor * op = ctx->node(idx);
1928
+
1929
+ ggml_metal_library_t lib = ctx->lib;
1930
+ ggml_metal_encoder_t enc = ctx->enc;
1931
+
1932
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1933
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1934
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1935
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1936
+
1937
+ const int32_t * opts = op->op_params;
1938
+ ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1939
+
1940
+ const int32_t k0 = opts[1];
1941
+ const int32_t s0 = opts[2];
1942
+ const int32_t p0 = opts[3];
1943
+
1944
+ const int64_t IW = op->src[0]->ne[0];
1945
+ const int64_t OW = op->ne[0];
1946
+
1947
+ const int64_t np = ggml_nelements(op);
1948
+
1949
+ ggml_metal_kargs_pool_1d args_pool_1d = {
1950
+ /* .k0 = */ k0,
1951
+ /* .s0 = */ s0,
1952
+ /* .p0 = */ p0,
1953
+ /* .IW = */ IW,
1954
+ /* .OW = */ OW,
1955
+ /* .np = */ np
1956
+ };
1957
+
1958
+ auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
1959
+
1960
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1961
+ const int ntg = (np + nth - 1) / nth;
1962
+
1963
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1964
+ ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
1965
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1966
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1967
+
1968
+ ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1969
+
1970
+ return 1;
1971
+ }
1972
+
1973
+
1625
1974
  int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1626
1975
  ggml_tensor * op = ctx->node(idx);
1627
1976
 
@@ -1717,6 +2066,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1717
2066
  (
1718
2067
  op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
1719
2068
  op->src[0]->type == GGML_TYPE_F16 ||
2069
+ op->src[0]->type == GGML_TYPE_BF16 ||
2070
+ op->src[0]->type == GGML_TYPE_Q1_0 ||
1720
2071
  op->src[0]->type == GGML_TYPE_Q4_0 ||
1721
2072
  op->src[0]->type == GGML_TYPE_Q4_1 ||
1722
2073
  op->src[0]->type == GGML_TYPE_Q5_0 ||
@@ -1731,6 +2082,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1731
2082
  op->src[0]->type == GGML_TYPE_Q4_K ||
1732
2083
  op->src[0]->type == GGML_TYPE_Q5_K ||
1733
2084
  op->src[0]->type == GGML_TYPE_Q6_K ||
2085
+ op->src[0]->type == GGML_TYPE_Q2_K ||
2086
+ op->src[0]->type == GGML_TYPE_Q3_K ||
1734
2087
  false) && (ne11 >= 4 && ne11 <= 8)
1735
2088
  )
1736
2089
  )
@@ -1759,7 +2112,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1759
2112
  const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
1760
2113
  int16_t r1ptg = 4; // num src1 rows per threadgroup
1761
2114
 
1762
- // note: not sure how optimal are those across all different hardware. there might be someting cleverer
2115
+ // note: not sure how optimal are those across all different hardware. there might be something cleverer
1763
2116
  switch (ne11) {
1764
2117
  case 2:
1765
2118
  r1ptg = 2; break;
@@ -1776,7 +2129,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1776
2129
  GGML_ABORT("unsupported ne11");
1777
2130
  };
1778
2131
 
1779
- auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
2132
+ auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg);
1780
2133
 
1781
2134
  ggml_metal_kargs_mul_mv_ext args = {
1782
2135
  /*.ne00 =*/ ne00,
@@ -1851,7 +2204,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1851
2204
  const size_t smem = pipeline.smem;
1852
2205
 
1853
2206
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1854
- ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
2207
+
2208
+ const int nr0 = pipeline.nr0;
2209
+ const int nr1 = pipeline.nr1;
2210
+ const int nsg = pipeline.nsg;
2211
+
2212
+ ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + nr1 - 1) / nr1), ((ne01 + nr0 - 1) / nr0), ne12 * ne13, 32, nsg, 1);
1855
2213
  } else {
1856
2214
  auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
1857
2215
 
@@ -2239,7 +2597,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
2239
2597
  // return res;
2240
2598
  //}
2241
2599
 
2242
- const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
2600
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
2243
2601
  const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2244
2602
 
2245
2603
  const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
@@ -2355,7 +2713,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2355
2713
 
2356
2714
  if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
2357
2715
  // half8x8 kernel
2358
- const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2716
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
2359
2717
  const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
2360
2718
 
2361
2719
  GGML_ASSERT(nqptg <= 32);
@@ -2464,7 +2822,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2464
2822
 
2465
2823
  // simdgroups per threadgroup (a.k.a. warps)
2466
2824
  //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2467
- int32_t nsg = 4;
2825
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
2468
2826
 
2469
2827
  const size_t smem = FATTN_SMEM(nsg);
2470
2828
 
@@ -2522,9 +2880,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2522
2880
  #undef FATTN_SMEM
2523
2881
  } else {
2524
2882
  // half4x4 kernel
2525
- const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2883
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
2526
2884
  const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2527
- const int nkpsg = 1*ncpsg;
2885
+ const int nhptg = 1; // heads per threadgroup
2528
2886
 
2529
2887
  GGML_ASSERT(nqptg <= 32);
2530
2888
  GGML_ASSERT(nqptg % 1 == 0);
@@ -2576,6 +2934,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2576
2934
  ggml_metal_op_concurrency_reset(ctx);
2577
2935
  }
2578
2936
 
2937
+ // note: for simplicity assume the K is larger or equal than V
2938
+ GGML_ASSERT(ne10 >= ne20);
2939
+
2579
2940
  // ne00 + 2*ncpsg*(nsg)
2580
2941
  // for each query, we load it as f16 in shared memory (ne00)
2581
2942
  // and store the soft_max values and the mask
@@ -2583,28 +2944,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2583
2944
  // ne20*(nsg)
2584
2945
  // each simdgroup has a full f32 head vector in shared mem to accumulate results
2585
2946
  //
2586
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
2587
-
2588
- int64_t nsgmax = 2;
2589
- while (true) {
2590
- const size_t smem = FATTN_SMEM(nsgmax);
2591
- // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
2592
- if (smem > props_dev->max_theadgroup_memory_size/2) {
2593
- break;
2594
- }
2595
- nsgmax *= 2;
2596
- }
2597
- nsgmax /= 2;
2598
-
2599
- // simdgroups per threadgroup (a.k.a. warps)
2600
- //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
2601
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
2947
+ #define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
2602
2948
 
2603
2949
  int64_t nsg = 1;
2604
- while (nsg <= nsgt) {
2605
- nsg *= 2;
2606
- }
2607
- nsg /= 2;
2608
2950
 
2609
2951
  // workgroups
2610
2952
  // each workgroup handles nsg*nkpsg cache values
@@ -2617,7 +2959,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2617
2959
  } else {
2618
2960
  nwg = 32;
2619
2961
  nsg = 1;
2620
- while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
2962
+ while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
2621
2963
  nsg *= 2;
2622
2964
  }
2623
2965
  }
@@ -2683,7 +3025,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2683
3025
 
2684
3026
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2685
3027
 
2686
- ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
3028
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2687
3029
  } else {
2688
3030
  // sanity checks
2689
3031
  assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
@@ -2696,7 +3038,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2696
3038
  ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2697
3039
 
2698
3040
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2699
- ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
3041
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2700
3042
 
2701
3043
  // sync the 2 kernels
2702
3044
  ggml_metal_op_concurrency_reset(ctx);
@@ -2748,8 +3090,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2748
3090
  GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2749
3091
  GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
2750
3092
 
2751
- bool bcast_row = false;
2752
-
2753
3093
  ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2754
3094
  ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2755
3095
  ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
@@ -2843,18 +3183,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2843
3183
 
2844
3184
  struct ggml_metal_pipeline_with_params pipeline;
2845
3185
 
2846
- if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2847
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
2848
-
2849
- // src1 is a row
2850
- GGML_ASSERT(ne11 == 1);
2851
-
2852
- pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
2853
-
2854
- bcast_row = true;
2855
- } else {
2856
- pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
2857
- }
3186
+ pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
2858
3187
 
2859
3188
  if (n_fuse > 1) {
2860
3189
  bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
@@ -2868,20 +3197,26 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2868
3197
  }
2869
3198
  }
2870
3199
 
3200
+ if (pipeline.c4) {
3201
+ args.ne00 = ne00/4;
3202
+ args.ne10 = ne10/4;
3203
+ args.ne0 = ne0/4;
3204
+ }
3205
+
2871
3206
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2872
3207
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2873
3208
  ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2874
3209
  ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2875
3210
  ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
2876
3211
 
2877
- if (bcast_row) {
2878
- const int64_t n = ggml_nelements(op)/4;
2879
-
2880
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
3212
+ if (pipeline.cnt) {
3213
+ ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
2881
3214
  } else {
2882
- int nth = 32;
3215
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2883
3216
 
2884
- while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3217
+ int nth = 1;
3218
+
3219
+ while (2*nth < args.ne0 && nth < nth_max) {
2885
3220
  nth *= 2;
2886
3221
  }
2887
3222
 
@@ -2902,39 +3237,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2902
3237
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2903
3238
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2904
3239
 
3240
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3241
+
3242
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3243
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3244
+
2905
3245
  float eps;
2906
3246
  memcpy(&eps, op->op_params, sizeof(float));
2907
3247
 
2908
- int nth = 32; // SIMD width
2909
-
2910
3248
  ggml_metal_kargs_l2_norm args = {
2911
- /*.ne00 =*/ ne00,
2912
- /*.ne00_4 =*/ ne00/4,
2913
- /*.nb01 =*/ nb01,
2914
- /*.eps =*/ eps,
3249
+ /*.ne00 =*/ ne00,
3250
+ /*.ne01 =*/ ne01,
3251
+ /*.ne02 =*/ ne02,
3252
+ /*.ne03 =*/ ne03,
3253
+ /*.nb00 =*/ nb00,
3254
+ /*.nb01 =*/ nb01,
3255
+ /*.nb02 =*/ nb02,
3256
+ /*.nb03 =*/ nb03,
3257
+ /*.ne0 =*/ ne0,
3258
+ /*.ne1 =*/ ne1,
3259
+ /*.ne2 =*/ ne2,
3260
+ /*.ne3 =*/ ne3,
3261
+ /*.nb0 =*/ nb0,
3262
+ /*.nb1 =*/ nb1,
3263
+ /*.nb2 =*/ nb2,
3264
+ /*.nb3 =*/ nb3,
3265
+ /*.eps =*/ eps,
2915
3266
  };
2916
3267
 
2917
3268
  auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
2918
3269
 
2919
- while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3270
+ if (pipeline.c4) {
3271
+ args.ne00 = ne00/4;
3272
+ args.ne0 = ne0/4;
3273
+ }
3274
+
3275
+ int nth = 32; // SIMD width
3276
+
3277
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2920
3278
  nth *= 2;
2921
3279
  }
2922
3280
 
2923
3281
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2924
- nth = std::min(nth, ne00/4);
2925
3282
 
2926
3283
  const size_t smem = pipeline.smem;
2927
3284
 
2928
- const int64_t nrows = ggml_nrows(op->src[0]);
2929
-
2930
3285
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2931
3286
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2932
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2933
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3287
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3288
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
2934
3289
 
2935
3290
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2936
3291
 
2937
- ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
3292
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2938
3293
 
2939
3294
  return 1;
2940
3295
  }
@@ -3280,16 +3635,26 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
3280
3635
 
3281
3636
  auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
3282
3637
 
3283
- GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3638
+ if (KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3639
+ const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
3284
3640
 
3285
- const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
3641
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3642
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3643
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3644
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3286
3645
 
3287
- ggml_metal_encoder_set_pipeline(enc, pipeline);
3288
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3289
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3290
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3646
+ ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3647
+ } else {
3648
+ const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N);
3649
+ const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
3650
+
3651
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3652
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3653
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3654
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3291
3655
 
3292
- ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3656
+ ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1);
3657
+ }
3293
3658
 
3294
3659
  return 1;
3295
3660
  }
@@ -3372,6 +3737,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
3372
3737
  return 1;
3373
3738
  }
3374
3739
 
3740
+ int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) {
3741
+ ggml_tensor * op = ctx->node(idx);
3742
+
3743
+ ggml_metal_library_t lib = ctx->lib;
3744
+ ggml_metal_encoder_t enc = ctx->enc;
3745
+
3746
+ // 1. Extract standard dimensions and byte strides
3747
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3748
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3749
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3750
+
3751
+ // 2. Extract hyperparams from op_params
3752
+ const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3753
+ const int32_t s1 = ((const int32_t *)(op->op_params))[1];
3754
+ const int32_t s2 = ((const int32_t *)(op->op_params))[2];
3755
+ const int32_t p0 = ((const int32_t *)(op->op_params))[3];
3756
+ const int32_t p1 = ((const int32_t *)(op->op_params))[4];
3757
+ const int32_t p2 = ((const int32_t *)(op->op_params))[5];
3758
+ const int32_t d0 = ((const int32_t *)(op->op_params))[6];
3759
+ const int32_t d1 = ((const int32_t *)(op->op_params))[7];
3760
+ const int32_t d2 = ((const int32_t *)(op->op_params))[8];
3761
+ const int32_t IC = ((const int32_t *)(op->op_params))[9];
3762
+ const int32_t N = ((const int32_t *)(op->op_params))[10];
3763
+ const int32_t OC = ((const int32_t *)(op->op_params))[11];
3764
+
3765
+ // 3. Build the parameter struct using the macro-generated variables
3766
+ ggml_metal_kargs_conv_3d args = {
3767
+ /*.IW =*/ (int32_t)op->src[1]->ne[0],
3768
+ /*.IH =*/ (int32_t)op->src[1]->ne[1],
3769
+ /*.ID =*/ (int32_t)op->src[1]->ne[2],
3770
+ /*.OW =*/ (int32_t)op->ne[0],
3771
+ /*.OH =*/ (int32_t)op->ne[1],
3772
+ /*.OD =*/ (int32_t)op->ne[2],
3773
+ /*.KW =*/ (int32_t)op->src[0]->ne[0],
3774
+ /*.KH =*/ (int32_t)op->src[0]->ne[1],
3775
+ /*.KD =*/ (int32_t)op->src[0]->ne[2],
3776
+ s0, s1, s2,
3777
+ p0, p1, p2,
3778
+ d0, d1, d2,
3779
+ IC, N, OC,
3780
+ nb00, nb01, nb02, nb03, // Weight strides
3781
+ nb10, nb11, nb12, nb13, // Input strides
3782
+ nb0, nb1, nb2, nb3 // Output strides
3783
+ };
3784
+
3785
+ // 4. Fetch the JIT pipeline
3786
+ auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op);
3787
+
3788
+ // 5. Grid mapping
3789
+ int nth0 = 32; // Standard SIMD width for Apple Silicon
3790
+ int nth1 = 1;
3791
+ int nth2 = 1;
3792
+
3793
+ int64_t spatial_volume = args.OW * args.OH * args.OD;
3794
+
3795
+ int ntg0 = (spatial_volume + nth0 - 1) / nth0;
3796
+ int ntg1 = args.OC;
3797
+ int ntg2 = args.N;
3798
+
3799
+ // 6. Bind and Dispatch via the ggml C wrapper
3800
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3801
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3802
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3803
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3804
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3805
+
3806
+ ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2);
3807
+
3808
+ return 1;
3809
+ }
3810
+
3375
3811
  int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
3376
3812
  ggml_tensor * op = ctx->node(idx);
3377
3813
 
@@ -3484,12 +3920,76 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
3484
3920
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3485
3921
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3486
3922
 
3487
- const float sf0 = (float)ne0/op->src[0]->ne[0];
3488
- const float sf1 = (float)ne1/op->src[0]->ne[1];
3489
- const float sf2 = (float)ne2/op->src[0]->ne[2];
3490
- const float sf3 = (float)ne3/op->src[0]->ne[3];
3923
+ float sf0 = (float)ne0/op->src[0]->ne[0];
3924
+ float sf1 = (float)ne1/op->src[0]->ne[1];
3925
+ float sf2 = (float)ne2/op->src[0]->ne[2];
3926
+ float sf3 = (float)ne3/op->src[0]->ne[3];
3927
+
3928
+ const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
3929
+
3930
+ float poffs = 0.5f;
3931
+
3932
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
3933
+ poffs = 0.0f;
3934
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
3935
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
3936
+ }
3491
3937
 
3492
3938
  ggml_metal_kargs_upscale args = {
3939
+ /*.ne00 =*/ ne00,
3940
+ /*.ne01 =*/ ne01,
3941
+ /*.ne02 =*/ ne02,
3942
+ /*.ne03 =*/ ne03,
3943
+ /*.nb00 =*/ nb00,
3944
+ /*.nb01 =*/ nb01,
3945
+ /*.nb02 =*/ nb02,
3946
+ /*.nb03 =*/ nb03,
3947
+ /*.ne0 =*/ ne0,
3948
+ /*.ne1 =*/ ne1,
3949
+ /*.ne2 =*/ ne2,
3950
+ /*.ne3 =*/ ne3,
3951
+ /*.nb0 =*/ nb0,
3952
+ /*.nb1 =*/ nb1,
3953
+ /*.nb2 =*/ nb2,
3954
+ /*.nb3 =*/ nb3,
3955
+ /*.sf0 =*/ sf0,
3956
+ /*.sf1 =*/ sf1,
3957
+ /*.sf2 =*/ sf2,
3958
+ /*.sf3 =*/ sf3,
3959
+ /*.poffs =*/ poffs,
3960
+ };
3961
+
3962
+ auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
3963
+
3964
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3965
+
3966
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
3967
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3968
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3969
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3970
+
3971
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3972
+
3973
+ return 1;
3974
+ }
3975
+
3976
+ int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) {
3977
+ ggml_tensor * op = ctx->node(idx);
3978
+
3979
+ ggml_metal_library_t lib = ctx->lib;
3980
+ ggml_metal_encoder_t enc = ctx->enc;
3981
+
3982
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3983
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3984
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3985
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3986
+
3987
+ const int32_t s0 = ggml_get_op_params_i32(op, 0);
3988
+ const int32_t s1 = ggml_get_op_params_i32(op, 1);
3989
+ const int32_t s2 = ggml_get_op_params_i32(op, 2);
3990
+ const int32_t s3 = ggml_get_op_params_i32(op, 3);
3991
+
3992
+ ggml_metal_kargs_roll args = {
3493
3993
  /*.ne00 =*/ ne00,
3494
3994
  /*.ne01 =*/ ne01,
3495
3995
  /*.ne02 =*/ ne02,
@@ -3498,23 +3998,23 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
3498
3998
  /*.nb01 =*/ nb01,
3499
3999
  /*.nb02 =*/ nb02,
3500
4000
  /*.nb03 =*/ nb03,
3501
- /*.ne0 =*/ ne0,
3502
- /*.ne1 =*/ ne1,
3503
- /*.ne2 =*/ ne2,
3504
- /*.ne3 =*/ ne3,
3505
- /*.nb0 =*/ nb0,
3506
- /*.nb1 =*/ nb1,
3507
- /*.nb2 =*/ nb2,
3508
- /*.nb3 =*/ nb3,
3509
- /*.sf0 =*/ sf0,
3510
- /*.sf1 =*/ sf1,
3511
- /*.sf2 =*/ sf2,
3512
- /*.sf3 =*/ sf3
4001
+ /*.ne0 =*/ ne0,
4002
+ /*.ne1 =*/ ne1,
4003
+ /*.ne2 =*/ ne2,
4004
+ /*.ne3 =*/ ne3,
4005
+ /*.nb0 =*/ nb0,
4006
+ /*.nb1 =*/ nb1,
4007
+ /*.nb2 =*/ nb2,
4008
+ /*.nb3 =*/ nb3,
4009
+ /*.s0 =*/ s0,
4010
+ /*.s1 =*/ s1,
4011
+ /*.s2 =*/ s2,
4012
+ /*.s3 =*/ s3
3513
4013
  };
3514
4014
 
3515
- auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
4015
+ auto pipeline = ggml_metal_library_get_pipeline_roll(lib, op);
3516
4016
 
3517
- const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4017
+ const int nth = std::min(1024, ne0);
3518
4018
 
3519
4019
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3520
4020
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -3558,14 +4058,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
3558
4058
 
3559
4059
  auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
3560
4060
 
3561
- const int nth = std::min(1024, ne0);
4061
+ if (pipeline.c4) {
4062
+ args.ne00 = ne00/4;
4063
+ args.ne0 = ne0/4;
4064
+ }
4065
+
4066
+ const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4067
+ const int nth = MIN(args.ne0, nth_max);
4068
+ const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel!
3562
4069
 
3563
4070
  ggml_metal_encoder_set_pipeline(enc, pipeline);
3564
4071
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3565
4072
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3566
4073
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3567
4074
 
3568
- ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
4075
+ ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1);
3569
4076
 
3570
4077
  return 1;
3571
4078
  }
@@ -3942,42 +4449,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3942
4449
  return 1;
3943
4450
  }
3944
4451
 
3945
- int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
3946
- ggml_tensor * op = ctx->node(idx);
3947
-
3948
- ggml_metal_library_t lib = ctx->lib;
3949
- ggml_metal_encoder_t enc = ctx->enc;
3950
-
3951
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3952
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3953
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3954
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3955
-
3956
- float slope;
3957
- memcpy(&slope, op->op_params, sizeof(float));
3958
-
3959
- ggml_metal_kargs_leaky_relu args = {
3960
- /*.slope =*/ slope
3961
- };
3962
-
3963
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
3964
-
3965
- int64_t n = ggml_nelements(op);
3966
-
3967
- if (n % 4 == 0) {
3968
- n /= 4;
3969
- }
3970
-
3971
- ggml_metal_encoder_set_pipeline(enc, pipeline);
3972
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3973
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3974
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3975
-
3976
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
3977
-
3978
- return 1;
3979
- }
3980
-
3981
4452
  int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
3982
4453
  ggml_tensor * op = ctx->node(idx);
3983
4454