whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,12 +1,13 @@
1
1
  #import "ggml-metal-device.h"
2
2
 
3
3
  #import "ggml-impl.h"
4
- #import "ggml-threading.h"
5
4
 
6
5
  #include <Foundation/Foundation.h>
7
6
 
8
7
  #include <Metal/Metal.h>
9
8
 
9
+ #include <stdatomic.h>
10
+
10
11
  #ifndef TARGET_OS_VISION
11
12
  #define TARGET_OS_VISION 0
12
13
  #endif
@@ -19,8 +20,12 @@
19
20
  #define GGML_METAL_HAS_RESIDENCY_SETS 1
20
21
  #endif
21
22
 
22
- // overload of MTLGPUFamilyMetal3 (not available in some environments)
23
+ // overload of MTLGPUFamilyMetalX (not available in some environments)
23
24
  static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
25
+ static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
26
+
27
+ // virtual address for GPU memory allocations
28
+ static atomic_uintptr_t g_addr_device = 0x000000400ULL;
24
29
 
25
30
  #if !GGML_METAL_EMBED_LIBRARY
26
31
  // Here to assist with NSBundle Path Hack
@@ -69,14 +74,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {
69
74
 
70
75
  struct ggml_metal_pipeline {
71
76
  id<MTLComputePipelineState> obj;
72
-
73
- // suggested dispatch sizes
74
- int nsg;
75
-
76
- int nr0;
77
- int nr1;
78
-
79
- size_t smem;
80
77
  };
81
78
 
82
79
  ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
@@ -84,10 +81,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
84
81
 
85
82
  *res = (struct ggml_metal_pipeline) {
86
83
  /*.obj =*/ nil,
87
- /*.nsg =*/ 0,
88
- /*.nr0 =*/ 0,
89
- /*.nr1 =*/ 0,
90
- /*.smem =*/ 0,
91
84
  };
92
85
 
93
86
  return res;
@@ -99,40 +92,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
99
92
  free(pipeline);
100
93
  }
101
94
 
102
- void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) {
103
- pipeline->nsg = nsg;
104
- }
105
-
106
- int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) {
107
- return pipeline->nsg;
108
- }
109
-
110
- void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) {
111
- pipeline->nr0 = nr0;
112
- }
113
-
114
- int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) {
115
- return pipeline->nr0;
116
- }
117
-
118
- void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) {
119
- pipeline->nr1 = nr1;
120
- }
121
-
122
- int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) {
123
- return pipeline->nr1;
124
- }
125
-
126
- void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) {
127
- pipeline->smem = smem;
128
- }
129
-
130
- size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) {
131
- return pipeline->smem;
132
- }
133
-
134
- int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) {
135
- return pipeline->obj.maxTotalThreadsPerThreadgroup;
95
+ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) {
96
+ return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
136
97
  }
137
98
 
138
99
  struct ggml_metal_library {
@@ -140,6 +101,8 @@ struct ggml_metal_library {
140
101
  id<MTLDevice> device;
141
102
 
142
103
  ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
104
+
105
+ NSLock * lock;
143
106
  };
144
107
 
145
108
  ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
@@ -256,6 +219,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
256
219
  [prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
257
220
  }
258
221
 
222
+ if (ggml_metal_device_get_props(dev)->has_tensor) {
223
+ [prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
224
+ }
225
+
259
226
  #if GGML_METAL_EMBED_LIBRARY
260
227
  [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
261
228
  #endif
@@ -286,9 +253,77 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
286
253
 
287
254
  ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
288
255
 
289
- res->obj = library;
290
- res->device = device;
256
+ res->obj = library;
257
+ res->device = device;
258
+ res->pipelines = ggml_metal_pipelines_init();
259
+ res->lock = [NSLock new];
260
+
261
+ return res;
262
+ }
263
+
264
+ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
265
+ if (source == NULL) {
266
+ GGML_LOG_ERROR("%s: source is NULL\n", __func__);
267
+ return NULL;
268
+ }
269
+
270
+ id<MTLDevice> device = ggml_metal_device_get_obj(dev);
271
+ id<MTLLibrary> library = nil;
272
+ NSError * error = nil;
273
+
274
+ const int64_t t_start = ggml_time_us();
275
+
276
+ NSString * src = [[NSString alloc] initWithBytes:source
277
+ length:strlen(source)
278
+ encoding:NSUTF8StringEncoding];
279
+ if (!src) {
280
+ GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
281
+ return NULL;
282
+ }
283
+
284
+ @autoreleasepool {
285
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
286
+
287
+ MTLCompileOptions * options = [MTLCompileOptions new];
288
+ options.preprocessorMacros = prep;
289
+
290
+ library = [device newLibraryWithSource:src options:options error:&error];
291
+ if (error) {
292
+ if (verbose) {
293
+ GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
294
+ } else {
295
+ GGML_LOG_ERROR("%s: error compiling source\n", __func__);
296
+ }
297
+ library = nil;
298
+ }
299
+
300
+ [options release];
301
+ }
302
+
303
+ [src release];
304
+
305
+ if (!library) {
306
+ if (verbose) {
307
+ GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
308
+ }
309
+
310
+ return NULL;
311
+ }
312
+
313
+ if (verbose) {
314
+ GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
315
+ }
316
+
317
+ ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
318
+ if (!res) {
319
+ GGML_LOG_ERROR("%s: calloc failed\n", __func__);
320
+ return NULL;
321
+ }
322
+
323
+ res->obj = library;
324
+ res->device = device;
291
325
  res->pipelines = ggml_metal_pipelines_init();
326
+ res->lock = [NSLock new];
292
327
 
293
328
  return res;
294
329
  }
@@ -304,26 +339,47 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
304
339
 
305
340
  ggml_metal_pipelines_free(lib->pipelines);
306
341
 
342
+ [lib->lock release];
343
+
307
344
  free(lib);
308
345
  }
309
346
 
310
- ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
311
- return ggml_metal_pipelines_get(lib->pipelines, name);
347
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
348
+ [lib->lock lock];
349
+
350
+ struct ggml_metal_pipeline_with_params res = {
351
+ /*.pipeline =*/ nil,
352
+ /*.nr0 =*/ 0,
353
+ /*.nr1 =*/ 0,
354
+ /*.nsg =*/ 0,
355
+ /*.smem =*/ 0,
356
+ };
357
+
358
+ res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
359
+
360
+ [lib->lock unlock];
361
+
362
+ return res;
312
363
  }
313
364
 
314
- ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
315
- // note: the pipelines are cached in the library per device, so they are shared across all metal contexts
316
- ggml_critical_section_start();
365
+ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
366
+ struct ggml_metal_pipeline_with_params res = {
367
+ /*.pipeline =*/ nil,
368
+ /*.nr0 =*/ 0,
369
+ /*.nr1 =*/ 0,
370
+ /*.nsg =*/ 0,
371
+ /*.smem =*/ 0,
372
+ };
373
+
374
+ [lib->lock lock];
317
375
 
318
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
319
- if (res) {
320
- ggml_critical_section_end();
376
+ res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
377
+ if (res.pipeline) {
378
+ [lib->lock unlock];
321
379
 
322
380
  return res;
323
381
  }
324
382
 
325
- res = ggml_metal_pipeline_init();
326
-
327
383
  @autoreleasepool {
328
384
  NSError * error = nil;
329
385
 
@@ -338,28 +394,53 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
338
394
  mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
339
395
  }
340
396
  if (!mtl_function) {
341
- ggml_critical_section_end();
397
+ [lib->lock unlock];
342
398
 
343
- GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
399
+ GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
344
400
  if (error) {
345
- GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
401
+ GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
346
402
  }
347
403
 
348
- return nil;
404
+ return res;
349
405
  }
350
406
 
351
- res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
352
-
353
- ggml_metal_pipelines_add(lib->pipelines, name, res);
407
+ id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
354
408
 
355
409
  [mtl_function release];
356
410
 
357
- GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
358
- (int) res->obj.maxTotalThreadsPerThreadgroup,
359
- (int) res->obj.threadExecutionWidth);
411
+ if (!obj) {
412
+ [lib->lock unlock];
413
+
414
+ GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name);
415
+ if (error) {
416
+ GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
417
+ }
418
+
419
+ return res;
420
+ }
421
+
422
+ GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name,
423
+ (void *) obj,
424
+ (int) obj.maxTotalThreadsPerThreadgroup,
425
+ (int) obj.threadExecutionWidth);
426
+
427
+ if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) {
428
+ [obj release];
429
+
430
+ [lib->lock unlock];
431
+
432
+ GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
433
+
434
+ return res;
435
+ }
436
+
437
+ res.pipeline = ggml_metal_pipeline_init();
438
+ res.pipeline->obj = obj;
439
+
440
+ ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline);
360
441
  }
361
442
 
362
- ggml_critical_section_end();
443
+ [lib->lock unlock];
363
444
 
364
445
  return res;
365
446
  }
@@ -401,8 +482,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) {
401
482
  [encoder->obj popDebugGroup];
402
483
  }
403
484
 
404
- void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) {
405
- [encoder->obj setComputePipelineState:pipeline->obj];
485
+ void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) {
486
+ [encoder->obj setComputePipelineState:pipeline.pipeline->obj];
406
487
  }
407
488
 
408
489
  void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) {
@@ -437,11 +518,106 @@ struct ggml_metal_device {
437
518
  // ref: https://github.com/ggml-org/llama.cpp/pull/15906
438
519
  id<MTLCommandQueue> mtl_queue;
439
520
 
521
+ ggml_metal_rsets_t rsets;
522
+
440
523
  ggml_metal_library_t library;
441
524
 
442
525
  struct ggml_metal_device_props props;
443
526
  };
444
527
 
528
+ //
529
+ // MTLResidenceSet wrapper
530
+ //
531
+
532
+ struct ggml_metal_rsets {
533
+ NSLock * lock;
534
+
535
+ NSMutableArray * data;
536
+
537
+ // number of seconds since the last graph computation
538
+ // keep the residency sets wired for that amount of time to avoid being collected by the OS
539
+ int keep_alive_s;
540
+
541
+ // background heartbeat thread to keep the residency sets alive
542
+ atomic_bool d_stop;
543
+ atomic_int d_loop;
544
+
545
+ dispatch_group_t d_group;
546
+ };
547
+
548
+ ggml_metal_rsets_t ggml_metal_rsets_init(void) {
549
+ ggml_metal_rsets_t res = calloc(1, sizeof(struct ggml_metal_rsets));
550
+
551
+ res->lock = [[NSLock alloc] init];
552
+ res->data = [[NSMutableArray alloc] init];
553
+
554
+ // by default keep the memory wired for 3 minutes
555
+ res->keep_alive_s = 3*60;
556
+
557
+ const char * GGML_METAL_RESIDENCY_KEEP_ALIVE_S = getenv("GGML_METAL_RESIDENCY_KEEP_ALIVE_S");
558
+ if (GGML_METAL_RESIDENCY_KEEP_ALIVE_S) {
559
+ res->keep_alive_s = atoi(GGML_METAL_RESIDENCY_KEEP_ALIVE_S);
560
+ }
561
+
562
+ if (res->keep_alive_s <= 0) {
563
+ res->keep_alive_s = 3*60;
564
+ }
565
+
566
+ GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s);
567
+
568
+ atomic_store_explicit(&res->d_stop, false, memory_order_relaxed);
569
+ atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed);
570
+
571
+ res->d_group = dispatch_group_create();
572
+
573
+ // start a background thread that periodically requests residency for all the currently active sets in the collection
574
+ // the requests stop after a certain amount of time (keep_alive_s) of inactivity
575
+ dispatch_queue_t d_queue = dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0);
576
+ dispatch_group_async(res->d_group, d_queue, ^{
577
+ #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
578
+ if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
579
+ while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) {
580
+ if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) {
581
+ [res->lock lock];
582
+
583
+ for (int i = 0; i < (int) res->data.count; ++i) {
584
+ [res->data[i] requestResidency];
585
+ }
586
+
587
+ atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed);
588
+
589
+ [res->lock unlock];
590
+ }
591
+
592
+ // half a second
593
+ usleep(500 * 1000);
594
+ }
595
+ }
596
+ #endif
597
+ });
598
+
599
+ return res;
600
+ }
601
+
602
+ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) {
603
+ if (rsets == NULL) {
604
+ return;
605
+ }
606
+
607
+ // note: if you hit this assert, most likely you haven't deallocated all Metal resources before exiting
608
+ GGML_ASSERT([rsets->data count] == 0);
609
+
610
+ atomic_store_explicit(&rsets->d_stop, true, memory_order_relaxed);
611
+
612
+ dispatch_group_wait(rsets->d_group, DISPATCH_TIME_FOREVER);
613
+ dispatch_release(rsets->d_group);
614
+
615
+ [rsets->data release];
616
+ [rsets->lock release];
617
+
618
+ free(rsets);
619
+ }
620
+
445
621
  ggml_metal_device_t ggml_metal_device_init(void) {
446
622
  ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device));
447
623
 
@@ -464,6 +640,128 @@ ggml_metal_device_t ggml_metal_device_init(void) {
464
640
 
465
641
  dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
466
642
  dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
643
+ if (getenv("GGML_METAL_BF16_DISABLE") != NULL) {
644
+ dev->props.has_bfloat = false;
645
+ }
646
+
647
+ dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
648
+ if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) {
649
+ dev->props.has_tensor = false;
650
+ }
651
+
652
+ // note: disable the tensor API by default for old chips because with the current implementation it is not useful
653
+ // - M2 Ultra: ~5% slower
654
+ // - M4, M4 Max: no significant difference
655
+ //
656
+ // TODO: try to update the tensor API kernels to at least match the simdgroup performance
657
+ if (getenv("GGML_METAL_TENSOR_ENABLE") == NULL &&
658
+ ![[dev->mtl_device name] containsString:@"M5"] &&
659
+ ![[dev->mtl_device name] containsString:@"M6"] &&
660
+ ![[dev->mtl_device name] containsString:@"A19"] &&
661
+ ![[dev->mtl_device name] containsString:@"A20"]) {
662
+ GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__);
663
+ dev->props.has_tensor = false;
664
+ }
665
+
666
+ // double-check that the tensor API compiles
667
+ if (dev->props.has_tensor) {
668
+ const char * src_tensor_f16 = "\n"
669
+ "#include <metal_stdlib> \n"
670
+ "#include <metal_tensor> \n"
671
+ "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
672
+ " \n"
673
+ "using namespace metal; \n"
674
+ "using namespace mpp::tensor_ops; \n"
675
+ " \n"
676
+ "kernel void dummy_kernel( \n"
677
+ " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
678
+ " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
679
+ " device float * C [[buffer(2)]], \n"
680
+ " uint2 tgid [[threadgroup_position_in_grid]]) \n"
681
+ "{ \n"
682
+ " auto tA = A.slice(0, (int)tgid.y); \n"
683
+ " auto tB = B.slice((int)tgid.x, 0); \n"
684
+ " \n"
685
+ " matmul2d< \n"
686
+ " matmul2d_descriptor(8, 8, dynamic_extent), \n"
687
+ " execution_simdgroups<4>> mm; \n"
688
+ " \n"
689
+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
690
+ " \n"
691
+ " auto sA = tA.slice(0, 0); \n"
692
+ " auto sB = tB.slice(0, 0); \n"
693
+ " mm.run(sB, sA, cT); \n"
694
+ " \n"
695
+ " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
696
+ " \n"
697
+ " cT.store(tC); \n"
698
+ "}";
699
+
700
+ GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
701
+ ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_f16, false);
702
+ if (lib == NULL) {
703
+ GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
704
+ dev->props.has_tensor = false;
705
+ } else {
706
+ struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
707
+ if (!ppl.pipeline) {
708
+ GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
709
+ dev->props.has_tensor = false;
710
+ }
711
+
712
+ ggml_metal_library_free(lib);
713
+ }
714
+ }
715
+
716
+ // try to compile a dummy kernel to determine if the tensor API is supported for bfloat
717
+ if (dev->props.has_tensor && dev->props.has_bfloat) {
718
+ const char * src_tensor_bf16 = "\n"
719
+ "#include <metal_stdlib> \n"
720
+ "#include <metal_tensor> \n"
721
+ "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
722
+ " \n"
723
+ "using namespace metal; \n"
724
+ "using namespace mpp::tensor_ops; \n"
725
+ " \n"
726
+ "kernel void dummy_kernel( \n"
727
+ " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
728
+ " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
729
+ " device float * C [[buffer(2)]], \n"
730
+ " uint2 tgid [[threadgroup_position_in_grid]]) \n"
731
+ "{ \n"
732
+ " auto tA = A.slice(0, (int)tgid.y); \n"
733
+ " auto tB = B.slice((int)tgid.x, 0); \n"
734
+ " \n"
735
+ " matmul2d< \n"
736
+ " matmul2d_descriptor(8, 8, dynamic_extent), \n"
737
+ " execution_simdgroups<4>> mm; \n"
738
+ " \n"
739
+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
740
+ " \n"
741
+ " auto sA = tA.slice(0, 0); \n"
742
+ " auto sB = tB.slice(0, 0); \n"
743
+ " mm.run(sB, sA, cT); \n"
744
+ " \n"
745
+ " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
746
+ " \n"
747
+ " cT.store(tC); \n"
748
+ "}";
749
+
750
+ GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
751
+ ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
752
+ if (lib == NULL) {
753
+ GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
754
+ dev->props.has_bfloat = false;
755
+ } else {
756
+ struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
757
+ if (!ppl.pipeline) {
758
+ GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
759
+ dev->props.has_bfloat = false;
760
+ }
761
+
762
+ ggml_metal_library_free(lib);
763
+ }
764
+ }
467
765
 
468
766
  dev->props.use_residency_sets = true;
469
767
  #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
@@ -471,13 +769,21 @@ ggml_metal_device_t ggml_metal_device_init(void) {
471
769
  #endif
472
770
 
473
771
  dev->props.use_shared_buffers = dev->props.has_unified_memory;
474
-
772
+ #if TARGET_OS_OSX
773
+ // In case of eGPU, shared memory may be preferable.
774
+ dev->props.use_shared_buffers |= [dev->mtl_device location] == MTLDeviceLocationExternal;
775
+ #endif
475
776
  if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) {
476
777
  dev->props.use_shared_buffers = false;
477
778
  }
779
+ if (getenv("GGML_METAL_SHARED_BUFFERS_ENABLE") != NULL) {
780
+ dev->props.use_shared_buffers = true;
781
+ }
478
782
 
479
783
  dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
480
784
 
785
+ dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
786
+
481
787
  dev->props.max_buffer_size = dev->mtl_device.maxBufferLength;
482
788
  dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
483
789
  dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;
@@ -489,7 +795,11 @@ ggml_metal_device_t ggml_metal_device_init(void) {
489
795
  GGML_LOG_ERROR("%s: error: failed to create library\n", __func__);
490
796
  }
491
797
 
492
- // --------------------------------------------------
798
+ if (dev->props.use_residency_sets) {
799
+ dev->rsets = ggml_metal_rsets_init();
800
+ } else {
801
+ dev->rsets = nil;
802
+ }
493
803
 
494
804
  // print MTL GPU family:
495
805
  GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name);
@@ -524,6 +834,7 @@ ggml_metal_device_t ggml_metal_device_init(void) {
524
834
  GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false");
525
835
  GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false");
526
836
  GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false");
837
+ GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
527
838
  GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
528
839
  GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
529
840
 
@@ -541,6 +852,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
541
852
  void ggml_metal_device_free(ggml_metal_device_t dev) {
542
853
  assert(dev != NULL);
543
854
 
855
+ ggml_metal_rsets_free(dev->rsets);
856
+
544
857
  ggml_metal_library_free(dev->library);
545
858
  dev->library = NULL;
546
859
 
@@ -569,6 +882,42 @@ ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev) {
569
882
  return dev->library;
570
883
  }
571
884
 
885
+ void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset) {
886
+ if (rset == nil) {
887
+ return;
888
+ }
889
+
890
+ GGML_ASSERT(dev->rsets);
891
+
892
+ [dev->rsets->lock lock];
893
+
894
+ [dev->rsets->data addObject:rset];
895
+
896
+ [dev->rsets->lock unlock];
897
+ }
898
+
899
+ void ggml_metal_device_rsets_rm(ggml_metal_device_t dev, ggml_metal_rset_t rset) {
900
+ if (rset == nil) {
901
+ return;
902
+ }
903
+
904
+ GGML_ASSERT(dev->rsets);
905
+
906
+ [dev->rsets->lock lock];
907
+
908
+ [dev->rsets->data removeObject:rset];
909
+
910
+ [dev->rsets->lock unlock];
911
+ }
912
+
913
+ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) {
914
+ if (dev->rsets == NULL) {
915
+ return;
916
+ }
917
+
918
+ atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed);
919
+ }
920
+
572
921
  void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) {
573
922
  if (@available(macOS 10.12, iOS 16.0, *)) {
574
923
  *total = dev->mtl_device.recommendedMaxWorkingSetSize;
@@ -614,6 +963,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
614
963
  case GGML_UNARY_OP_HARDSWISH:
615
964
  case GGML_UNARY_OP_HARDSIGMOID:
616
965
  case GGML_UNARY_OP_EXP:
966
+ case GGML_UNARY_OP_SOFTPLUS:
967
+ case GGML_UNARY_OP_EXPM1:
617
968
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
618
969
  default:
619
970
  return false;
@@ -646,8 +997,14 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
646
997
  case GGML_OP_ACC:
647
998
  case GGML_OP_REPEAT:
648
999
  case GGML_OP_SCALE:
1000
+ case GGML_OP_FILL:
649
1001
  case GGML_OP_CONV_TRANSPOSE_1D:
650
1002
  return true;
1003
+ case GGML_OP_CONV_TRANSPOSE_2D:
1004
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
1005
+ (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
1006
+ op->src[1]->type == GGML_TYPE_F32 &&
1007
+ op->type == GGML_TYPE_F32;
651
1008
  case GGML_OP_CLAMP:
652
1009
  return op->src[0]->type == GGML_TYPE_F32;
653
1010
  case GGML_OP_SQR:
@@ -656,13 +1013,23 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
656
1013
  case GGML_OP_COS:
657
1014
  case GGML_OP_LOG:
658
1015
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1016
+ case GGML_OP_SUM:
1017
+ return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1018
+ case GGML_OP_TRI:
1019
+ return ggml_is_contiguous_rows(op->src[0]);
659
1020
  case GGML_OP_SUM_ROWS:
1021
+ case GGML_OP_CUMSUM:
660
1022
  case GGML_OP_MEAN:
661
1023
  case GGML_OP_SOFT_MAX:
662
1024
  case GGML_OP_GROUP_NORM:
663
1025
  return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
664
1026
  case GGML_OP_L2_NORM:
665
1027
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1028
+ case GGML_OP_COUNT_EQUAL:
1029
+ return has_simdgroup_reduction &&
1030
+ op->src[0]->type == GGML_TYPE_I32 &&
1031
+ op->src[1]->type == GGML_TYPE_I32 &&
1032
+ op->type == GGML_TYPE_I64;
666
1033
  case GGML_OP_ARGMAX:
667
1034
  return has_simdgroup_reduction;
668
1035
  case GGML_OP_NORM:
@@ -672,13 +1039,23 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
672
1039
  return true;
673
1040
  case GGML_OP_IM2COL:
674
1041
  return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
1042
+ case GGML_OP_CONV_2D:
1043
+ return ggml_is_contiguous(op->src[0]) &&
1044
+ op->src[1]->type == GGML_TYPE_F32 &&
1045
+ op->type == GGML_TYPE_F32 &&
1046
+ (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
675
1047
  case GGML_OP_POOL_1D:
676
1048
  return false;
677
1049
  case GGML_OP_UPSCALE:
678
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
1050
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
679
1051
  case GGML_OP_POOL_2D:
680
1052
  return op->src[0]->type == GGML_TYPE_F32;
681
1053
  case GGML_OP_PAD:
1054
+ // TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985
1055
+ if (ggml_get_op_params_i32(op, 8) != 0) {
1056
+ return false;
1057
+ }
1058
+
682
1059
  return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
683
1060
  (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
684
1061
  case GGML_OP_PAD_REFLECT_1D:
@@ -686,14 +1063,16 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
686
1063
  case GGML_OP_LEAKY_RELU:
687
1064
  return op->src[0]->type == GGML_TYPE_F32;
688
1065
  case GGML_OP_ARGSORT:
689
- // TODO: Support arbitrary column width
690
- return op->src[0]->ne[0] <= 1024;
1066
+ case GGML_OP_TOP_K:
691
1067
  case GGML_OP_ARANGE:
692
1068
  return true;
693
1069
  case GGML_OP_FLASH_ATTN_EXT:
694
1070
  // for new head sizes, add checks here
695
- if (op->src[0]->ne[0] != 40 &&
1071
+ if (op->src[0]->ne[0] != 32 &&
1072
+ op->src[0]->ne[0] != 40 &&
1073
+ op->src[0]->ne[0] != 48 &&
696
1074
  op->src[0]->ne[0] != 64 &&
1075
+ op->src[0]->ne[0] != 72 &&
697
1076
  op->src[0]->ne[0] != 80 &&
698
1077
  op->src[0]->ne[0] != 96 &&
699
1078
  op->src[0]->ne[0] != 112 &&
@@ -770,15 +1149,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
770
1149
  return false;
771
1150
  }
772
1151
  case GGML_TYPE_I32:
773
- return op->type == GGML_TYPE_F32;
1152
+ return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32;
774
1153
  default:
775
1154
  return false;
776
1155
  };
777
1156
  }
778
1157
  case GGML_OP_GET_ROWS:
779
- {
780
- return op->ne[3] == 1;
781
- }
1158
+ return true;
782
1159
  case GGML_OP_SET_ROWS:
783
1160
  {
784
1161
  if (op->src[0]->type != GGML_TYPE_F32) {
@@ -800,6 +1177,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
800
1177
  return false;
801
1178
  };
802
1179
  }
1180
+ case GGML_OP_OPT_STEP_ADAMW:
1181
+ case GGML_OP_OPT_STEP_SGD:
1182
+ return has_simdgroup_reduction;
803
1183
  default:
804
1184
  return false;
805
1185
  }
@@ -824,7 +1204,7 @@ struct ggml_metal_buffer_wrapper {
824
1204
  };
825
1205
 
826
1206
  struct ggml_metal_buffer {
827
- void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
1207
+ void * all_data;
828
1208
  size_t all_size;
829
1209
 
830
1210
  // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
@@ -841,9 +1221,8 @@ struct ggml_metal_buffer {
841
1221
  // note: cannot use explicity "id<MTLResidencySet>" here because it is not available on certain OSes
842
1222
  id rset;
843
1223
 
844
- // pointers to global device objects
845
- id<MTLDevice> device;
846
- id<MTLCommandQueue> queue;
1224
+ // pointers to global device
1225
+ ggml_metal_device_t dev;
847
1226
  };
848
1227
 
849
1228
  static void ggml_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
@@ -886,7 +1265,7 @@ static bool ggml_metal_buffer_rset_init(ggml_metal_buffer_t buf) {
886
1265
  desc.initialCapacity = buf->n_buffers;
887
1266
 
888
1267
  NSError * error;
889
- buf->rset = [buf->device newResidencySetWithDescriptor:desc error:&error];
1268
+ buf->rset = [buf->dev->mtl_device newResidencySetWithDescriptor:desc error:&error];
890
1269
  if (error) {
891
1270
  GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
892
1271
  [desc release];
@@ -947,6 +1326,8 @@ static void * ggml_metal_host_malloc(size_t n) {
947
1326
  ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) {
948
1327
  ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));
949
1328
 
1329
+ res->dev = dev;
1330
+
950
1331
  const size_t size_page = sysconf(_SC_PAGESIZE);
951
1332
 
952
1333
  size_t size_aligned = size;
@@ -962,16 +1343,14 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
962
1343
  if (shared) {
963
1344
  res->all_data = ggml_metal_host_malloc(size_aligned);
964
1345
  res->is_shared = true;
965
- res->owned = true;
966
1346
  } else {
967
- // dummy, non-NULL value - we'll populate this after creating the Metal buffer below
968
- res->all_data = (void *) 0x000000400ULL;
1347
+ // use virtual address from g_addr_device counter
1348
+ res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
969
1349
  res->is_shared = false;
970
1350
  }
971
1351
  res->all_size = size_aligned;
972
1352
 
973
- res->device = ggml_metal_device_get_obj(dev);
974
- res->queue = ggml_metal_device_get_queue(dev);
1353
+ res->owned = true;
975
1354
 
976
1355
  res->n_buffers = 1;
977
1356
 
@@ -980,15 +1359,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
980
1359
  res->buffers[0].metal = nil;
981
1360
 
982
1361
  if (size_aligned > 0) {
983
- if (props_dev->use_shared_buffers &&shared) {
984
- res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
1362
+ if (props_dev->use_shared_buffers && shared) {
1363
+ res->buffers[0].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:res->all_data
985
1364
  length:size_aligned
986
1365
  options:MTLResourceStorageModeShared
987
1366
  deallocator:nil];
988
1367
  } else {
989
- res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
990
-
991
- res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
1368
+ res->buffers[0].metal = [res->dev->mtl_device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
992
1369
  }
993
1370
  }
994
1371
 
@@ -1009,6 +1386,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
1009
1386
  return NULL;
1010
1387
  }
1011
1388
 
1389
+ ggml_metal_device_rsets_add(dev, res->rset);
1390
+
1012
1391
  //ggml_metal_log_allocated_size(device, size_aligned);
1013
1392
 
1014
1393
  return res;
@@ -1017,6 +1396,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
1017
1396
  ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) {
1018
1397
  ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));
1019
1398
 
1399
+ res->dev = dev;
1400
+
1020
1401
  res->all_data = ptr;
1021
1402
  res->all_size = size;
1022
1403
 
@@ -1039,9 +1420,6 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
1039
1420
  size_aligned += (size_page - (size_aligned % size_page));
1040
1421
  }
1041
1422
 
1042
- res->device = ggml_metal_device_get_obj(dev);
1043
- res->queue = ggml_metal_device_get_queue(dev);
1044
-
1045
1423
  const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
1046
1424
 
1047
1425
  // the buffer fits into the max buffer size allowed by the device
@@ -1051,7 +1429,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
1051
1429
  res->buffers[res->n_buffers].metal = nil;
1052
1430
 
1053
1431
  if (size_aligned > 0) {
1054
- res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
1432
+ res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
1055
1433
 
1056
1434
  if (res->buffers[res->n_buffers].metal == nil) {
1057
1435
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
@@ -1060,7 +1438,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
1060
1438
  }
1061
1439
  }
1062
1440
 
1063
- ggml_metal_log_allocated_size(res->device, size_aligned);
1441
+ ggml_metal_log_allocated_size(res->dev->mtl_device, size_aligned);
1064
1442
 
1065
1443
  ++res->n_buffers;
1066
1444
  } else {
@@ -1078,7 +1456,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
1078
1456
  res->buffers[res->n_buffers].metal = nil;
1079
1457
 
1080
1458
  if (size_step_aligned > 0) {
1081
- res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
1459
+ res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
1082
1460
 
1083
1461
  if (res->buffers[res->n_buffers].metal == nil) {
1084
1462
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
@@ -1087,7 +1465,7 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
1087
1465
  }
1088
1466
  }
1089
1467
 
1090
- ggml_metal_log_allocated_size(res->device, size_step_aligned);
1468
+ ggml_metal_log_allocated_size(res->dev->mtl_device, size_step_aligned);
1091
1469
 
1092
1470
  if (i + size_step < size) {
1093
1471
  GGML_LOG_INFO("\n");
@@ -1105,10 +1483,14 @@ ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, s
1105
1483
  return NULL;
1106
1484
  }
1107
1485
 
1486
+ ggml_metal_device_rsets_add(dev, res->rset);
1487
+
1108
1488
  return res;
1109
1489
  }
1110
1490
 
1111
1491
  void ggml_metal_buffer_free(ggml_metal_buffer_t buf) {
1492
+ ggml_metal_device_rsets_rm(buf->dev, buf->rset);
1493
+
1112
1494
  for (int i = 0; i < buf->n_buffers; i++) {
1113
1495
  [buf->buffers[i].metal release];
1114
1496
  }
@@ -1136,7 +1518,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
1136
1518
 
1137
1519
  void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
1138
1520
  if (buf->is_shared) {
1139
- memset((char *)tensor->data + offset, value, size);
1521
+ memset((char *) tensor->data + offset, value, size);
1140
1522
  return;
1141
1523
  }
1142
1524
 
@@ -1145,8 +1527,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
1145
1527
  struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor);
1146
1528
  bid_dst.offs += offset;
1147
1529
 
1148
- id<MTLCommandQueue> queue = buf->queue;
1149
- id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
1530
+ id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
1150
1531
 
1151
1532
  {
1152
1533
  id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
@@ -1165,14 +1546,14 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
1165
1546
 
1166
1547
  void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1167
1548
  if (buf->is_shared) {
1168
- memcpy((char *)tensor->data + offset, data, size);
1549
+ memcpy((char *) tensor->data + offset, data, size);
1169
1550
  return;
1170
1551
  }
1171
1552
 
1172
1553
  @autoreleasepool {
1173
1554
  // src
1174
1555
  void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data
1175
- id<MTLBuffer> buf_src = [buf->device newBufferWithBytesNoCopy:data_ptr
1556
+ id<MTLBuffer> buf_src = [buf->dev->mtl_device newBufferWithBytesNoCopy:data_ptr
1176
1557
  length:size
1177
1558
  options:MTLResourceStorageModeShared
1178
1559
  deallocator:nil];
@@ -1187,8 +1568,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
1187
1568
  // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference
1188
1569
  dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0);
1189
1570
 
1190
- id<MTLCommandQueue> queue = buf->queue;
1191
- id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
1571
+ id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
1192
1572
 
1193
1573
  {
1194
1574
  id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
@@ -1220,7 +1600,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
1220
1600
 
1221
1601
  void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1222
1602
  if (buf->is_shared) {
1223
- memcpy(data, (const char *)tensor->data + offset, size);
1603
+ memcpy(data, (const char *) tensor->data + offset, size);
1224
1604
  return;
1225
1605
  }
1226
1606
 
@@ -1230,15 +1610,14 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten
1230
1610
  bid_src.offs += offset;
1231
1611
 
1232
1612
  // dst
1233
- id<MTLBuffer> buf_dst = [buf->device newBufferWithBytesNoCopy:data
1613
+ id<MTLBuffer> buf_dst = [buf->dev->mtl_device newBufferWithBytesNoCopy:data
1234
1614
  length:size
1235
1615
  options:MTLResourceStorageModeShared
1236
1616
  deallocator:nil];
1237
1617
 
1238
1618
  GGML_ASSERT(buf_dst);
1239
1619
 
1240
- id<MTLCommandQueue> queue = buf->queue;
1241
- id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
1620
+ id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
1242
1621
 
1243
1622
  {
1244
1623
  id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
@@ -1264,8 +1643,7 @@ void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {
1264
1643
  }
1265
1644
 
1266
1645
  @autoreleasepool {
1267
- id<MTLCommandQueue> queue = buf->queue;
1268
- id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
1646
+ id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
1269
1647
 
1270
1648
  {
1271
1649
  id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];