whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
203
203
  GGML_ABORT("unsupported op");
204
204
  }
205
205
 
206
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
207
+ return 1;
208
+ }
209
+
206
210
  int n_fuse = 1;
207
211
 
208
212
  // check if the current node can run concurrently with other nodes before it
@@ -283,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
283
287
  n_fuse = ggml_metal_op_acc(ctx, idx);
284
288
  } break;
285
289
  case GGML_OP_SCALE:
286
- {
287
- n_fuse = ggml_metal_op_scale(ctx, idx);
288
- } break;
289
290
  case GGML_OP_FILL:
290
- {
291
- n_fuse = ggml_metal_op_fill(ctx, idx);
292
- } break;
293
291
  case GGML_OP_CLAMP:
294
- {
295
- n_fuse = ggml_metal_op_clamp(ctx, idx);
296
- } break;
292
+ case GGML_OP_LEAKY_RELU:
297
293
  case GGML_OP_SQR:
298
294
  case GGML_OP_SQRT:
299
295
  case GGML_OP_SIN:
@@ -337,6 +333,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
337
333
  {
338
334
  n_fuse = ggml_metal_op_rwkv(ctx, idx);
339
335
  } break;
336
+ case GGML_OP_GATED_DELTA_NET:
337
+ {
338
+ n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
339
+ } break;
340
+ case GGML_OP_SOLVE_TRI:
341
+ {
342
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
343
+ } break;
340
344
  case GGML_OP_MUL_MAT:
341
345
  {
342
346
  n_fuse = ggml_metal_op_mul_mat(ctx, idx);
@@ -353,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
353
357
  {
354
358
  n_fuse = ggml_metal_op_set_rows(ctx, idx);
355
359
  } break;
360
+ case GGML_OP_DIAG:
361
+ {
362
+ n_fuse = ggml_metal_op_diag(ctx, idx);
363
+ } break;
356
364
  case GGML_OP_L2_NORM:
357
365
  {
358
366
  n_fuse = ggml_metal_op_l2_norm(ctx, idx);
@@ -414,10 +422,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
414
422
  {
415
423
  n_fuse = ggml_metal_op_top_k(ctx, idx);
416
424
  } break;
417
- case GGML_OP_LEAKY_RELU:
418
- {
419
- n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
420
- } break;
421
425
  case GGML_OP_TRI:
422
426
  {
423
427
  n_fuse = ggml_metal_op_tri(ctx, idx);
@@ -426,12 +430,20 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
426
430
  {
427
431
  n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
428
432
  } break;
433
+ case GGML_OP_SET:
434
+ {
435
+ n_fuse = ggml_metal_op_set(ctx, idx);
436
+ } break;
429
437
  case GGML_OP_DUP:
430
438
  case GGML_OP_CPY:
431
439
  case GGML_OP_CONT:
432
440
  {
433
441
  n_fuse = ggml_metal_op_cpy(ctx, idx);
434
442
  } break;
443
+ case GGML_OP_POOL_1D:
444
+ {
445
+ n_fuse = ggml_metal_op_pool_1d(ctx, idx);
446
+ } break;
435
447
  case GGML_OP_POOL_2D:
436
448
  {
437
449
  n_fuse = ggml_metal_op_pool_2d(ctx, idx);
@@ -612,8 +624,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
612
624
  GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
613
625
  GGML_ASSERT(op->type == GGML_TYPE_F32);
614
626
 
615
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
616
- GGML_ASSERT(ggml_is_contiguous(op->src[1]));
627
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
628
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
617
629
 
618
630
  const size_t pnb1 = ((const int32_t *) op->op_params)[0];
619
631
  const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -623,7 +635,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
623
635
  const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
624
636
 
625
637
  if (!inplace) {
626
- // run a separete kernel to cpy src->dst
638
+ // run a separate kernel to cpy src->dst
627
639
  // not sure how to avoid this
628
640
  // TODO: make a simpler cpy_bytes kernel
629
641
 
@@ -663,10 +675,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
663
675
  }
664
676
 
665
677
  ggml_metal_kargs_bin args = {
666
- /*.ne00 =*/ ne00,
667
- /*.ne01 =*/ ne01,
668
- /*.ne02 =*/ ne02,
669
- /*.ne03 =*/ ne03,
678
+ /*.ne00 =*/ ne10,
679
+ /*.ne01 =*/ ne11,
680
+ /*.ne02 =*/ ne12,
681
+ /*.ne03 =*/ ne13,
670
682
  /*.nb00 =*/ nb00,
671
683
  /*.nb01 =*/ pnb1,
672
684
  /*.nb02 =*/ pnb2,
@@ -679,10 +691,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
679
691
  /*.nb11 =*/ nb11,
680
692
  /*.nb12 =*/ nb12,
681
693
  /*.nb13 =*/ nb13,
682
- /*.ne0 =*/ ne0,
683
- /*.ne1 =*/ ne1,
684
- /*.ne2 =*/ ne2,
685
- /*.ne3 =*/ ne3,
694
+ /*.ne0 =*/ ne10,
695
+ /*.ne1 =*/ ne11,
696
+ /*.ne2 =*/ ne12,
697
+ /*.ne3 =*/ ne13,
686
698
  /*.nb0 =*/ nb0,
687
699
  /*.nb1 =*/ pnb1,
688
700
  /*.nb2 =*/ pnb2,
@@ -691,7 +703,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
691
703
  /*.o1 =*/ { 0 },
692
704
  };
693
705
 
694
- auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
706
+ auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
695
707
 
696
708
  ggml_metal_encoder_set_pipeline(enc, pipeline);
697
709
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -699,53 +711,20 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
699
711
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
700
712
  ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
701
713
 
702
- const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
703
-
704
- ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
705
-
706
- return 1;
707
- }
708
-
709
- int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
710
- ggml_tensor * op = ctx->node(idx);
711
-
712
- ggml_metal_library_t lib = ctx->lib;
713
- ggml_metal_encoder_t enc = ctx->enc;
714
-
715
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
716
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
717
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
718
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
719
-
720
- float scale;
721
- float bias;
722
- memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
723
- memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
724
-
725
- ggml_metal_kargs_scale args = {
726
- /*.scale =*/ scale,
727
- /*.bias =*/ bias,
728
- };
714
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
729
715
 
730
- int64_t n = ggml_nelements(op);
716
+ int nth = 1;
731
717
 
732
- if (n % 4 == 0) {
733
- n /= 4;
718
+ while (2*nth < args.ne0 && nth < nth_max) {
719
+ nth *= 2;
734
720
  }
735
721
 
736
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
737
-
738
- ggml_metal_encoder_set_pipeline(enc, pipeline);
739
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
740
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
741
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
742
-
743
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
722
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
744
723
 
745
724
  return 1;
746
725
  }
747
726
 
748
- int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
727
+ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
749
728
  ggml_tensor * op = ctx->node(idx);
750
729
 
751
730
  ggml_metal_library_t lib = ctx->lib;
@@ -756,94 +735,80 @@ int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
756
735
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
757
736
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
758
737
 
759
- const float val = ggml_get_op_params_f32(op, 0);
738
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
760
739
 
761
- ggml_metal_kargs_fill args = {
762
- /*.val =*/ val
763
- };
740
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
741
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
764
742
 
765
- int64_t n = ggml_nelements(op);
743
+ ggml_metal_kargs_unary args = {
744
+ /*.ne00 =*/ ne00,
745
+ /*.ne01 =*/ ne01,
746
+ /*.ne02 =*/ ne02,
747
+ /*.ne03 =*/ ne03,
748
+ /*.nb00 =*/ nb00,
749
+ /*.nb01 =*/ nb01,
750
+ /*.nb02 =*/ nb02,
751
+ /*.nb03 =*/ nb03,
752
+ /*.ne0 =*/ ne0,
753
+ /*.ne1 =*/ ne1,
754
+ /*.ne2 =*/ ne2,
755
+ /*.ne3 =*/ ne3,
756
+ /*.nb0 =*/ nb0,
757
+ /*.nb1 =*/ nb1,
758
+ /*.nb2 =*/ nb2,
759
+ /*.nb3 =*/ nb3,
760
+ /*.slope =*/ 0.0,
761
+ /*.scale =*/ 0.0,
762
+ /*.bias =*/ 0.0,
763
+ /*.val =*/ 0.0,
764
+ /*.min =*/ 0.0,
765
+ /*.max =*/ 0.0,
766
+ };
766
767
 
767
- if (n % 4 == 0) {
768
- n /= 4;
768
+ if (op->op == GGML_OP_LEAKY_RELU) {
769
+ args.slope = ggml_get_op_params_f32(op, 0);
769
770
  }
770
771
 
771
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
772
-
773
- ggml_metal_encoder_set_pipeline(enc, pipeline);
774
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
775
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
776
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
777
-
778
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
779
-
780
- return 1;
781
- }
782
-
783
- int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
784
- ggml_tensor * op = ctx->node(idx);
785
-
786
- ggml_metal_library_t lib = ctx->lib;
787
- ggml_metal_encoder_t enc = ctx->enc;
788
-
789
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
790
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
791
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
792
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
793
-
794
- float min;
795
- float max;
796
- memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
797
- memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
798
-
799
- ggml_metal_kargs_clamp args = {
800
- /*.min =*/ min,
801
- /*.max =*/ max,
802
- };
772
+ if (op->op == GGML_OP_SCALE) {
773
+ args.scale = ggml_get_op_params_f32(op, 0);
774
+ args.bias = ggml_get_op_params_f32(op, 1);
775
+ }
803
776
 
804
- int64_t n = ggml_nelements(op);
777
+ if (op->op == GGML_OP_FILL) {
778
+ args.val = ggml_get_op_params_f32(op, 0);
779
+ }
805
780
 
806
- if (n % 4 == 0) {
807
- n /= 4;
781
+ if (op->op == GGML_OP_CLAMP) {
782
+ args.min = ggml_get_op_params_f32(op, 0);
783
+ args.max = ggml_get_op_params_f32(op, 1);
808
784
  }
809
785
 
810
786
  auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
811
787
 
788
+ if (pipeline.c4) {
789
+ args.ne00 = ne00/4;
790
+ args.ne0 = ne0/4;
791
+ }
792
+
812
793
  ggml_metal_encoder_set_pipeline(enc, pipeline);
813
794
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
814
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
815
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
816
-
817
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
818
-
819
- return 1;
820
- }
795
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
796
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
821
797
 
822
- int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
823
- ggml_tensor * op = ctx->node(idx);
798
+ if (pipeline.cnt) {
799
+ const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
824
800
 
825
- ggml_metal_library_t lib = ctx->lib;
826
- ggml_metal_encoder_t enc = ctx->enc;
801
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
802
+ } else {
803
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
827
804
 
828
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
829
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
830
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
831
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
805
+ const int nth = MIN(args.ne00, nth_max);
832
806
 
833
- int64_t n = ggml_nelements(op);
807
+ const int nk0 = (args.ne00 + nth - 1)/nth;
834
808
 
835
- if (n % 4 == 0) {
836
- n /= 4;
809
+ ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
837
810
  }
838
811
 
839
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
840
-
841
- ggml_metal_encoder_set_pipeline(enc, pipeline);
842
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
843
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
844
-
845
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
846
-
847
812
  return 1;
848
813
  }
849
814
 
@@ -953,6 +918,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
953
918
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
954
919
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
955
920
 
921
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
922
+
923
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
924
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
925
+
956
926
  ggml_metal_kargs_sum_rows args = {
957
927
  /*.ne00 =*/ ne00,
958
928
  /*.ne01 =*/ ne01,
@@ -974,21 +944,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
974
944
 
975
945
  auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
976
946
 
947
+ if (pipeline.c4) {
948
+ args.ne00 = ne00/4;
949
+ args.ne0 = ne0/4;
950
+ }
951
+
977
952
  int nth = 32; // SIMD width
978
953
 
979
- while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
954
+ while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
980
955
  nth *= 2;
981
956
  }
982
957
 
983
958
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
984
- nth = std::min(nth, ne00);
959
+ nth = std::min(nth, (int) args.ne00);
985
960
 
986
961
  const size_t smem = pipeline.smem;
987
962
 
988
963
  ggml_metal_encoder_set_pipeline(enc, pipeline);
989
964
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
990
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
991
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
965
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
966
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
992
967
 
993
968
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
994
969
 
@@ -1247,6 +1222,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
1247
1222
  return 1;
1248
1223
  }
1249
1224
 
1225
+ int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
1226
+ ggml_tensor * op = ctx->node(idx);
1227
+
1228
+ ggml_metal_library_t lib = ctx->lib;
1229
+ ggml_metal_encoder_t enc = ctx->enc;
1230
+
1231
+ GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
1232
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1233
+ GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
1234
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1235
+
1236
+ ggml_metal_kargs_diag args = {
1237
+ /*.ne00 =*/ne00,
1238
+ /*.ne01 =*/ne01,
1239
+ /*.ne02 =*/ne02,
1240
+ /*.ne03 =*/ne03,
1241
+ /*.nb00 =*/nb00,
1242
+ /*.nb01 =*/nb01,
1243
+ /*.nb02 =*/nb02,
1244
+ /*.nb03 =*/nb03,
1245
+ /*.ne0 =*/ne0,
1246
+ /*.ne1 =*/ne1,
1247
+ /*.ne2 =*/ne2,
1248
+ /*.ne3 =*/ne3,
1249
+ /*.nb0 =*/nb0,
1250
+ /*.nb1 =*/nb1,
1251
+ /*.nb2 =*/nb2,
1252
+ /*.nb3 =*/nb3,
1253
+ };
1254
+
1255
+ auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
1256
+
1257
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1258
+ ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1259
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1260
+ ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
1261
+
1262
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
1263
+
1264
+ return 1;
1265
+ }
1266
+
1250
1267
  int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1251
1268
  ggml_tensor * op = ctx->node(idx);
1252
1269
 
@@ -1508,7 +1525,180 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1508
1525
  return 1;
1509
1526
  }
1510
1527
 
1511
- int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1528
+ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1529
+ ggml_tensor * op = ctx->node(idx);
1530
+
1531
+ ggml_metal_library_t lib = ctx->lib;
1532
+ ggml_metal_encoder_t enc = ctx->enc;
1533
+
1534
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1535
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1536
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1537
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1538
+
1539
+ const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1540
+ const int64_t T = op->src[0]->ne[2];
1541
+ const int64_t C = op->ne[0];
1542
+ const int64_t H = op->src[0]->ne[1];
1543
+
1544
+ auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1545
+
1546
+ int ida = 0;
1547
+
1548
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1549
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
1550
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
1551
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
1552
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
1553
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
1554
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
1555
+ if (op->op == GGML_OP_RWKV_WKV7) {
1556
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
1557
+ }
1558
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
1559
+ ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
1560
+ ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
1561
+ ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
1562
+ ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
1563
+
1564
+ ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1565
+
1566
+ return 1;
1567
+ }
1568
+
1569
+ int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
1570
+ ggml_tensor * op = ctx->node(idx);
1571
+
1572
+ ggml_metal_library_t lib = ctx->lib;
1573
+ ggml_metal_encoder_t enc = ctx->enc;
1574
+
1575
+
1576
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1577
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1578
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1579
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1580
+ GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1581
+ GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1582
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1583
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1584
+
1585
+ auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
1586
+
1587
+ int ida = 0;
1588
+
1589
+ ggml_metal_kargs_gated_delta_net args = {
1590
+ /*.ne00 =*/ ne00,
1591
+ /*.ne01 =*/ ne01,
1592
+ /*.ne02 =*/ ne02,
1593
+ /*.ne03 =*/ ne03,
1594
+ /*.nb00 =*/ nb00,
1595
+ /*.nb01 =*/ nb01,
1596
+ /*.nb02 =*/ nb02,
1597
+ /*.nb03 =*/ nb03,
1598
+ /*.ne10 =*/ ne10,
1599
+ /*.ne11 =*/ ne11,
1600
+ /*.ne12 =*/ ne12,
1601
+ /*.ne13 =*/ ne13,
1602
+ /*.nb10 =*/ nb10,
1603
+ /*.nb11 =*/ nb11,
1604
+ /*.nb12 =*/ nb12,
1605
+ /*.nb13 =*/ nb13,
1606
+ /*.ne20 =*/ ne20,
1607
+ /*.ne21 =*/ ne21,
1608
+ /*.ne22 =*/ ne22,
1609
+ /*.ne23 =*/ ne23,
1610
+ /*.nb20 =*/ nb20,
1611
+ /*.nb21 =*/ nb21,
1612
+ /*.nb22 =*/ nb22,
1613
+ /*.nb23 =*/ nb23,
1614
+ /*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
1615
+ /*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
1616
+ /*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
1617
+ /*.ne0 =*/ ne0,
1618
+ /*.ne1 =*/ ne1,
1619
+ /*.ne2 =*/ ne2,
1620
+ /*.ne3 =*/ ne3,
1621
+ /*.nb0 =*/ nb0,
1622
+ /*.nb1 =*/ nb1,
1623
+ /*.nb2 =*/ nb2,
1624
+ /*.nb3 =*/ nb3,
1625
+ };
1626
+
1627
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1628
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
1629
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
1630
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
1631
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
1632
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
1633
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
1634
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
1635
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
1636
+
1637
+ const int nsg = pipeline.nsg;
1638
+
1639
+ ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
1640
+
1641
+ return 1;
1642
+ }
1643
+
1644
+ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
1645
+ ggml_tensor * op = ctx->node(idx);
1646
+
1647
+ ggml_metal_library_t lib = ctx->lib;
1648
+ ggml_metal_encoder_t enc = ctx->enc;
1649
+
1650
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1651
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1652
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1653
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1654
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1655
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1656
+
1657
+ ggml_metal_kargs_solve_tri args = {
1658
+ /*.ne00 =*/ ne00,
1659
+ /*.ne01 =*/ ne01,
1660
+ /*.ne02 =*/ ne02,
1661
+ /*.ne03 =*/ ne03,
1662
+ /*.nb00 =*/ nb00,
1663
+ /*.nb01 =*/ nb01,
1664
+ /*.nb02 =*/ nb02,
1665
+ /*.nb03 =*/ nb03,
1666
+ /*.ne10 =*/ ne10,
1667
+ /*.ne11 =*/ ne11,
1668
+ /*.ne12 =*/ ne12,
1669
+ /*.ne13 =*/ ne13,
1670
+ /*.nb10 =*/ nb10,
1671
+ /*.nb11 =*/ nb11,
1672
+ /*.nb12 =*/ nb12,
1673
+ /*.nb13 =*/ nb13,
1674
+ /*.ne0 =*/ ne0,
1675
+ /*.ne1 =*/ ne1,
1676
+ /*.ne2 =*/ ne2,
1677
+ /*.ne3 =*/ ne3,
1678
+ /*.nb0 =*/ nb0,
1679
+ /*.nb1 =*/ nb1,
1680
+ /*.nb2 =*/ nb2,
1681
+ /*.nb3 =*/ nb3,
1682
+ };
1683
+
1684
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
1685
+
1686
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1687
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1688
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1689
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1690
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1691
+
1692
+ const int nsg = pipeline.nsg;
1693
+
1694
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
1695
+
1696
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
1697
+
1698
+ return 1;
1699
+ }
1700
+
1701
+ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
1512
1702
  ggml_tensor * op = ctx->node(idx);
1513
1703
 
1514
1704
  ggml_metal_library_t lib = ctx->lib;
@@ -1516,35 +1706,122 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1516
1706
 
1517
1707
  GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1518
1708
  GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1709
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1710
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1519
1711
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1520
1712
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1521
1713
 
1522
- const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1523
- const int64_t T = op->src[0]->ne[2];
1524
- const int64_t C = op->ne[0];
1525
- const int64_t H = op->src[0]->ne[1];
1714
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
1715
+ ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
1716
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
1526
1717
 
1527
- auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1718
+ const size_t pnb1 = ((const int32_t *) op->op_params)[0];
1719
+ const size_t pnb2 = ((const int32_t *) op->op_params)[1];
1720
+ const size_t pnb3 = ((const int32_t *) op->op_params)[2];
1721
+ const size_t offs = ((const int32_t *) op->op_params)[3];
1528
1722
 
1529
- int ida = 0;
1723
+ const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
1530
1724
 
1531
- ggml_metal_encoder_set_pipeline(enc, pipeline);
1532
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
1533
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
1534
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
1535
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
1536
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
1537
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
1538
- if (op->op == GGML_OP_RWKV_WKV7) {
1539
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
1725
+ if (!inplace) {
1726
+ // run a separate kernel to cpy src->dst
1727
+ // not sure how to avoid this
1728
+ // TODO: make a simpler cpy_bytes kernel
1729
+
1730
+ //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
1731
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1732
+
1733
+ ggml_metal_kargs_cpy args = {
1734
+ /*.nk0 =*/ ne00,
1735
+ /*.ne00 =*/ ne00,
1736
+ /*.ne01 =*/ ne01,
1737
+ /*.ne02 =*/ ne02,
1738
+ /*.ne03 =*/ ne03,
1739
+ /*.nb00 =*/ nb00,
1740
+ /*.nb01 =*/ nb01,
1741
+ /*.nb02 =*/ nb02,
1742
+ /*.nb03 =*/ nb03,
1743
+ /*.ne0 =*/ ne0,
1744
+ /*.ne1 =*/ ne1,
1745
+ /*.ne2 =*/ ne2,
1746
+ /*.ne3 =*/ ne3,
1747
+ /*.nb0 =*/ nb0,
1748
+ /*.nb1 =*/ nb1,
1749
+ /*.nb2 =*/ nb2,
1750
+ /*.nb3 =*/ nb3,
1751
+ };
1752
+
1753
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1754
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1755
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1756
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1757
+
1758
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
1759
+
1760
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1761
+
1762
+ ggml_metal_op_concurrency_reset(ctx);
1540
1763
  }
1541
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
1542
- ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
1543
- ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
1544
- ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
1545
- ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
1546
1764
 
1547
- ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1765
+ auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
1766
+
1767
+ GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
1768
+
1769
+ int64_t nk0 = ne10;
1770
+ if (ggml_is_quantized(op->src[1]->type)) {
1771
+ nk0 = ne10/16;
1772
+ } else if (ggml_is_quantized(op->type)) {
1773
+ nk0 = ne10/ggml_blck_size(op->type);
1774
+ }
1775
+
1776
+ int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1777
+
1778
+ // when rows are small, we can batch them together in a single threadgroup
1779
+ int nrptg = 1;
1780
+
1781
+ // TODO: relax this constraint in the future
1782
+ if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
1783
+ if (nth > nk0) {
1784
+ nrptg = (nth + nk0 - 1)/nk0;
1785
+ nth = nk0;
1786
+
1787
+ if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1788
+ nrptg--;
1789
+ }
1790
+ }
1791
+ }
1792
+
1793
+ nth = std::min<int>(nth, nk0);
1794
+
1795
+ ggml_metal_kargs_cpy args = {
1796
+ /*.nk0 =*/ nk0,
1797
+ /*.ne00 =*/ ne10,
1798
+ /*.ne01 =*/ ne11,
1799
+ /*.ne02 =*/ ne12,
1800
+ /*.ne03 =*/ ne13,
1801
+ /*.nb00 =*/ nb10,
1802
+ /*.nb01 =*/ nb11,
1803
+ /*.nb02 =*/ nb12,
1804
+ /*.nb03 =*/ nb13,
1805
+ /*.ne0 =*/ ne10,
1806
+ /*.ne1 =*/ ne11,
1807
+ /*.ne2 =*/ ne12,
1808
+ /*.ne3 =*/ ne13,
1809
+ /*.nb0 =*/ ggml_element_size(op),
1810
+ /*.nb1 =*/ pnb1,
1811
+ /*.nb2 =*/ pnb2,
1812
+ /*.nb3 =*/ pnb3,
1813
+ };
1814
+
1815
+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1816
+
1817
+ bid_dst.offs += offs;
1818
+
1819
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1820
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1821
+ ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
1822
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1823
+
1824
+ ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
1548
1825
 
1549
1826
  return 1;
1550
1827
  }
@@ -1622,6 +1899,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1622
1899
  return 1;
1623
1900
  }
1624
1901
 
1902
+ int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
1903
+ ggml_tensor * op = ctx->node(idx);
1904
+
1905
+ ggml_metal_library_t lib = ctx->lib;
1906
+ ggml_metal_encoder_t enc = ctx->enc;
1907
+
1908
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1909
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1910
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1911
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1912
+
1913
+ const int32_t * opts = op->op_params;
1914
+ ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1915
+
1916
+ const int32_t k0 = opts[1];
1917
+ const int32_t s0 = opts[2];
1918
+ const int32_t p0 = opts[3];
1919
+
1920
+ const int64_t IW = op->src[0]->ne[0];
1921
+ const int64_t OW = op->ne[0];
1922
+
1923
+ const int64_t np = ggml_nelements(op);
1924
+
1925
+ ggml_metal_kargs_pool_1d args_pool_1d = {
1926
+ /* .k0 = */ k0,
1927
+ /* .s0 = */ s0,
1928
+ /* .p0 = */ p0,
1929
+ /* .IW = */ IW,
1930
+ /* .OW = */ OW,
1931
+ /* .np = */ np
1932
+ };
1933
+
1934
+ auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
1935
+
1936
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1937
+ const int ntg = (np + nth - 1) / nth;
1938
+
1939
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
1940
+ ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
1941
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1942
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1943
+
1944
+ ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1945
+
1946
+ return 1;
1947
+ }
1948
+
1949
+
1625
1950
  int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1626
1951
  ggml_tensor * op = ctx->node(idx);
1627
1952
 
@@ -1717,6 +2042,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1717
2042
  (
1718
2043
  op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
1719
2044
  op->src[0]->type == GGML_TYPE_F16 ||
2045
+ op->src[0]->type == GGML_TYPE_BF16 ||
1720
2046
  op->src[0]->type == GGML_TYPE_Q4_0 ||
1721
2047
  op->src[0]->type == GGML_TYPE_Q4_1 ||
1722
2048
  op->src[0]->type == GGML_TYPE_Q5_0 ||
@@ -1731,6 +2057,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1731
2057
  op->src[0]->type == GGML_TYPE_Q4_K ||
1732
2058
  op->src[0]->type == GGML_TYPE_Q5_K ||
1733
2059
  op->src[0]->type == GGML_TYPE_Q6_K ||
2060
+ op->src[0]->type == GGML_TYPE_Q2_K ||
2061
+ op->src[0]->type == GGML_TYPE_Q3_K ||
1734
2062
  false) && (ne11 >= 4 && ne11 <= 8)
1735
2063
  )
1736
2064
  )
@@ -1759,7 +2087,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1759
2087
  const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
1760
2088
  int16_t r1ptg = 4; // num src1 rows per threadgroup
1761
2089
 
1762
- // note: not sure how optimal are those across all different hardware. there might be someting cleverer
2090
+ // note: not sure how optimal are those across all different hardware. there might be something cleverer
1763
2091
  switch (ne11) {
1764
2092
  case 2:
1765
2093
  r1ptg = 2; break;
@@ -2239,7 +2567,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
2239
2567
  // return res;
2240
2568
  //}
2241
2569
 
2242
- const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
2570
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
2243
2571
  const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2244
2572
 
2245
2573
  const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
@@ -2355,7 +2683,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2355
2683
 
2356
2684
  if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
2357
2685
  // half8x8 kernel
2358
- const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
2686
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
2359
2687
  const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
2360
2688
 
2361
2689
  GGML_ASSERT(nqptg <= 32);
@@ -2464,7 +2792,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2464
2792
 
2465
2793
  // simdgroups per threadgroup (a.k.a. warps)
2466
2794
  //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2467
- int32_t nsg = 4;
2795
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
2468
2796
 
2469
2797
  const size_t smem = FATTN_SMEM(nsg);
2470
2798
 
@@ -2522,9 +2850,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2522
2850
  #undef FATTN_SMEM
2523
2851
  } else {
2524
2852
  // half4x4 kernel
2525
- const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
2853
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
2526
2854
  const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2527
- const int nkpsg = 1*ncpsg;
2855
+ const int nhptg = 1; // heads per threadgroup
2528
2856
 
2529
2857
  GGML_ASSERT(nqptg <= 32);
2530
2858
  GGML_ASSERT(nqptg % 1 == 0);
@@ -2576,6 +2904,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2576
2904
  ggml_metal_op_concurrency_reset(ctx);
2577
2905
  }
2578
2906
 
2907
+ // note: for simplicity assume the K is larger or equal than V
2908
+ GGML_ASSERT(ne10 >= ne20);
2909
+
2579
2910
  // ne00 + 2*ncpsg*(nsg)
2580
2911
  // for each query, we load it as f16 in shared memory (ne00)
2581
2912
  // and store the soft_max values and the mask
@@ -2583,28 +2914,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2583
2914
  // ne20*(nsg)
2584
2915
  // each simdgroup has a full f32 head vector in shared mem to accumulate results
2585
2916
  //
2586
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
2587
-
2588
- int64_t nsgmax = 2;
2589
- while (true) {
2590
- const size_t smem = FATTN_SMEM(nsgmax);
2591
- // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
2592
- if (smem > props_dev->max_theadgroup_memory_size/2) {
2593
- break;
2594
- }
2595
- nsgmax *= 2;
2596
- }
2597
- nsgmax /= 2;
2598
-
2599
- // simdgroups per threadgroup (a.k.a. warps)
2600
- //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
2601
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
2917
+ #define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
2602
2918
 
2603
2919
  int64_t nsg = 1;
2604
- while (nsg <= nsgt) {
2605
- nsg *= 2;
2606
- }
2607
- nsg /= 2;
2608
2920
 
2609
2921
  // workgroups
2610
2922
  // each workgroup handles nsg*nkpsg cache values
@@ -2617,7 +2929,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2617
2929
  } else {
2618
2930
  nwg = 32;
2619
2931
  nsg = 1;
2620
- while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
2932
+ while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
2621
2933
  nsg *= 2;
2622
2934
  }
2623
2935
  }
@@ -2683,7 +2995,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2683
2995
 
2684
2996
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2685
2997
 
2686
- ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
2998
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2687
2999
  } else {
2688
3000
  // sanity checks
2689
3001
  assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
@@ -2696,7 +3008,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2696
3008
  ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2697
3009
 
2698
3010
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2699
- ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
3011
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2700
3012
 
2701
3013
  // sync the 2 kernels
2702
3014
  ggml_metal_op_concurrency_reset(ctx);
@@ -2748,8 +3060,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2748
3060
  GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2749
3061
  GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
2750
3062
 
2751
- bool bcast_row = false;
2752
-
2753
3063
  ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2754
3064
  ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2755
3065
  ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
@@ -2843,18 +3153,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2843
3153
 
2844
3154
  struct ggml_metal_pipeline_with_params pipeline;
2845
3155
 
2846
- if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2847
- GGML_ASSERT(ggml_is_contiguous(op->src[0]));
2848
-
2849
- // src1 is a row
2850
- GGML_ASSERT(ne11 == 1);
2851
-
2852
- pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
2853
-
2854
- bcast_row = true;
2855
- } else {
2856
- pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
2857
- }
3156
+ pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
2858
3157
 
2859
3158
  if (n_fuse > 1) {
2860
3159
  bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
@@ -2868,20 +3167,26 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2868
3167
  }
2869
3168
  }
2870
3169
 
3170
+ if (pipeline.c4) {
3171
+ args.ne00 = ne00/4;
3172
+ args.ne10 = ne10/4;
3173
+ args.ne0 = ne0/4;
3174
+ }
3175
+
2871
3176
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2872
3177
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2873
3178
  ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2874
3179
  ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2875
3180
  ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
2876
3181
 
2877
- if (bcast_row) {
2878
- const int64_t n = ggml_nelements(op)/4;
2879
-
2880
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
3182
+ if (pipeline.cnt) {
3183
+ ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
2881
3184
  } else {
2882
- int nth = 32;
3185
+ const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3186
+
3187
+ int nth = 1;
2883
3188
 
2884
- while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3189
+ while (2*nth < args.ne0 && nth < nth_max) {
2885
3190
  nth *= 2;
2886
3191
  }
2887
3192
 
@@ -2902,39 +3207,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2902
3207
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2903
3208
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2904
3209
 
3210
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3211
+
3212
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3213
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3214
+
2905
3215
  float eps;
2906
3216
  memcpy(&eps, op->op_params, sizeof(float));
2907
3217
 
2908
- int nth = 32; // SIMD width
2909
-
2910
3218
  ggml_metal_kargs_l2_norm args = {
2911
- /*.ne00 =*/ ne00,
2912
- /*.ne00_4 =*/ ne00/4,
2913
- /*.nb01 =*/ nb01,
2914
- /*.eps =*/ eps,
3219
+ /*.ne00 =*/ ne00,
3220
+ /*.ne01 =*/ ne01,
3221
+ /*.ne02 =*/ ne02,
3222
+ /*.ne03 =*/ ne03,
3223
+ /*.nb00 =*/ nb00,
3224
+ /*.nb01 =*/ nb01,
3225
+ /*.nb02 =*/ nb02,
3226
+ /*.nb03 =*/ nb03,
3227
+ /*.ne0 =*/ ne0,
3228
+ /*.ne1 =*/ ne1,
3229
+ /*.ne2 =*/ ne2,
3230
+ /*.ne3 =*/ ne3,
3231
+ /*.nb0 =*/ nb0,
3232
+ /*.nb1 =*/ nb1,
3233
+ /*.nb2 =*/ nb2,
3234
+ /*.nb3 =*/ nb3,
3235
+ /*.eps =*/ eps,
2915
3236
  };
2916
3237
 
2917
3238
  auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
2918
3239
 
2919
- while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3240
+ if (pipeline.c4) {
3241
+ args.ne00 = ne00/4;
3242
+ args.ne0 = ne0/4;
3243
+ }
3244
+
3245
+ int nth = 32; // SIMD width
3246
+
3247
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
2920
3248
  nth *= 2;
2921
3249
  }
2922
3250
 
2923
3251
  nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2924
- nth = std::min(nth, ne00/4);
2925
3252
 
2926
3253
  const size_t smem = pipeline.smem;
2927
3254
 
2928
- const int64_t nrows = ggml_nrows(op->src[0]);
2929
-
2930
3255
  ggml_metal_encoder_set_pipeline(enc, pipeline);
2931
3256
  ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2932
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2933
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3257
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3258
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
2934
3259
 
2935
3260
  ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2936
3261
 
2937
- ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
3262
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2938
3263
 
2939
3264
  return 1;
2940
3265
  }
@@ -3484,32 +3809,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
3484
3809
  GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3485
3810
  GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3486
3811
 
3487
- const float sf0 = (float)ne0/op->src[0]->ne[0];
3488
- const float sf1 = (float)ne1/op->src[0]->ne[1];
3489
- const float sf2 = (float)ne2/op->src[0]->ne[2];
3490
- const float sf3 = (float)ne3/op->src[0]->ne[3];
3812
+ float sf0 = (float)ne0/op->src[0]->ne[0];
3813
+ float sf1 = (float)ne1/op->src[0]->ne[1];
3814
+ float sf2 = (float)ne2/op->src[0]->ne[2];
3815
+ float sf3 = (float)ne3/op->src[0]->ne[3];
3816
+
3817
+ const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
3818
+
3819
+ float poffs = 0.5f;
3820
+
3821
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
3822
+ poffs = 0.0f;
3823
+ sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
3824
+ sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
3825
+ }
3491
3826
 
3492
3827
  ggml_metal_kargs_upscale args = {
3493
- /*.ne00 =*/ ne00,
3494
- /*.ne01 =*/ ne01,
3495
- /*.ne02 =*/ ne02,
3496
- /*.ne03 =*/ ne03,
3497
- /*.nb00 =*/ nb00,
3498
- /*.nb01 =*/ nb01,
3499
- /*.nb02 =*/ nb02,
3500
- /*.nb03 =*/ nb03,
3501
- /*.ne0 =*/ ne0,
3502
- /*.ne1 =*/ ne1,
3503
- /*.ne2 =*/ ne2,
3504
- /*.ne3 =*/ ne3,
3505
- /*.nb0 =*/ nb0,
3506
- /*.nb1 =*/ nb1,
3507
- /*.nb2 =*/ nb2,
3508
- /*.nb3 =*/ nb3,
3509
- /*.sf0 =*/ sf0,
3510
- /*.sf1 =*/ sf1,
3511
- /*.sf2 =*/ sf2,
3512
- /*.sf3 =*/ sf3
3828
+ /*.ne00 =*/ ne00,
3829
+ /*.ne01 =*/ ne01,
3830
+ /*.ne02 =*/ ne02,
3831
+ /*.ne03 =*/ ne03,
3832
+ /*.nb00 =*/ nb00,
3833
+ /*.nb01 =*/ nb01,
3834
+ /*.nb02 =*/ nb02,
3835
+ /*.nb03 =*/ nb03,
3836
+ /*.ne0 =*/ ne0,
3837
+ /*.ne1 =*/ ne1,
3838
+ /*.ne2 =*/ ne2,
3839
+ /*.ne3 =*/ ne3,
3840
+ /*.nb0 =*/ nb0,
3841
+ /*.nb1 =*/ nb1,
3842
+ /*.nb2 =*/ nb2,
3843
+ /*.nb3 =*/ nb3,
3844
+ /*.sf0 =*/ sf0,
3845
+ /*.sf1 =*/ sf1,
3846
+ /*.sf2 =*/ sf2,
3847
+ /*.sf3 =*/ sf3,
3848
+ /*.poffs =*/ poffs,
3513
3849
  };
3514
3850
 
3515
3851
  auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
@@ -3942,42 +4278,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3942
4278
  return 1;
3943
4279
  }
3944
4280
 
3945
- int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
3946
- ggml_tensor * op = ctx->node(idx);
3947
-
3948
- ggml_metal_library_t lib = ctx->lib;
3949
- ggml_metal_encoder_t enc = ctx->enc;
3950
-
3951
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3952
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3953
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3954
- GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3955
-
3956
- float slope;
3957
- memcpy(&slope, op->op_params, sizeof(float));
3958
-
3959
- ggml_metal_kargs_leaky_relu args = {
3960
- /*.slope =*/ slope
3961
- };
3962
-
3963
- auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
3964
-
3965
- int64_t n = ggml_nelements(op);
3966
-
3967
- if (n % 4 == 0) {
3968
- n /= 4;
3969
- }
3970
-
3971
- ggml_metal_encoder_set_pipeline(enc, pipeline);
3972
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3973
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3974
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3975
-
3976
- ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
3977
-
3978
- return 1;
3979
- }
3980
-
3981
4281
  int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
3982
4282
  ggml_tensor * op = ctx->node(idx);
3983
4283