whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -17,10 +17,12 @@ struct ggml_metal_device_deleter {
17
17
 
18
18
  typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr;
19
19
 
20
- ggml_metal_device_t ggml_metal_device_get(void) {
21
- static ggml_metal_device_ptr ctx { ggml_metal_device_init() };
20
+ ggml_metal_device_t ggml_metal_device_get(int device) {
21
+ static std::vector<ggml_metal_device_ptr> devs;
22
22
 
23
- return ctx.get();
23
+ devs.emplace_back(ggml_metal_device_init(device));
24
+
25
+ return devs.back().get();
24
26
  }
25
27
 
26
28
  struct ggml_metal_pipelines {
@@ -50,14 +52,14 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
50
52
  }
51
53
 
52
54
  ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
53
- if (ppls->data.find(name) == ppls->data.end()) {
55
+ if (ppls->data.find(name) == ppls->data.end()) {
54
56
  return nullptr;
55
57
  }
56
58
 
57
59
  return ppls->data[name];
58
60
  }
59
61
 
60
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
62
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
61
63
  char base[256];
62
64
  char name[256];
63
65
 
@@ -71,34 +73,55 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t
71
73
  snprintf(base, 256, "kernel_%s", op_str);
72
74
  snprintf(name, 256, "%s", base);
73
75
 
74
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
75
- if (res) {
76
- return res;
76
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
77
+ if (!res.pipeline) {
78
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
77
79
  }
78
80
 
79
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
80
-
81
81
  return res;
82
82
  }
83
83
 
84
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
84
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
85
85
  char base[256];
86
86
  char name[256];
87
87
 
88
88
  snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
89
89
  snprintf(name, 256, "%s", base);
90
90
 
91
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
92
- if (res) {
93
- return res;
91
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
92
+ if (!res.pipeline) {
93
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
94
94
  }
95
95
 
96
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
96
+ return res;
97
+ }
98
+
99
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
100
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
101
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
102
+
103
+ const char * pool_str = "undefined";
104
+ switch (op_pool) {
105
+ case GGML_OP_POOL_AVG: pool_str = "avg"; break;
106
+ case GGML_OP_POOL_MAX: pool_str = "max"; break;
107
+ default: GGML_ASSERT(false && "not implemented");
108
+ };
109
+
110
+ char base[256];
111
+ char name[256];
112
+
113
+ snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
114
+ snprintf(name, sizeof(name), "%s", base);
115
+
116
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
117
+ if (!res.pipeline) {
118
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
119
+ }
97
120
 
98
121
  return res;
99
122
  }
100
123
 
101
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
124
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
102
125
  GGML_ASSERT(ggml_is_contiguous(op->src[0]));
103
126
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
104
127
 
@@ -115,126 +138,147 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library
115
138
  snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
116
139
  snprintf(name, 256, "%s", base);
117
140
 
118
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
119
- if (res) {
120
- return res;
141
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
142
+ if (!res.pipeline) {
143
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
121
144
  }
122
145
 
123
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
124
-
125
146
  return res;
126
147
  }
127
148
 
128
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
149
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
129
150
  char base[256];
130
151
  char name[256];
131
152
 
132
153
  snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
133
154
  snprintf(name, 256, "%s", base);
134
155
 
135
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
136
- if (res) {
137
- return res;
156
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
157
+ if (!res.pipeline) {
158
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
138
159
  }
139
160
 
140
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
141
-
142
161
  return res;
143
162
  }
144
163
 
145
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
164
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
146
165
  char base[256];
147
166
  char name[256];
148
167
 
149
168
  snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
150
169
  snprintf(name, 256, "%s", base);
151
170
 
152
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
153
- if (res) {
154
- return res;
171
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
172
+ if (!res.pipeline) {
173
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
155
174
  }
156
175
 
157
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
176
+ return res;
177
+ }
178
+
179
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
180
+ char base[256];
181
+ char name[256];
182
+
183
+ const int n = op->src[0]->ne[0];
184
+
185
+ snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
186
+ snprintf(name, 256, "%s_n=%d", base, n);
187
+
188
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
189
+ if (!res.pipeline) {
190
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
191
+ }
192
+
193
+ res.nsg = 1;
194
+ res.smem = 0;
158
195
 
159
196
  return res;
160
197
  }
161
198
 
162
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
199
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
163
200
  char base[256];
164
201
  char name[256];
165
202
 
166
203
  snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
167
204
  snprintf(name, 256, "%s", base);
168
205
 
169
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
170
- if (res) {
171
- return res;
206
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
207
+ if (!res.pipeline) {
208
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
172
209
  }
173
210
 
174
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
175
-
176
211
  return res;
177
212
  }
178
213
 
179
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
180
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
181
-
214
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
182
215
  char base[256];
183
216
  char name[256];
184
217
 
185
- const int64_t n = ggml_nelements(op);
218
+ int op_num = -1;
186
219
 
187
- const char * op_str = "undefined";
188
220
  switch (op->op) {
189
- case GGML_OP_SCALE: op_str = "scale"; break;
190
- case GGML_OP_CLAMP: op_str = "clamp"; break;
191
- case GGML_OP_SQR: op_str = "sqr"; break;
192
- case GGML_OP_SQRT: op_str = "sqrt"; break;
193
- case GGML_OP_SIN: op_str = "sin"; break;
194
- case GGML_OP_COS: op_str = "cos"; break;
195
- case GGML_OP_LOG: op_str = "log"; break;
196
- case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
221
+ case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
222
+ case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
223
+ case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
224
+ case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
225
+ case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
226
+ case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
227
+ case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
228
+ case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
229
+ case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
197
230
  case GGML_OP_UNARY:
198
231
  switch (ggml_get_unary_op(op)) {
199
- case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
200
- case GGML_UNARY_OP_RELU: op_str = "relu"; break;
201
- case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
202
- case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
203
- case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
204
- case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
205
- case GGML_UNARY_OP_SILU: op_str = "silu"; break;
206
- case GGML_UNARY_OP_ELU: op_str = "elu"; break;
207
- case GGML_UNARY_OP_NEG: op_str = "neg"; break;
208
- case GGML_UNARY_OP_ABS: op_str = "abs"; break;
209
- case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
210
- case GGML_UNARY_OP_STEP: op_str = "step"; break;
211
- case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
212
- case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
213
- case GGML_UNARY_OP_EXP: op_str = "exp"; break;
232
+ case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
233
+ case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
234
+ case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
235
+ case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
236
+ case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
237
+ case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
238
+ case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
239
+ case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
240
+ case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
241
+ case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
242
+ case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
243
+ case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
244
+ case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
245
+ case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
246
+ case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
247
+ case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
248
+ case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
214
249
  default: GGML_ABORT("fatal error");
215
250
  } break;
216
251
  default: GGML_ABORT("fatal error");
217
252
  };
218
253
 
219
- const char * suffix = "";
220
- if (n % 4 == 0) {
221
- suffix = "_4";
222
- }
254
+ const char * t0_str = ggml_type_name(op->src[0]->type);
255
+ const char * t_str = ggml_type_name(op->type);
223
256
 
224
- snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
225
- snprintf(name, 256, "%s", base);
257
+ const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
258
+ const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
226
259
 
227
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
228
- if (res) {
229
- return res;
260
+ snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
261
+ snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
262
+
263
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
264
+ if (!res.pipeline) {
265
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
266
+
267
+ ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
268
+ ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
269
+
270
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
271
+
272
+ ggml_metal_cv_free(cv);
230
273
  }
231
274
 
232
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
275
+ res.c4 = is_c4;
276
+ res.cnt = is_cnt;
233
277
 
234
278
  return res;
235
279
  }
236
280
 
237
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
281
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
238
282
  GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
239
283
 
240
284
  char base[256];
@@ -258,48 +302,132 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
258
302
  snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
259
303
  snprintf(name, 256, "%s", base);
260
304
 
261
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
262
- if (res) {
263
- return res;
305
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
306
+ if (!res.pipeline) {
307
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
264
308
  }
265
309
 
266
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
310
+ return res;
311
+ }
312
+
313
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
314
+ assert(op->op == GGML_OP_SUM);
315
+
316
+ char base[256];
317
+ char name[256];
318
+
319
+ snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
320
+ snprintf(name, 256, "%s", base);
321
+
322
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
323
+ if (!res.pipeline) {
324
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
325
+ }
267
326
 
268
327
  return res;
269
328
  }
270
329
 
271
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
272
- GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
330
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
331
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
273
332
 
274
333
  char base[256];
275
334
  char name[256];
276
335
 
277
- const char * op_str = "undefined";
336
+ int op_num = -1;
337
+
278
338
  switch (op->op) {
279
- case GGML_OP_SUM_ROWS:
280
- op_str = "sum_rows"; break;
281
- case GGML_OP_MEAN:
282
- op_str = "mean"; break;
339
+ case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
340
+ case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
283
341
  default: GGML_ABORT("fatal error");
284
342
  };
285
343
 
286
- snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
344
+ const char * t0_str = ggml_type_name(op->src[0]->type);
345
+ const char * t_str = ggml_type_name(op->type);
346
+
347
+ const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
348
+
349
+ snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
350
+ snprintf(name, 256, "%s_op=%d", base, op_num);
351
+
352
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
353
+ if (!res.pipeline) {
354
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
355
+
356
+ ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
357
+
358
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
359
+
360
+ ggml_metal_cv_free(cv);
361
+ }
362
+
363
+ res.smem = 32*sizeof(float);
364
+
365
+ if (is_c4) {
366
+ res.smem *= 4;
367
+ }
368
+
369
+ res.c4 = is_c4;
370
+
371
+ return res;
372
+ }
373
+
374
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
375
+ GGML_ASSERT(op->op == GGML_OP_CUMSUM);
376
+
377
+ char base[256];
378
+ char name[256];
287
379
 
380
+ snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
288
381
  snprintf(name, 256, "%s", base);
289
382
 
290
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
291
- if (res) {
292
- return res;
383
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
384
+ if (!res.pipeline) {
385
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
293
386
  }
294
387
 
295
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
388
+ return res;
389
+ }
390
+
391
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
392
+ GGML_ASSERT(op->op == GGML_OP_CUMSUM);
393
+
394
+ char base[256];
395
+ char name[256];
396
+
397
+ snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
398
+ snprintf(name, 256, "%s", base);
399
+
400
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
401
+ if (!res.pipeline) {
402
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
403
+ }
404
+
405
+ return res;
406
+ }
407
+
408
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
409
+ GGML_ASSERT(op->op == GGML_OP_TRI);
410
+ GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
411
+
412
+ char base[256];
413
+ char name[256];
414
+
415
+ const char * op_str = "tri";
416
+ const int ttype = op->op_params[0];
296
417
 
297
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
418
+ snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
419
+
420
+ snprintf(name, 256, "%s", base);
421
+
422
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
423
+ if (!res.pipeline) {
424
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
425
+ }
298
426
 
299
427
  return res;
300
428
  }
301
429
 
302
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
430
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
303
431
  GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
304
432
 
305
433
  char base[256];
@@ -316,19 +444,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar
316
444
  snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
317
445
  snprintf(name, 256, "%s", base);
318
446
 
319
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
320
- if (res) {
321
- return res;
447
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
448
+ if (!res.pipeline) {
449
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
322
450
  }
323
451
 
324
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
325
-
326
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
452
+ res.smem = 32*sizeof(float);
327
453
 
328
454
  return res;
329
455
  }
330
456
 
331
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
457
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
332
458
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
333
459
  GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
334
460
 
@@ -338,43 +464,82 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
338
464
  char base[256];
339
465
  char name[256];
340
466
 
341
- snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
342
- snprintf(name, 256, "%s", base);
467
+ const char * suffix = "";
343
468
 
344
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
345
- if (res) {
346
- return res;
469
+ if (op->src[1]->ne[0] % 4 == 0) {
470
+ suffix = "_4";
347
471
  }
348
472
 
349
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
473
+ snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
474
+ snprintf(name, 256, "%s", base);
475
+
476
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
477
+ if (!res.pipeline) {
478
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
479
+ }
350
480
 
351
481
  return res;
352
482
  }
353
483
 
354
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
484
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
485
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
486
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
487
+
488
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
489
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
490
+
355
491
  char base[256];
356
492
  char name[256];
357
493
 
358
- if (op->src[3]->ne[0] == 1) {
359
- snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
360
- } else {
361
- snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
494
+ const char * suffix = "";
495
+ if (op->src[1]->ne[0] % 4 == 0) {
496
+ suffix = "_4";
362
497
  }
363
- snprintf(name, 256, "%s", base);
364
498
 
365
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
366
- if (res) {
367
- return res;
499
+ snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
500
+ snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
501
+
502
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
503
+ if (!res.pipeline) {
504
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
505
+
506
+ ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
507
+
508
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
509
+
510
+ ggml_metal_cv_free(cv);
368
511
  }
369
512
 
370
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
513
+ return res;
514
+ }
515
+
516
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
517
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
518
+
519
+ char base[256];
520
+ char name[256];
371
521
 
372
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
522
+ const int nsg = (ne00 + 31)/32;
523
+
524
+ snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
525
+ snprintf(name, 256, "%s_nsg=%d", base, nsg);
526
+
527
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
528
+ if (!res.pipeline) {
529
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
530
+ }
531
+
532
+ // Shared memory layout:
533
+ // - sgptg * NW floats for partial sums (nsg * 32)
534
+ // - sgptg floats for shared_x_dt (nsg)
535
+ // - sgptg floats for shared_dA (nsg)
536
+ // Total: nsg * (32 + 2) floats
537
+ res.smem = (32 + 2)*sizeof(float)*nsg;
373
538
 
374
539
  return res;
375
540
  }
376
541
 
377
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
542
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
378
543
  char base[256];
379
544
  char name[256];
380
545
 
@@ -404,41 +569,102 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
404
569
 
405
570
  snprintf(name, 256, "%s", base);
406
571
 
407
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
408
- if (res) {
409
- return res;
572
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
573
+ if (!res.pipeline) {
574
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
410
575
  }
411
576
 
412
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
577
+ return res;
578
+ }
579
+
580
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
581
+ char base[256];
582
+ char name[256];
583
+
584
+ // v is src[2], dimensions: S_v = ne[0], H = ne[1]
585
+ const int ne20 = op->src[2]->ne[0]; // S_v
586
+ const int ne21 = op->src[2]->ne[1]; // H
587
+ const int ne30 = op->src[3]->ne[0]; // G
588
+
589
+ const int nsg = op->src[2]->ne[0]/32;
590
+
591
+ GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
592
+ GGML_ASSERT(op->ne[0] == ne20 * ne21);
593
+ GGML_ASSERT(ne20 % 32 == 0);
594
+
595
+ snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
596
+ snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
597
+
598
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
599
+ if (!res.pipeline) {
600
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
601
+
602
+ ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
603
+ ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
604
+
605
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
606
+
607
+ ggml_metal_cv_free(cv);
608
+ }
609
+
610
+ res.nsg = nsg;
413
611
 
414
612
  return res;
415
613
  }
416
614
 
417
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
615
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
418
616
  char base[256];
419
617
  char name[256];
420
618
 
421
- snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
422
- snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
619
+ const int nsg = 8;
620
+ const int n = op->src[1]->ne[1];
621
+ const int k = op->src[1]->ne[0];
423
622
 
424
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
425
- if (res) {
426
- return res;
623
+ snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
624
+ snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
625
+
626
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
627
+ if (!res.pipeline) {
628
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
629
+
630
+ ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
631
+ ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1);
632
+ ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2);
633
+
634
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
635
+
636
+ ggml_metal_cv_free(cv);
427
637
  }
428
638
 
429
- ggml_metal_cv_t cv = ggml_metal_cv_init();
639
+ res.nsg = nsg;
640
+ res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
430
641
 
431
- ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
432
- ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
642
+ return res;
643
+ }
644
+
645
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
646
+ char base[256];
647
+ char name[256];
648
+
649
+ snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
650
+ snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
651
+
652
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
653
+ if (!res.pipeline) {
654
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
433
655
 
434
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
656
+ ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
657
+ ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
435
658
 
436
- ggml_metal_cv_free(cv);
659
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
660
+
661
+ ggml_metal_cv_free(cv);
662
+ }
437
663
 
438
664
  return res;
439
665
  }
440
666
 
441
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
667
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
442
668
  char base[256];
443
669
  char name[256];
444
670
 
@@ -451,27 +677,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_
451
677
  snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
452
678
  snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
453
679
 
454
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
455
- if (res) {
456
- return res;
457
- }
680
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
681
+ if (!res.pipeline) {
682
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
458
683
 
459
- ggml_metal_cv_t cv = ggml_metal_cv_init();
684
+ ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
685
+ ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
460
686
 
461
- ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
462
- ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
687
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
463
688
 
464
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
465
-
466
- ggml_metal_cv_free(cv);
689
+ ggml_metal_cv_free(cv);
690
+ }
467
691
 
468
692
  // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
469
- ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
693
+ res.smem = bc_out ? 8192 : 4096 + 2048;
470
694
 
471
695
  return res;
472
696
  }
473
697
 
474
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
698
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
475
699
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
476
700
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
477
701
 
@@ -626,49 +850,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
626
850
  snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
627
851
  snprintf(name, 256, "%s_nsg=%d", base, nsg);
628
852
 
629
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
630
- if (res) {
631
- return res;
632
- }
633
-
634
- ggml_metal_cv_t cv = ggml_metal_cv_init();
853
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
854
+ if (!res.pipeline) {
855
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
635
856
 
636
- ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
857
+ ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
637
858
 
638
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
859
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
639
860
 
640
- ggml_metal_cv_free(cv);
861
+ ggml_metal_cv_free(cv);
862
+ }
641
863
 
642
- ggml_metal_pipeline_set_nr0 (res, nr0);
643
- ggml_metal_pipeline_set_nr1 (res, nr1);
644
- ggml_metal_pipeline_set_nsg (res, nsg);
645
- ggml_metal_pipeline_set_smem(res, smem);
864
+ res.nr0 = nr0;
865
+ res.nr1 = nr1;
866
+ res.nsg = nsg;
867
+ res.smem = smem;
646
868
 
647
869
  return res;
648
870
  }
649
871
 
650
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
872
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
651
873
  char base[256];
652
874
  char name[256];
653
875
 
654
876
  snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
655
- snprintf(name, 256, "%s", base);
877
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
656
878
 
657
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
658
- if (res) {
659
- return res;
879
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
880
+ if (!res.pipeline) {
881
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
660
882
  }
661
883
 
662
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
663
-
664
- const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t);
665
-
666
- ggml_metal_pipeline_set_smem(res, smem);
884
+ res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
667
885
 
668
886
  return res;
669
887
  }
670
888
 
671
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
889
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
672
890
  char base[256];
673
891
  char name[256];
674
892
 
@@ -680,25 +898,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra
680
898
  snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
681
899
  snprintf(name, 256, "%s_bci=%d", base, bc_inp);
682
900
 
683
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
684
- if (res) {
685
- return res;
686
- }
901
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
902
+ if (!res.pipeline) {
903
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
687
904
 
688
- ggml_metal_cv_t cv = ggml_metal_cv_init();
905
+ ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
689
906
 
690
- ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
907
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
691
908
 
692
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
693
-
694
- ggml_metal_cv_free(cv);
909
+ ggml_metal_cv_free(cv);
910
+ }
695
911
 
696
- ggml_metal_pipeline_set_smem(res, 8192);
912
+ res.smem = 8192;
697
913
 
698
914
  return res;
699
915
  }
700
916
 
701
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
917
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
702
918
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
703
919
  GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
704
920
 
@@ -846,28 +1062,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
846
1062
  snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
847
1063
  snprintf(name, 256, "%s_nsg=%d", base, nsg);
848
1064
 
849
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
850
- if (res) {
851
- return res;
852
- }
853
-
854
- ggml_metal_cv_t cv = ggml_metal_cv_init();
1065
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1066
+ if (!res.pipeline) {
1067
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
855
1068
 
856
- ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
1069
+ ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
857
1070
 
858
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1071
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
859
1072
 
860
- ggml_metal_cv_free(cv);
1073
+ ggml_metal_cv_free(cv);
1074
+ }
861
1075
 
862
- ggml_metal_pipeline_set_nr0 (res, nr0);
863
- ggml_metal_pipeline_set_nr1 (res, nr1);
864
- ggml_metal_pipeline_set_nsg (res, nsg);
865
- ggml_metal_pipeline_set_smem(res, smem);
1076
+ res.nr0 = nr0;
1077
+ res.nr1 = nr1;
1078
+ res.nsg = nsg;
1079
+ res.smem = smem;
866
1080
 
867
1081
  return res;
868
1082
  }
869
1083
 
870
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
1084
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
871
1085
  GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
872
1086
  GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
873
1087
  GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
@@ -878,19 +1092,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_
878
1092
  snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
879
1093
  snprintf(name, 256, "%s", base);
880
1094
 
881
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
882
- if (res) {
883
- return res;
1095
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1096
+ if (!res.pipeline) {
1097
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
884
1098
  }
885
1099
 
886
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1100
+ res.smem = 32*(sizeof(float) + sizeof(int32_t));
1101
+
1102
+ return res;
1103
+ }
1104
+
1105
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
1106
+ assert(op->op == GGML_OP_ARGSORT);
1107
+
1108
+ char base[256];
1109
+ char name[256];
1110
+
1111
+ ggml_sort_order order = (ggml_sort_order) op->op_params[0];
1112
+
1113
+ const char * order_str = "undefined";
1114
+ switch (order) {
1115
+ case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1116
+ case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1117
+ default: GGML_ABORT("fatal error");
1118
+ };
887
1119
 
888
- ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t)));
1120
+ snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1121
+ snprintf(name, 256, "%s", base);
1122
+
1123
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1124
+ if (!res.pipeline) {
1125
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1126
+ }
889
1127
 
890
1128
  return res;
891
1129
  }
892
1130
 
893
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
1131
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
894
1132
  assert(op->op == GGML_OP_ARGSORT);
895
1133
 
896
1134
  char base[256];
@@ -905,26 +1143,165 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
905
1143
  default: GGML_ABORT("fatal error");
906
1144
  };
907
1145
 
1146
+ snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1147
+ snprintf(name, 256, "%s", base);
1148
+
1149
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1150
+ if (!res.pipeline) {
1151
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1152
+ }
1153
+
1154
+ return res;
1155
+ }
1156
+
1157
+ // note: reuse the argsort kernel for top_k
1158
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
1159
+ assert(op->op == GGML_OP_TOP_K);
1160
+
1161
+ char base[256];
1162
+ char name[256];
1163
+
1164
+ // note: the top_k kernel is always descending order
1165
+ ggml_sort_order order = GGML_SORT_ORDER_DESC;
1166
+
1167
+ const char * order_str = "undefined";
1168
+ switch (order) {
1169
+ case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1170
+ case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1171
+ default: GGML_ABORT("fatal error");
1172
+ };
1173
+
908
1174
  snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
909
1175
  snprintf(name, 256, "%s", base);
910
1176
 
911
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
912
- if (res) {
913
- return res;
1177
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1178
+ if (!res.pipeline) {
1179
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
914
1180
  }
915
1181
 
916
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1182
+ return res;
1183
+ }
1184
+
1185
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1186
+ assert(op->op == GGML_OP_TOP_K);
1187
+
1188
+ char base[256];
1189
+ char name[256];
1190
+
1191
+ ggml_sort_order order = GGML_SORT_ORDER_DESC;
1192
+
1193
+ const char * order_str = "undefined";
1194
+ switch (order) {
1195
+ case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1196
+ case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1197
+ default: GGML_ABORT("fatal error");
1198
+ };
1199
+
1200
+ snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1201
+ snprintf(name, 256, "%s", base);
1202
+
1203
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1204
+ if (!res.pipeline) {
1205
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1206
+ }
1207
+
1208
+ return res;
1209
+ }
1210
+
1211
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
1212
+ ggml_metal_library_t lib,
1213
+ const struct ggml_tensor * op,
1214
+ bool has_mask,
1215
+ int32_t ncpsg) {
1216
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1217
+ GGML_UNUSED(op);
1218
+
1219
+ char base[256];
1220
+ char name[256];
1221
+
1222
+ snprintf(base, 256, "kernel_%s",
1223
+ "flash_attn_ext_pad");
1224
+
1225
+ snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
1226
+ base,
1227
+ has_mask,
1228
+ ncpsg);
1229
+
1230
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1231
+ if (!res.pipeline) {
1232
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1233
+
1234
+ ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
1235
+ //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
1236
+ //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
1237
+ //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
1238
+
1239
+ //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
1240
+ //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
1241
+ //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
1242
+ //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
1243
+ //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
1244
+ ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
1245
+
1246
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1247
+
1248
+ ggml_metal_cv_free(cv);
1249
+ }
917
1250
 
918
1251
  return res;
919
1252
  }
920
1253
 
921
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
1254
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
1255
+ ggml_metal_library_t lib,
1256
+ const struct ggml_tensor * op,
1257
+ int32_t nqptg,
1258
+ int32_t ncpsg) {
1259
+ assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1260
+ GGML_UNUSED(op);
1261
+
1262
+ char base[256];
1263
+ char name[256];
1264
+
1265
+ snprintf(base, 256, "kernel_%s",
1266
+ "flash_attn_ext_blk");
1267
+
1268
+ snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
1269
+ base,
1270
+ nqptg,
1271
+ ncpsg);
1272
+
1273
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1274
+ if (!res.pipeline) {
1275
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1276
+
1277
+ //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
1278
+ //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1279
+ //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1280
+ //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1281
+
1282
+ //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1283
+ //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1284
+ //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1285
+ //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1286
+ ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1287
+ ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1288
+
1289
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1290
+
1291
+ ggml_metal_cv_free(cv);
1292
+ }
1293
+
1294
+ return res;
1295
+ }
1296
+
1297
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
922
1298
  ggml_metal_library_t lib,
923
1299
  const ggml_tensor * op,
924
1300
  bool has_mask,
925
1301
  bool has_sinks,
926
1302
  bool has_bias,
927
1303
  bool has_scap,
1304
+ bool has_kvpad,
928
1305
  int32_t nsg) {
929
1306
  assert(op->op == GGML_OP_FLASH_ATTN_EXT);
930
1307
 
@@ -937,52 +1314,59 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
937
1314
  const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
938
1315
  const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
939
1316
 
1317
+ // do bounds checks for the mask?
1318
+ const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
1319
+
940
1320
  snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
941
1321
  "flash_attn_ext",
942
1322
  ggml_type_name(op->src[1]->type),
943
1323
  dk,
944
1324
  dv);
945
1325
 
946
- snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
1326
+ snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
947
1327
  base,
948
1328
  has_mask,
949
1329
  has_sinks,
950
1330
  has_bias,
951
1331
  has_scap,
1332
+ has_kvpad,
1333
+ bc_mask,
952
1334
  ns10,
953
1335
  ns20,
954
1336
  nsg);
955
1337
 
956
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
957
- if (res) {
958
- return res;
959
- }
1338
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1339
+ if (!res.pipeline) {
1340
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
960
1341
 
961
- ggml_metal_cv_t cv = ggml_metal_cv_init();
1342
+ ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
1343
+ ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
1344
+ ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
1345
+ ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1346
+ ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
962
1347
 
963
- ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
964
- ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
965
- ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
966
- ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1348
+ ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
967
1349
 
968
- ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
969
- ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
970
- ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
1350
+ ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
1351
+ ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
1352
+ ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
971
1353
 
972
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1354
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
973
1355
 
974
- ggml_metal_cv_free(cv);
1356
+ ggml_metal_cv_free(cv);
1357
+ }
975
1358
 
976
1359
  return res;
977
1360
  }
978
1361
 
979
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1362
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
980
1363
  ggml_metal_library_t lib,
981
1364
  const ggml_tensor * op,
982
1365
  bool has_mask,
983
1366
  bool has_sinks,
984
1367
  bool has_bias,
985
1368
  bool has_scap,
1369
+ bool has_kvpad,
986
1370
  int32_t nsg,
987
1371
  int32_t nwg) {
988
1372
  assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1002,41 +1386,41 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1002
1386
  dk,
1003
1387
  dv);
1004
1388
 
1005
- snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1389
+ snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1006
1390
  base,
1007
1391
  has_mask,
1008
1392
  has_sinks,
1009
1393
  has_bias,
1010
1394
  has_scap,
1395
+ has_kvpad,
1011
1396
  ns10,
1012
1397
  ns20,
1013
1398
  nsg, nwg);
1014
1399
 
1015
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1016
- if (res) {
1017
- return res;
1018
- }
1019
-
1020
- ggml_metal_cv_t cv = ggml_metal_cv_init();
1400
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1401
+ if (!res.pipeline) {
1402
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1021
1403
 
1022
- ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
1023
- ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1024
- ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1025
- ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1404
+ ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
1405
+ ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1406
+ ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1407
+ ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1408
+ ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1026
1409
 
1027
- ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1028
- ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1029
- ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
1030
- ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
1410
+ ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1411
+ ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1412
+ ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
1413
+ ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
1031
1414
 
1032
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1415
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1033
1416
 
1034
- ggml_metal_cv_free(cv);
1417
+ ggml_metal_cv_free(cv);
1418
+ }
1035
1419
 
1036
1420
  return res;
1037
1421
  }
1038
1422
 
1039
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1423
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1040
1424
  ggml_metal_library_t lib,
1041
1425
  const ggml_tensor * op,
1042
1426
  int32_t dv,
@@ -1049,85 +1433,128 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1049
1433
  snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1050
1434
  snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
1051
1435
 
1052
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1053
- if (res) {
1054
- return res;
1055
- }
1056
-
1057
- ggml_metal_cv_t cv = ggml_metal_cv_init();
1436
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1437
+ if (!res.pipeline) {
1438
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1058
1439
 
1059
- ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1060
- ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1440
+ ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1441
+ ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1061
1442
 
1062
- res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1443
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1063
1444
 
1064
- ggml_metal_cv_free(cv);
1445
+ ggml_metal_cv_free(cv);
1446
+ }
1065
1447
 
1066
1448
  return res;
1067
1449
 
1068
1450
  GGML_UNUSED(op);
1069
1451
  }
1070
1452
 
1071
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
1072
- ggml_metal_library_t lib,
1073
- ggml_op op,
1074
- int32_t n_fuse,
1075
- bool row) {
1453
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
1076
1454
  char base[256];
1077
1455
  char name[256];
1078
1456
 
1079
- const char * op_str = "undefined";
1080
- switch (op) {
1081
- case GGML_OP_ADD: op_str = "add"; break;
1082
- case GGML_OP_SUB: op_str = "sub"; break;
1083
- case GGML_OP_MUL: op_str = "mul"; break;
1084
- case GGML_OP_DIV: op_str = "div"; break;
1457
+ int op_num = -1;
1458
+
1459
+ switch (op->op) {
1460
+ case GGML_OP_ADD: op_num = 0; break;
1461
+ case GGML_OP_SUB: op_num = 1; break;
1462
+ case GGML_OP_MUL: op_num = 2; break;
1463
+ case GGML_OP_DIV: op_num = 3; break;
1085
1464
  default: GGML_ABORT("fatal error");
1086
1465
  };
1087
1466
 
1088
- if (row) {
1089
- snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
1090
- } else {
1091
- snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
1092
- }
1467
+ const char * t0_str = ggml_type_name(op->src[0]->type);
1468
+ const char * t1_str = ggml_type_name(op->src[1]->type);
1469
+ const char * t_str = ggml_type_name(op->type);
1093
1470
 
1094
- snprintf(name, 256, "%s", base);
1471
+ const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
1095
1472
 
1096
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1097
- if (res) {
1098
- return res;
1473
+ const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];
1474
+ const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
1475
+
1476
+ snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
1477
+ snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb);
1478
+
1479
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1480
+ if (!res.pipeline) {
1481
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1482
+
1483
+ ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
1484
+ ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
1485
+ ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
1486
+ ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3);
1487
+
1488
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1489
+
1490
+ ggml_metal_cv_free(cv);
1099
1491
  }
1100
1492
 
1101
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1493
+ res.c4 = is_c4;
1494
+ res.cnt = is_rb;
1102
1495
 
1103
1496
  return res;
1104
1497
  }
1105
1498
 
1106
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1107
- assert(op->op == GGML_OP_L2_NORM);
1499
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
1500
+ char base[256];
1501
+ char name[256];
1108
1502
 
1109
- GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
1110
- GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
1503
+ int op_num = -1;
1504
+
1505
+ switch (op) {
1506
+ case GGML_OP_ADD: op_num = 0; break;
1507
+ case GGML_OP_SUB: op_num = 1; break;
1508
+ case GGML_OP_MUL: op_num = 2; break;
1509
+ case GGML_OP_DIV: op_num = 3; break;
1510
+ default: GGML_ABORT("fatal error");
1511
+ };
1512
+
1513
+ snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
1514
+ snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
1515
+
1516
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1517
+ if (!res.pipeline) {
1518
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1519
+
1520
+ ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
1521
+ ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
1522
+ ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);
1523
+
1524
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1525
+
1526
+ ggml_metal_cv_free(cv);
1527
+ }
1528
+
1529
+ return res;
1530
+ }
1531
+
1532
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1533
+ assert(op->op == GGML_OP_L2_NORM);
1111
1534
 
1112
1535
  char base[256];
1113
1536
  char name[256];
1114
1537
 
1115
- snprintf(base, 256, "kernel_l2_norm_f32");
1538
+ const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
1539
+
1540
+ const char * t0_str = ggml_type_name(op->src[0]->type);
1541
+ const char * t_str = ggml_type_name(op->type);
1542
+
1543
+ snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
1116
1544
  snprintf(name, 256, "%s", base);
1117
1545
 
1118
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1119
- if (res) {
1120
- return res;
1546
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1547
+ if (!res.pipeline) {
1548
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1121
1549
  }
1122
1550
 
1123
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1124
-
1125
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1551
+ res.c4 = is_c4;
1552
+ res.smem = 32*sizeof(float);
1126
1553
 
1127
1554
  return res;
1128
1555
  }
1129
1556
 
1130
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1557
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1131
1558
  assert(op->op == GGML_OP_GROUP_NORM);
1132
1559
 
1133
1560
  GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -1138,19 +1565,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
1138
1565
  snprintf(base, 256, "kernel_group_norm_f32");
1139
1566
  snprintf(name, 256, "%s", base);
1140
1567
 
1141
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1142
- if (res) {
1143
- return res;
1568
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1569
+ if (!res.pipeline) {
1570
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1144
1571
  }
1145
1572
 
1146
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1147
-
1148
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1573
+ res.smem = 32*sizeof(float);
1149
1574
 
1150
1575
  return res;
1151
1576
  }
1152
1577
 
1153
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
1578
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
1154
1579
  assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
1155
1580
 
1156
1581
  GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
@@ -1183,19 +1608,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t
1183
1608
 
1184
1609
  snprintf(name, 256, "%s", base);
1185
1610
 
1186
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1187
- if (res) {
1188
- return res;
1611
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1612
+ if (!res.pipeline) {
1613
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1189
1614
  }
1190
1615
 
1191
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1192
-
1193
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
1616
+ res.smem = 32*sizeof(float);
1194
1617
 
1195
1618
  return res;
1196
1619
  }
1197
1620
 
1198
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
1621
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
1199
1622
  assert(op->op == GGML_OP_ROPE);
1200
1623
 
1201
1624
  char base[256];
@@ -1205,11 +1628,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
1205
1628
 
1206
1629
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
1207
1630
  const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
1631
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
1208
1632
  const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
1209
1633
 
1210
1634
  if (is_neox) {
1211
1635
  snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
1212
- } else if (is_mrope && !is_vision) {
1636
+ } else if ((is_mrope || is_imrope) && !is_vision) {
1213
1637
  GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1214
1638
  snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
1215
1639
  } else if (is_vision) {
@@ -1219,19 +1643,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
1219
1643
  snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
1220
1644
  }
1221
1645
 
1222
- snprintf(name, 256, "%s", base);
1646
+ snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
1223
1647
 
1224
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1225
- if (res) {
1226
- return res;
1227
- }
1648
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1649
+ if (!res.pipeline) {
1650
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1228
1651
 
1229
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1652
+ ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1653
+
1654
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1655
+
1656
+ ggml_metal_cv_free(cv);
1657
+ }
1230
1658
 
1231
1659
  return res;
1232
1660
  }
1233
1661
 
1234
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
1662
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
1235
1663
  assert(op->op == GGML_OP_IM2COL);
1236
1664
 
1237
1665
  GGML_ASSERT(ggml_is_contiguous(op->src[1]));
@@ -1244,17 +1672,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
1244
1672
  snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
1245
1673
  snprintf(name, 256, "%s", base);
1246
1674
 
1247
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1248
- if (res) {
1249
- return res;
1675
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1676
+ if (!res.pipeline) {
1677
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1250
1678
  }
1251
1679
 
1252
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1253
-
1254
1680
  return res;
1255
1681
  }
1256
1682
 
1257
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1683
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1258
1684
  assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
1259
1685
 
1260
1686
  GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -1269,36 +1695,94 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
1269
1695
  snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1270
1696
  snprintf(name, 256, "%s", base);
1271
1697
 
1272
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1273
- if (res) {
1274
- return res;
1698
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1699
+ if (!res.pipeline) {
1700
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1275
1701
  }
1276
1702
 
1277
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1703
+ return res;
1704
+ }
1705
+
1706
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1707
+ assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
1708
+
1709
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1710
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1711
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1712
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1713
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
1714
+
1715
+ char base[256];
1716
+ char name[256];
1717
+
1718
+ snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1719
+ snprintf(name, 256, "%s", base);
1720
+
1721
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1722
+ if (!res.pipeline) {
1723
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1724
+ }
1278
1725
 
1279
1726
  return res;
1280
1727
  }
1281
1728
 
1282
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
1283
- assert(op->op == GGML_OP_UPSCALE);
1729
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1730
+ assert(op->op == GGML_OP_CONV_2D);
1731
+
1732
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1733
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1734
+ GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1735
+ GGML_ASSERT(op->type == GGML_TYPE_F32);
1284
1736
 
1285
1737
  char base[256];
1286
1738
  char name[256];
1287
1739
 
1288
- snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
1740
+ snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1289
1741
  snprintf(name, 256, "%s", base);
1290
1742
 
1291
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1292
- if (res) {
1293
- return res;
1743
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1744
+ if (!res.pipeline) {
1745
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1294
1746
  }
1295
1747
 
1296
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1748
+ return res;
1749
+ }
1750
+
1751
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
1752
+ assert(op->op == GGML_OP_UPSCALE);
1753
+
1754
+ char base[256];
1755
+ char name[256];
1756
+
1757
+ const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
1758
+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
1759
+
1760
+ const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
1761
+
1762
+ if (mode == GGML_SCALE_MODE_BILINEAR) {
1763
+ snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
1764
+ } else if (mode == GGML_SCALE_MODE_BICUBIC) {
1765
+ snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
1766
+ } else {
1767
+ snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
1768
+ }
1769
+ snprintf(name, 256, "%s_aa=%d", base, antialias);
1770
+
1771
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1772
+ if (!res.pipeline) {
1773
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1774
+
1775
+ ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
1776
+
1777
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1778
+
1779
+ ggml_metal_cv_free(cv);
1780
+ }
1297
1781
 
1298
1782
  return res;
1299
1783
  }
1300
1784
 
1301
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
1785
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
1302
1786
  assert(op->op == GGML_OP_PAD);
1303
1787
 
1304
1788
  char base[256];
@@ -1307,8 +1791,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
1307
1791
  snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
1308
1792
  snprintf(name, 256, "%s", base);
1309
1793
 
1310
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1311
- if (res) {
1794
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1795
+ if (res.pipeline) {
1312
1796
  return res;
1313
1797
  }
1314
1798
 
@@ -1317,7 +1801,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
1317
1801
  return res;
1318
1802
  }
1319
1803
 
1320
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1804
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1321
1805
  assert(op->op == GGML_OP_PAD_REFLECT_1D);
1322
1806
 
1323
1807
  char base[256];
@@ -1326,17 +1810,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_
1326
1810
  snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
1327
1811
  snprintf(name, 256, "%s", base);
1328
1812
 
1329
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1330
- if (res) {
1331
- return res;
1813
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1814
+ if (!res.pipeline) {
1815
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1332
1816
  }
1333
1817
 
1334
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1335
-
1336
1818
  return res;
1337
1819
  }
1338
1820
 
1339
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
1821
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
1340
1822
  assert(op->op == GGML_OP_ARANGE);
1341
1823
 
1342
1824
  char base[256];
@@ -1345,17 +1827,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_
1345
1827
  snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
1346
1828
  snprintf(name, 256, "%s", base);
1347
1829
 
1348
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1349
- if (res) {
1350
- return res;
1830
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1831
+ if (!res.pipeline) {
1832
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1351
1833
  }
1352
1834
 
1353
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1354
-
1355
1835
  return res;
1356
1836
  }
1357
1837
 
1358
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
1838
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
1359
1839
  assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
1360
1840
 
1361
1841
  char base[256];
@@ -1364,13 +1844,101 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
1364
1844
  snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
1365
1845
  snprintf(name, 256, "%s", base);
1366
1846
 
1367
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1368
- if (res) {
1369
- return res;
1847
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1848
+ if (!res.pipeline) {
1849
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1370
1850
  }
1371
1851
 
1372
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1852
+ return res;
1853
+ }
1854
+
1855
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
1856
+ assert(op->op == GGML_OP_OPT_STEP_ADAMW);
1857
+
1858
+ char base[256];
1859
+ char name[256];
1860
+
1861
+ snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
1862
+ snprintf(name, 256, "%s", base);
1863
+
1864
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1865
+ if (!res.pipeline) {
1866
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1867
+ }
1868
+
1869
+ return res;
1870
+ }
1871
+
1872
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
1873
+ assert(op->op == GGML_OP_OPT_STEP_SGD);
1874
+
1875
+ char base[256];
1876
+ char name[256];
1877
+
1878
+ snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
1879
+ snprintf(name, 256, "%s", base);
1880
+
1881
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1882
+ if (!res.pipeline) {
1883
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1884
+ }
1373
1885
 
1374
1886
  return res;
1375
1887
  }
1376
1888
 
1889
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
1890
+ GGML_ASSERT(op->type == GGML_TYPE_I64);
1891
+
1892
+ char base[256];
1893
+ char name[256];
1894
+
1895
+ snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
1896
+ snprintf(name, 256, "%s", base);
1897
+
1898
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1899
+ if (!res.pipeline) {
1900
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1901
+ }
1902
+
1903
+ return res;
1904
+ }
1905
+
1906
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
1907
+ assert(op->op == GGML_OP_COUNT_EQUAL);
1908
+
1909
+ GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
1910
+
1911
+ GGML_ASSERT(op->src[0]->type == op->src[1]->type);
1912
+ GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
1913
+ GGML_ASSERT(op->type == GGML_TYPE_I64);
1914
+
1915
+ // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
1916
+ GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
1917
+
1918
+ char base[256];
1919
+ char name[256];
1920
+
1921
+ int nsg = 1;
1922
+ while (32*nsg < ne00 && nsg < 32) {
1923
+ nsg *= 2;
1924
+ }
1925
+
1926
+ snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
1927
+ snprintf(name, 256, "%s_nsg=%d", base, nsg);
1928
+
1929
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1930
+ if (!res.pipeline) {
1931
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
1932
+
1933
+ ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
1934
+
1935
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1936
+
1937
+ ggml_metal_cv_free(cv);
1938
+ }
1939
+
1940
+ res.smem = 32 * sizeof(int32_t);
1941
+ res.nsg = nsg;
1942
+
1943
+ return res;
1944
+ }