whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -50,14 +50,14 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
50
50
  }
51
51
 
52
52
  ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
53
- if (ppls->data.find(name) == ppls->data.end()) {
53
+ if (ppls->data.find(name) == ppls->data.end()) {
54
54
  return nullptr;
55
55
  }
56
56
 
57
57
  return ppls->data[name];
58
58
  }
59
59
 
60
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
60
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
61
61
  char base[256];
62
62
  char name[256];
63
63
 
@@ -71,34 +71,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t
71
71
  snprintf(base, 256, "kernel_%s", op_str);
72
72
  snprintf(name, 256, "%s", base);
73
73
 
74
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
75
- if (res) {
76
- return res;
74
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
75
+ if (!res.pipeline) {
76
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
77
77
  }
78
78
 
79
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
80
-
81
79
  return res;
82
80
  }
83
81
 
84
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
82
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
85
83
  char base[256];
86
84
  char name[256];
87
85
 
88
86
  snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
89
87
  snprintf(name, 256, "%s", base);
90
88
 
91
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
92
- if (res) {
93
- return res;
89
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
90
+ if (!res.pipeline) {
91
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
94
92
  }
95
93
 
96
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
97
-
98
94
  return res;
99
95
  }
100
96
 
101
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
97
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
102
98
  GGML_ASSERT(ggml_is_contiguous(op->src[0]));
103
99
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
104
100
 
@@ -115,68 +111,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library
115
111
  snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
116
112
  snprintf(name, 256, "%s", base);
117
113
 
118
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
119
- if (res) {
120
- return res;
114
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
115
+ if (!res.pipeline) {
116
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
121
117
  }
122
118
 
123
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
124
-
125
119
  return res;
126
120
  }
127
121
 
128
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
122
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
129
123
  char base[256];
130
124
  char name[256];
131
125
 
132
126
  snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
133
127
  snprintf(name, 256, "%s", base);
134
128
 
135
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
136
- if (res) {
137
- return res;
129
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
130
+ if (!res.pipeline) {
131
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
138
132
  }
139
133
 
140
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
141
-
142
134
  return res;
143
135
  }
144
136
 
145
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
137
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
146
138
  char base[256];
147
139
  char name[256];
148
140
 
149
141
  snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
150
142
  snprintf(name, 256, "%s", base);
151
143
 
152
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
153
- if (res) {
154
- return res;
144
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
145
+ if (!res.pipeline) {
146
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
155
147
  }
156
148
 
157
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
158
-
159
149
  return res;
160
150
  }
161
151
 
162
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
152
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
163
153
  char base[256];
164
154
  char name[256];
165
155
 
166
156
  snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
167
157
  snprintf(name, 256, "%s", base);
168
158
 
169
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
170
- if (res) {
171
- return res;
159
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
160
+ if (!res.pipeline) {
161
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
172
162
  }
173
163
 
174
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
175
-
176
164
  return res;
177
165
  }
178
166
 
179
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
167
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
180
168
  GGML_ASSERT(ggml_is_contiguous(op->src[0]));
181
169
 
182
170
  char base[256];
@@ -187,6 +175,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
187
175
  const char * op_str = "undefined";
188
176
  switch (op->op) {
189
177
  case GGML_OP_SCALE: op_str = "scale"; break;
178
+ case GGML_OP_FILL: op_str = "fill"; break;
190
179
  case GGML_OP_CLAMP: op_str = "clamp"; break;
191
180
  case GGML_OP_SQR: op_str = "sqr"; break;
192
181
  case GGML_OP_SQRT: op_str = "sqrt"; break;
@@ -211,6 +200,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
211
200
  case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
212
201
  case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
213
202
  case GGML_UNARY_OP_EXP: op_str = "exp"; break;
203
+ case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
204
+ case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
214
205
  default: GGML_ABORT("fatal error");
215
206
  } break;
216
207
  default: GGML_ABORT("fatal error");
@@ -224,17 +215,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
224
215
  snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
225
216
  snprintf(name, 256, "%s", base);
226
217
 
227
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
228
- if (res) {
229
- return res;
218
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
219
+ if (!res.pipeline) {
220
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
230
221
  }
231
222
 
232
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
233
-
234
223
  return res;
235
224
  }
236
225
 
237
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
226
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
238
227
  GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
239
228
 
240
229
  char base[256];
@@ -258,17 +247,32 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
258
247
  snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
259
248
  snprintf(name, 256, "%s", base);
260
249
 
261
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
262
- if (res) {
263
- return res;
250
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
251
+ if (!res.pipeline) {
252
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
264
253
  }
265
254
 
266
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
255
+ return res;
256
+ }
257
+
258
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
259
+ assert(op->op == GGML_OP_SUM);
260
+
261
+ char base[256];
262
+ char name[256];
263
+
264
+ snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
265
+ snprintf(name, 256, "%s", base);
266
+
267
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
268
+ if (!res.pipeline) {
269
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
270
+ }
267
271
 
268
272
  return res;
269
273
  }
270
274
 
271
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
275
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
272
276
  GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
273
277
 
274
278
  char base[256];
@@ -287,19 +291,73 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
287
291
 
288
292
  snprintf(name, 256, "%s", base);
289
293
 
290
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
291
- if (res) {
292
- return res;
294
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
295
+ if (!res.pipeline) {
296
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
293
297
  }
294
298
 
295
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
299
+ res.smem = 32*sizeof(float);
296
300
 
297
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
301
+ return res;
302
+ }
303
+
304
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
305
+ GGML_ASSERT(op->op == GGML_OP_CUMSUM);
306
+
307
+ char base[256];
308
+ char name[256];
309
+
310
+ snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
311
+ snprintf(name, 256, "%s", base);
312
+
313
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
314
+ if (!res.pipeline) {
315
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
316
+ }
298
317
 
299
318
  return res;
300
319
  }
301
320
 
302
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
321
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
322
+ GGML_ASSERT(op->op == GGML_OP_CUMSUM);
323
+
324
+ char base[256];
325
+ char name[256];
326
+
327
+ snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
328
+ snprintf(name, 256, "%s", base);
329
+
330
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
331
+ if (!res.pipeline) {
332
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
333
+ }
334
+
335
+ return res;
336
+ }
337
+
338
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
339
+ GGML_ASSERT(op->op == GGML_OP_TRI);
340
+ GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
341
+
342
+ char base[256];
343
+ char name[256];
344
+
345
+ const char * op_str = "tri";
346
+ const int ttype = op->op_params[0];
347
+
348
+ snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
349
+
350
+ snprintf(name, 256, "%s", base);
351
+
352
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
353
+ if (!res.pipeline) {
354
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
355
+ }
356
+
357
+ return res;
358
+ }
359
+
360
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
303
361
  GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
304
362
 
305
363
  char base[256];
@@ -316,19 +374,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar
316
374
  snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
317
375
  snprintf(name, 256, "%s", base);
318
376
 
319
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
320
- if (res) {
321
- return res;
377
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
378
+ if (!res.pipeline) {
379
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
322
380
  }
323
381
 
324
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
325
-
326
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
382
+ res.smem = 32*sizeof(float);
327
383
 
328
384
  return res;
329
385
  }
330
386
 
331
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
387
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
332
388
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
333
389
  GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
334
390
 
@@ -338,43 +394,82 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
338
394
  char base[256];
339
395
  char name[256];
340
396
 
341
- snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
342
- snprintf(name, 256, "%s", base);
397
+ const char * suffix = "";
343
398
 
344
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
345
- if (res) {
346
- return res;
399
+ if (op->src[1]->ne[0] % 4 == 0) {
400
+ suffix = "_4";
347
401
  }
348
402
 
349
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
403
+ snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
404
+ snprintf(name, 256, "%s", base);
405
+
406
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
407
+ if (!res.pipeline) {
408
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
409
+ }
350
410
 
351
411
  return res;
352
412
  }
353
413
 
354
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
414
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
415
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
416
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
417
+
418
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
419
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
420
+
355
421
  char base[256];
356
422
  char name[256];
357
423
 
358
- if (op->src[3]->ne[0] == 1) {
359
- snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
360
- } else {
361
- snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
424
+ const char * suffix = "";
425
+ if (op->src[1]->ne[0] % 4 == 0) {
426
+ suffix = "_4";
362
427
  }
363
- snprintf(name, 256, "%s", base);
364
428
 
365
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
366
- if (res) {
367
- return res;
429
+ snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
430
+ snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
431
+
432
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
433
+ if (!res.pipeline) {
434
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
435
+
436
+ ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
437
+
438
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
439
+
440
+ ggml_metal_cv_free(cv);
368
441
  }
369
442
 
370
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
443
+ return res;
444
+ }
371
445
 
372
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
446
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
447
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
448
+
449
+ char base[256];
450
+ char name[256];
451
+
452
+ const int nsg = (ne00 + 31)/32;
453
+
454
+ snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
455
+ snprintf(name, 256, "%s_nsg=%d", base, nsg);
456
+
457
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
458
+ if (!res.pipeline) {
459
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
460
+ }
461
+
462
+ // Shared memory layout:
463
+ // - sgptg * NW floats for partial sums (nsg * 32)
464
+ // - sgptg floats for shared_x_dt (nsg)
465
+ // - sgptg floats for shared_dA (nsg)
466
+ // Total: nsg * (32 + 2) floats
467
+ res.smem = (32 + 2)*sizeof(float)*nsg;
373
468
 
374
469
  return res;
375
470
  }
376
471
 
377
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
472
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
378
473
  char base[256];
379
474
  char name[256];
380
475
 
@@ -404,41 +499,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
404
499
 
405
500
  snprintf(name, 256, "%s", base);
406
501
 
407
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
408
- if (res) {
409
- return res;
502
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
503
+ if (!res.pipeline) {
504
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
410
505
  }
411
506
 
412
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
413
-
414
507
  return res;
415
508
  }
416
509
 
417
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
510
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
418
511
  char base[256];
419
512
  char name[256];
420
513
 
421
514
  snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
422
515
  snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
423
516
 
424
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
425
- if (res) {
426
- return res;
427
- }
517
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
518
+ if (!res.pipeline) {
519
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
428
520
 
429
- ggml_metal_cv_t cv = ggml_metal_cv_init();
521
+ ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
522
+ ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
430
523
 
431
- ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
432
- ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
524
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
433
525
 
434
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
435
-
436
- ggml_metal_cv_free(cv);
526
+ ggml_metal_cv_free(cv);
527
+ }
437
528
 
438
529
  return res;
439
530
  }
440
531
 
441
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
532
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
442
533
  char base[256];
443
534
  char name[256];
444
535
 
@@ -451,27 +542,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_
451
542
  snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
452
543
  snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
453
544
 
454
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
455
- if (res) {
456
- return res;
457
- }
458
-
459
- ggml_metal_cv_t cv = ggml_metal_cv_init();
545
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
546
+ if (!res.pipeline) {
547
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
460
548
 
461
- ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
462
- ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
549
+ ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
550
+ ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
463
551
 
464
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
552
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
465
553
 
466
- ggml_metal_cv_free(cv);
554
+ ggml_metal_cv_free(cv);
555
+ }
467
556
 
468
557
  // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
469
- ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
558
+ res.smem = bc_out ? 8192 : 4096 + 2048;
470
559
 
471
560
  return res;
472
561
  }
473
562
 
474
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
563
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
475
564
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
476
565
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
477
566
 
@@ -626,49 +715,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
626
715
  snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
627
716
  snprintf(name, 256, "%s_nsg=%d", base, nsg);
628
717
 
629
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
630
- if (res) {
631
- return res;
632
- }
633
-
634
- ggml_metal_cv_t cv = ggml_metal_cv_init();
718
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
719
+ if (!res.pipeline) {
720
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
635
721
 
636
- ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
722
+ ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
637
723
 
638
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
724
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
639
725
 
640
- ggml_metal_cv_free(cv);
726
+ ggml_metal_cv_free(cv);
727
+ }
641
728
 
642
- ggml_metal_pipeline_set_nr0 (res, nr0);
643
- ggml_metal_pipeline_set_nr1 (res, nr1);
644
- ggml_metal_pipeline_set_nsg (res, nsg);
645
- ggml_metal_pipeline_set_smem(res, smem);
729
+ res.nr0 = nr0;
730
+ res.nr1 = nr1;
731
+ res.nsg = nsg;
732
+ res.smem = smem;
646
733
 
647
734
  return res;
648
735
  }
649
736
 
650
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
737
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
651
738
  char base[256];
652
739
  char name[256];
653
740
 
654
741
  snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
655
- snprintf(name, 256, "%s", base);
742
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
656
743
 
657
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
658
- if (res) {
659
- return res;
744
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
745
+ if (!res.pipeline) {
746
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
660
747
  }
661
748
 
662
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
663
-
664
- const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t);
665
-
666
- ggml_metal_pipeline_set_smem(res, smem);
749
+ res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
667
750
 
668
751
  return res;
669
752
  }
670
753
 
671
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
754
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
672
755
  char base[256];
673
756
  char name[256];
674
757
 
@@ -680,25 +763,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra
680
763
  snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
681
764
  snprintf(name, 256, "%s_bci=%d", base, bc_inp);
682
765
 
683
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
684
- if (res) {
685
- return res;
686
- }
687
-
688
- ggml_metal_cv_t cv = ggml_metal_cv_init();
766
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
767
+ if (!res.pipeline) {
768
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
689
769
 
690
- ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
770
+ ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
691
771
 
692
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
772
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
693
773
 
694
- ggml_metal_cv_free(cv);
774
+ ggml_metal_cv_free(cv);
775
+ }
695
776
 
696
- ggml_metal_pipeline_set_smem(res, 8192);
777
+ res.smem = 8192;
697
778
 
698
779
  return res;
699
780
  }
700
781
 
701
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
782
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
702
783
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
703
784
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
704
785
 
@@ -846,28 +927,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
846
927
  snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
847
928
  snprintf(name, 256, "%s_nsg=%d", base, nsg);
848
929
 
849
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
850
- if (res) {
851
- return res;
852
- }
853
-
854
- ggml_metal_cv_t cv = ggml_metal_cv_init();
930
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
931
+ if (!res.pipeline) {
932
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
855
933
 
856
- ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
934
+ ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
857
935
 
858
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
936
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
859
937
 
860
- ggml_metal_cv_free(cv);
938
+ ggml_metal_cv_free(cv);
939
+ }
861
940
 
862
- ggml_metal_pipeline_set_nr0 (res, nr0);
863
- ggml_metal_pipeline_set_nr1 (res, nr1);
864
- ggml_metal_pipeline_set_nsg (res, nsg);
865
- ggml_metal_pipeline_set_smem(res, smem);
941
+ res.nr0 = nr0;
942
+ res.nr1 = nr1;
943
+ res.nsg = nsg;
944
+ res.smem = smem;
866
945
 
867
946
  return res;
868
947
  }
869
948
 
870
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
949
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
871
950
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
872
951
  GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
873
952
  GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
@@ -878,19 +957,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_
878
957
  snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
879
958
  snprintf(name, 256, "%s", base);
880
959
 
881
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
882
- if (res) {
883
- return res;
960
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
961
+ if (!res.pipeline) {
962
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
884
963
  }
885
964
 
886
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
965
+ res.smem = 32*(sizeof(float) + sizeof(int32_t));
887
966
 
888
- ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t)));
967
+ return res;
968
+ }
969
+
970
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
971
+ assert(op->op == GGML_OP_ARGSORT);
972
+
973
+ char base[256];
974
+ char name[256];
975
+
976
+ ggml_sort_order order = (ggml_sort_order) op->op_params[0];
977
+
978
+ const char * order_str = "undefined";
979
+ switch (order) {
980
+ case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
981
+ case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
982
+ default: GGML_ABORT("fatal error");
983
+ };
984
+
985
+ snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
986
+ snprintf(name, 256, "%s", base);
987
+
988
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
989
+ if (!res.pipeline) {
990
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
991
+ }
889
992
 
890
993
  return res;
891
994
  }
892
995
 
893
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
996
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
894
997
  assert(op->op == GGML_OP_ARGSORT);
895
998
 
896
999
  char base[256];
@@ -905,26 +1008,165 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
905
1008
  default: GGML_ABORT("fatal error");
906
1009
  };
907
1010
 
1011
+ snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1012
+ snprintf(name, 256, "%s", base);
1013
+
1014
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1015
+ if (!res.pipeline) {
1016
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1017
+ }
1018
+
1019
+ return res;
1020
+ }
1021
+
1022
+ // note: reuse the argsort kernel for top_k
1023
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
1024
+ assert(op->op == GGML_OP_TOP_K);
1025
+
1026
+ char base[256];
1027
+ char name[256];
1028
+
1029
+ // note: the top_k kernel is always descending order
1030
+ ggml_sort_order order = GGML_SORT_ORDER_DESC;
1031
+
1032
+ const char * order_str = "undefined";
1033
+ switch (order) {
1034
+ case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1035
+ case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1036
+ default: GGML_ABORT("fatal error");
1037
+ };
1038
+
908
1039
  snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
909
1040
  snprintf(name, 256, "%s", base);
910
1041
 
911
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
912
- if (res) {
913
- return res;
1042
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1043
+ if (!res.pipeline) {
1044
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
914
1045
  }
915
1046
 
916
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1047
+ return res;
1048
+ }
1049
+
1050
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1051
+ assert(op->op == GGML_OP_TOP_K);
1052
+
1053
+ char base[256];
1054
+ char name[256];
1055
+
1056
+ ggml_sort_order order = GGML_SORT_ORDER_DESC;
1057
+
1058
+ const char * order_str = "undefined";
1059
+ switch (order) {
1060
+ case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1061
+ case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1062
+ default: GGML_ABORT("fatal error");
1063
+ };
1064
+
1065
+ snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1066
+ snprintf(name, 256, "%s", base);
1067
+
1068
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1069
+ if (!res.pipeline) {
1070
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1071
+ }
1072
+
1073
+ return res;
1074
+ }
1075
+
1076
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
1077
+ ggml_metal_library_t lib,
1078
+ const struct ggml_tensor * op,
1079
+ bool has_mask,
1080
+ int32_t ncpsg) {
1081
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1082
+ GGML_UNUSED(op);
1083
+
1084
+ char base[256];
1085
+ char name[256];
1086
+
1087
+ snprintf(base, 256, "kernel_%s",
1088
+ "flash_attn_ext_pad");
1089
+
1090
+ snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
1091
+ base,
1092
+ has_mask,
1093
+ ncpsg);
1094
+
1095
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1096
+ if (!res.pipeline) {
1097
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1098
+
1099
+ ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
1100
+ //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
1101
+ //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
1102
+ //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
1103
+
1104
+ //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
1105
+ //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
1106
+ //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
1107
+ //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
1108
+ //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
1109
+ ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
1110
+
1111
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1112
+
1113
+ ggml_metal_cv_free(cv);
1114
+ }
1115
+
1116
+ return res;
1117
+ }
1118
+
1119
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
1120
+ ggml_metal_library_t lib,
1121
+ const struct ggml_tensor * op,
1122
+ int32_t nqptg,
1123
+ int32_t ncpsg) {
1124
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1125
+ GGML_UNUSED(op);
1126
+
1127
+ char base[256];
1128
+ char name[256];
1129
+
1130
+ snprintf(base, 256, "kernel_%s",
1131
+ "flash_attn_ext_blk");
1132
+
1133
+ snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
1134
+ base,
1135
+ nqptg,
1136
+ ncpsg);
1137
+
1138
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1139
+ if (!res.pipeline) {
1140
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1141
+
1142
+ //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
1143
+ //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1144
+ //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1145
+ //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1146
+
1147
+ //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1148
+ //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1149
+ //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1150
+ //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1151
+ ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1152
+ ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1153
+
1154
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1155
+
1156
+ ggml_metal_cv_free(cv);
1157
+ }
917
1158
 
918
1159
  return res;
919
1160
  }
920
1161
 
921
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
1162
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
922
1163
  ggml_metal_library_t lib,
923
1164
  const ggml_tensor * op,
924
1165
  bool has_mask,
925
1166
  bool has_sinks,
926
1167
  bool has_bias,
927
1168
  bool has_scap,
1169
+ bool has_kvpad,
928
1170
  int32_t nsg) {
929
1171
  assert(op->op == GGML_OP_FLASH_ATTN_EXT);
930
1172
 
@@ -937,52 +1179,59 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
937
1179
  const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
938
1180
  const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
939
1181
 
1182
+ // do bounds checks for the mask?
1183
+ const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
1184
+
940
1185
  snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
941
1186
  "flash_attn_ext",
942
1187
  ggml_type_name(op->src[1]->type),
943
1188
  dk,
944
1189
  dv);
945
1190
 
946
- snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
1191
+ snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
947
1192
  base,
948
1193
  has_mask,
949
1194
  has_sinks,
950
1195
  has_bias,
951
1196
  has_scap,
1197
+ has_kvpad,
1198
+ bc_mask,
952
1199
  ns10,
953
1200
  ns20,
954
1201
  nsg);
955
1202
 
956
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
957
- if (res) {
958
- return res;
959
- }
1203
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1204
+ if (!res.pipeline) {
1205
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
960
1206
 
961
- ggml_metal_cv_t cv = ggml_metal_cv_init();
1207
+ ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
1208
+ ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
1209
+ ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
1210
+ ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1211
+ ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
962
1212
 
963
- ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
964
- ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
965
- ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
966
- ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1213
+ ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
967
1214
 
968
- ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
969
- ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
970
- ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
1215
+ ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
1216
+ ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
1217
+ ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
971
1218
 
972
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1219
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
973
1220
 
974
- ggml_metal_cv_free(cv);
1221
+ ggml_metal_cv_free(cv);
1222
+ }
975
1223
 
976
1224
  return res;
977
1225
  }
978
1226
 
979
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1227
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
980
1228
  ggml_metal_library_t lib,
981
1229
  const ggml_tensor * op,
982
1230
  bool has_mask,
983
1231
  bool has_sinks,
984
1232
  bool has_bias,
985
1233
  bool has_scap,
1234
+ bool has_kvpad,
986
1235
  int32_t nsg,
987
1236
  int32_t nwg) {
988
1237
  assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1002,41 +1251,41 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1002
1251
  dk,
1003
1252
  dv);
1004
1253
 
1005
- snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1254
+ snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1006
1255
  base,
1007
1256
  has_mask,
1008
1257
  has_sinks,
1009
1258
  has_bias,
1010
1259
  has_scap,
1260
+ has_kvpad,
1011
1261
  ns10,
1012
1262
  ns20,
1013
1263
  nsg, nwg);
1014
1264
 
1015
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1016
- if (res) {
1017
- return res;
1018
- }
1019
-
1020
- ggml_metal_cv_t cv = ggml_metal_cv_init();
1265
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1266
+ if (!res.pipeline) {
1267
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1021
1268
 
1022
- ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
1023
- ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1024
- ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1025
- ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1269
+ ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
1270
+ ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1271
+ ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1272
+ ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1273
+ ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1026
1274
 
1027
- ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1028
- ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1029
- ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
1030
- ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
1275
+ ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1276
+ ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1277
+ ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
1278
+ ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
1031
1279
 
1032
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1280
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1033
1281
 
1034
- ggml_metal_cv_free(cv);
1282
+ ggml_metal_cv_free(cv);
1283
+ }
1035
1284
 
1036
1285
  return res;
1037
1286
  }
1038
1287
 
1039
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1288
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1040
1289
  ggml_metal_library_t lib,
1041
1290
  const ggml_tensor * op,
1042
1291
  int32_t dv,
@@ -1049,26 +1298,24 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1049
1298
  snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1050
1299
  snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
1051
1300
 
1052
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1053
- if (res) {
1054
- return res;
1055
- }
1301
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1302
+ if (!res.pipeline) {
1303
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1056
1304
 
1057
- ggml_metal_cv_t cv = ggml_metal_cv_init();
1305
+ ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1306
+ ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1058
1307
 
1059
- ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1060
- ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1308
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1061
1309
 
1062
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1063
-
1064
- ggml_metal_cv_free(cv);
1310
+ ggml_metal_cv_free(cv);
1311
+ }
1065
1312
 
1066
1313
  return res;
1067
1314
 
1068
1315
  GGML_UNUSED(op);
1069
1316
  }
1070
1317
 
1071
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
1318
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
1072
1319
  ggml_metal_library_t lib,
1073
1320
  ggml_op op,
1074
1321
  int32_t n_fuse,
@@ -1093,17 +1340,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
1093
1340
 
1094
1341
  snprintf(name, 256, "%s", base);
1095
1342
 
1096
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1097
- if (res) {
1098
- return res;
1343
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1344
+ if (!res.pipeline) {
1345
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1099
1346
  }
1100
1347
 
1101
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1102
-
1103
1348
  return res;
1104
1349
  }
1105
1350
 
1106
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1351
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1107
1352
  assert(op->op == GGML_OP_L2_NORM);
1108
1353
 
1109
1354
  GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
@@ -1115,19 +1360,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library
1115
1360
  snprintf(base, 256, "kernel_l2_norm_f32");
1116
1361
  snprintf(name, 256, "%s", base);
1117
1362
 
1118
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1119
- if (res) {
1120
- return res;
1363
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1364
+ if (!res.pipeline) {
1365
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1121
1366
  }
1122
1367
 
1123
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1124
-
1125
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1368
+ res.smem = 32*sizeof(float);
1126
1369
 
1127
1370
  return res;
1128
1371
  }
1129
1372
 
1130
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1373
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1131
1374
  assert(op->op == GGML_OP_GROUP_NORM);
1132
1375
 
1133
1376
  GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -1138,19 +1381,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
1138
1381
  snprintf(base, 256, "kernel_group_norm_f32");
1139
1382
  snprintf(name, 256, "%s", base);
1140
1383
 
1141
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1142
- if (res) {
1143
- return res;
1384
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1385
+ if (!res.pipeline) {
1386
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1144
1387
  }
1145
1388
 
1146
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1147
-
1148
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1389
+ res.smem = 32*sizeof(float);
1149
1390
 
1150
1391
  return res;
1151
1392
  }
1152
1393
 
1153
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
1394
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
1154
1395
  assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
1155
1396
 
1156
1397
  GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
@@ -1183,19 +1424,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t
1183
1424
 
1184
1425
  snprintf(name, 256, "%s", base);
1185
1426
 
1186
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1187
- if (res) {
1188
- return res;
1427
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1428
+ if (!res.pipeline) {
1429
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1189
1430
  }
1190
1431
 
1191
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1192
-
1193
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1432
+ res.smem = 32*sizeof(float);
1194
1433
 
1195
1434
  return res;
1196
1435
  }
1197
1436
 
1198
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
1437
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
1199
1438
  assert(op->op == GGML_OP_ROPE);
1200
1439
 
1201
1440
  char base[256];
@@ -1205,11 +1444,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
1205
1444
 
1206
1445
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
1207
1446
  const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
1447
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
1208
1448
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
1209
1449
 
1210
1450
  if (is_neox) {
1211
1451
  snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
1212
- } else if (is_mrope && !is_vision) {
1452
+ } else if ((is_mrope || is_imrope) && !is_vision) {
1213
1453
  GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1214
1454
  snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
1215
1455
  } else if (is_vision) {
@@ -1219,19 +1459,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
1219
1459
  snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
1220
1460
  }
1221
1461
 
1222
- snprintf(name, 256, "%s", base);
1462
+ snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
1223
1463
 
1224
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1225
- if (res) {
1226
- return res;
1227
- }
1464
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1465
+ if (!res.pipeline) {
1466
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1228
1467
 
1229
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1468
+ ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1469
+
1470
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1471
+
1472
+ ggml_metal_cv_free(cv);
1473
+ }
1230
1474
 
1231
1475
  return res;
1232
1476
  }
1233
1477
 
1234
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
1478
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
1235
1479
  assert(op->op == GGML_OP_IM2COL);
1236
1480
 
1237
1481
  GGML_ASSERT(ggml_is_contiguous(op->src[1]));
@@ -1244,17 +1488,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
1244
1488
  snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
1245
1489
  snprintf(name, 256, "%s", base);
1246
1490
 
1247
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1248
- if (res) {
1249
- return res;
1491
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1492
+ if (!res.pipeline) {
1493
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1250
1494
  }
1251
1495
 
1252
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1253
-
1254
1496
  return res;
1255
1497
  }
1256
1498
 
1257
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1499
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1258
1500
  assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
1259
1501
 
1260
1502
  GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -1269,17 +1511,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
1269
1511
  snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1270
1512
  snprintf(name, 256, "%s", base);
1271
1513
 
1272
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1273
- if (res) {
1274
- return res;
1514
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1515
+ if (!res.pipeline) {
1516
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1275
1517
  }
1276
1518
 
1277
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1519
+ return res;
1520
+ }
1521
+
1522
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1523
+ assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
1524
+
1525
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1526
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1527
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1528
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1529
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
1530
+
1531
+ char base[256];
1532
+ char name[256];
1533
+
1534
+ snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1535
+ snprintf(name, 256, "%s", base);
1536
+
1537
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1538
+ if (!res.pipeline) {
1539
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1540
+ }
1541
+
1542
+ return res;
1543
+ }
1544
+
1545
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1546
+ assert(op->op == GGML_OP_CONV_2D);
1547
+
1548
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1549
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1550
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1551
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
1552
+
1553
+ char base[256];
1554
+ char name[256];
1555
+
1556
+ snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1557
+ snprintf(name, 256, "%s", base);
1558
+
1559
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1560
+ if (!res.pipeline) {
1561
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1562
+ }
1278
1563
 
1279
1564
  return res;
1280
1565
  }
1281
1566
 
1282
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
1567
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
1283
1568
  assert(op->op == GGML_OP_UPSCALE);
1284
1569
 
1285
1570
  char base[256];
@@ -1288,17 +1573,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library
1288
1573
  snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
1289
1574
  snprintf(name, 256, "%s", base);
1290
1575
 
1291
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1292
- if (res) {
1293
- return res;
1576
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1577
+ if (!res.pipeline) {
1578
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1294
1579
  }
1295
1580
 
1296
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1297
-
1298
1581
  return res;
1299
1582
  }
1300
1583
 
1301
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
1584
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
1302
1585
  assert(op->op == GGML_OP_PAD);
1303
1586
 
1304
1587
  char base[256];
@@ -1307,8 +1590,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
1307
1590
  snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
1308
1591
  snprintf(name, 256, "%s", base);
1309
1592
 
1310
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1311
- if (res) {
1593
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1594
+ if (res.pipeline) {
1312
1595
  return res;
1313
1596
  }
1314
1597
 
@@ -1317,7 +1600,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
1317
1600
  return res;
1318
1601
  }
1319
1602
 
1320
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1603
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1321
1604
  assert(op->op == GGML_OP_PAD_REFLECT_1D);
1322
1605
 
1323
1606
  char base[256];
@@ -1326,17 +1609,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_
1326
1609
  snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
1327
1610
  snprintf(name, 256, "%s", base);
1328
1611
 
1329
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1330
- if (res) {
1331
- return res;
1612
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1613
+ if (!res.pipeline) {
1614
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1332
1615
  }
1333
1616
 
1334
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1335
-
1336
1617
  return res;
1337
1618
  }
1338
1619
 
1339
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
1620
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
1340
1621
  assert(op->op == GGML_OP_ARANGE);
1341
1622
 
1342
1623
  char base[256];
@@ -1345,17 +1626,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_
1345
1626
  snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
1346
1627
  snprintf(name, 256, "%s", base);
1347
1628
 
1348
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1349
- if (res) {
1350
- return res;
1629
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1630
+ if (!res.pipeline) {
1631
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1351
1632
  }
1352
1633
 
1353
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1354
-
1355
1634
  return res;
1356
1635
  }
1357
1636
 
1358
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
1637
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
1359
1638
  assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
1360
1639
 
1361
1640
  char base[256];
@@ -1364,13 +1643,101 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
1364
1643
  snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
1365
1644
  snprintf(name, 256, "%s", base);
1366
1645
 
1367
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1368
- if (res) {
1369
- return res;
1646
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1647
+ if (!res.pipeline) {
1648
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1370
1649
  }
1371
1650
 
1372
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1651
+ return res;
1652
+ }
1653
+
1654
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
1655
+ assert(op->op == GGML_OP_OPT_STEP_ADAMW);
1656
+
1657
+ char base[256];
1658
+ char name[256];
1659
+
1660
+ snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
1661
+ snprintf(name, 256, "%s", base);
1662
+
1663
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1664
+ if (!res.pipeline) {
1665
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1666
+ }
1667
+
1668
+ return res;
1669
+ }
1670
+
1671
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
1672
+ assert(op->op == GGML_OP_OPT_STEP_SGD);
1673
+
1674
+ char base[256];
1675
+ char name[256];
1676
+
1677
+ snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
1678
+ snprintf(name, 256, "%s", base);
1679
+
1680
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1681
+ if (!res.pipeline) {
1682
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1683
+ }
1684
+
1685
+ return res;
1686
+ }
1687
+
1688
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
1689
+ GGML_ASSERT(op->type == GGML_TYPE_I64);
1690
+
1691
+ char base[256];
1692
+ char name[256];
1693
+
1694
+ snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
1695
+ snprintf(name, 256, "%s", base);
1696
+
1697
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1698
+ if (!res.pipeline) {
1699
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1700
+ }
1373
1701
 
1374
1702
  return res;
1375
1703
  }
1376
1704
 
1705
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
1706
+ assert(op->op == GGML_OP_COUNT_EQUAL);
1707
+
1708
+ GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
1709
+
1710
+ GGML_ASSERT(op->src[0]->type == op->src[1]->type);
1711
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
1712
+ GGML_ASSERT(op->type == GGML_TYPE_I64);
1713
+
1714
+ // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
1715
+ GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
1716
+
1717
+ char base[256];
1718
+ char name[256];
1719
+
1720
+ int nsg = 1;
1721
+ while (32*nsg < ne00 && nsg < 32) {
1722
+ nsg *= 2;
1723
+ }
1724
+
1725
+ snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
1726
+ snprintf(name, 256, "%s_nsg=%d", base, nsg);
1727
+
1728
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1729
+ if (!res.pipeline) {
1730
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1731
+
1732
+ ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
1733
+
1734
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1735
+
1736
+ ggml_metal_cv_free(cv);
1737
+ }
1738
+
1739
+ res.smem = 32 * sizeof(int32_t);
1740
+ res.nsg = nsg;
1741
+
1742
+ return res;
1743
+ }