whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -0,0 +1,2263 @@
1
+ #include "ggml.h"
2
+ #include "ggml-impl.h"
3
+ #include "ggml-backend.h"
4
+ #include "ggml-backend-impl.h"
5
+ #include "ggml-alloc.h"
6
+ #include "ggml-cpp.h"
7
+
8
+ #include <algorithm>
9
+ #include <cassert>
10
+ #include <cmath>
11
+ #include <cstddef>
12
+ #include <cstdint>
13
+ #include <cstring>
14
+ #include <map>
15
+ #include <memory>
16
+ #include <set>
17
+ #include <string>
18
+ #include <tuple>
19
+ #include <utility>
20
+ #include <vector>
21
+
22
+ struct ggml_backend_meta_device;
23
+ struct ggml_backend_meta_buffer_type;
24
+ struct ggml_backend_meta_buffer;
25
+ struct ggml_backend_meta;
26
+
27
+ const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) {
28
+ switch (split_axis) {
29
+ case GGML_BACKEND_SPLIT_AXIS_0:
30
+ return "0";
31
+ case GGML_BACKEND_SPLIT_AXIS_1:
32
+ return "1";
33
+ case GGML_BACKEND_SPLIT_AXIS_2:
34
+ return "2";
35
+ case GGML_BACKEND_SPLIT_AXIS_3:
36
+ return "3";
37
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
38
+ return "MIRRORED";
39
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL:
40
+ return "PARTIAL";
41
+ case GGML_BACKEND_SPLIT_AXIS_NONE:
42
+ return "NONE";
43
+ case GGML_BACKEND_SPLIT_AXIS_UNKNOWN:
44
+ return "UNKNOWN";
45
+ default:
46
+ GGML_ABORT("fatal error");
47
+ }
48
+ }
49
+
50
+ //
51
+ // meta backend device
52
+ //
53
+
54
+ struct ggml_backend_meta_device_context {
55
+ std::vector<ggml_backend_dev_t> simple_devs;
56
+ ggml_backend_meta_get_split_state_t get_split_state;
57
+ void * get_split_state_ud;
58
+
59
+ std::string name;
60
+ std::string description;
61
+
62
+ ggml_backend_meta_device_context(
63
+ std::vector<ggml_backend_dev_t> simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) :
64
+ simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) {
65
+ name = std::string("Meta(");
66
+ description = std::string("Meta(");
67
+ for (size_t i = 0; i < simple_devs.size(); i++) {
68
+ if (i > 0) {
69
+ name += ",";
70
+ description += ",";
71
+ }
72
+ name += ggml_backend_dev_name (simple_devs[i]);
73
+ description += ggml_backend_dev_description(simple_devs[i]);
74
+ }
75
+ name += ")";
76
+ description += ")";
77
+ }
78
+
79
+ bool operator<(const ggml_backend_meta_device_context & other) const {
80
+ return std::tie(simple_devs, get_split_state, get_split_state_ud)
81
+ < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud);
82
+ }
83
+ };
84
+
85
+ static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev);
86
+
87
+ static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) {
88
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
89
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
90
+ return meta_dev_ctx->name.c_str();
91
+ }
92
+
93
+ static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) {
94
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
95
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
96
+ return meta_dev_ctx->description.c_str();
97
+ }
98
+
99
+ static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
100
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
101
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
102
+ *free = 0;
103
+ *total = 0;
104
+ for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) {
105
+ size_t tmp_free, tmp_total;
106
+ ggml_backend_dev_memory(dev, &tmp_free, &tmp_total);
107
+ *free += tmp_free;
108
+ *total += tmp_total;
109
+ }
110
+ }
111
+
112
+ static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) {
113
+ return GGML_BACKEND_DEVICE_TYPE_META;
114
+
115
+ GGML_UNUSED(dev);
116
+ }
117
+
118
+ static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
119
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
120
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
121
+
122
+ // TODO replace placeholders
123
+ props->name = ggml_backend_meta_device_get_name(dev);
124
+ props->description = ggml_backend_meta_device_get_description(dev);
125
+ props->type = ggml_backend_meta_device_get_type(dev);
126
+ props->device_id = 0;
127
+
128
+ ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total);
129
+
130
+ props->caps = {
131
+ /* .async = */ true,
132
+ /* .host_buffer = */ false, // Not implemented.
133
+ /* .buffer_from_host_ptr = */ false, // Not implemented.
134
+ /* .events = */ false, // Not implemented.
135
+ };
136
+ for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
137
+ ggml_backend_dev_props tmp_props;
138
+ ggml_backend_dev_get_props(simple_dev, &tmp_props);
139
+ props->caps.async = props->caps.async && tmp_props.caps.async;
140
+ props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer;
141
+ props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr;
142
+ props->caps.events = props->caps.events && tmp_props.caps.events;
143
+ }
144
+ }
145
+
146
+ static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params);
147
+
148
+ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev);
149
+
150
+ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev);
151
+
152
+ static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
153
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
154
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
155
+ return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(),
156
+ [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); });
157
+ }
158
+
159
+ static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
160
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
161
+ ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft);
162
+ if (!ggml_backend_dev_is_meta(dev_buft)) {
163
+ return false;
164
+ }
165
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
166
+ const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context;
167
+ if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) {
168
+ return false;
169
+ }
170
+ for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) {
171
+ if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) {
172
+ return false;
173
+ }
174
+ }
175
+ return true;
176
+ }
177
+
178
+ static const ggml_backend_device_i ggml_backend_meta_device_iface = {
179
+ /* .get_name = */ ggml_backend_meta_device_get_name,
180
+ /* .get_description = */ ggml_backend_meta_device_get_description,
181
+ /* .get_memory = */ ggml_backend_meta_device_get_memory,
182
+ /* .get_type = */ ggml_backend_meta_device_get_type,
183
+ /* .get_props = */ ggml_backend_meta_device_get_props,
184
+ /* .init_backend = */ ggml_backend_meta_device_init_backend,
185
+ /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type,
186
+ /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type,
187
+ /* .buffer_from_host_ptr = */ nullptr,
188
+ /* .supports_op = */ ggml_backend_meta_device_supports_op,
189
+ /* .supports_buft = */ ggml_backend_meta_device_supports_buft,
190
+ /* .offload_op = */ nullptr,
191
+ /* .event_new = */ nullptr,
192
+ /* .event_free = */ nullptr,
193
+ /* .event_synchronize = */ nullptr,
194
+ };
195
+
196
+ static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) {
197
+ return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name;
198
+ }
199
+
200
+ static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) {
201
+ GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
202
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
203
+ return meta_dev_ctx->simple_devs.size();
204
+ }
205
+
206
+ static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) {
207
+ GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
208
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
209
+ GGML_ASSERT(index < meta_dev_ctx->simple_devs.size());
210
+ return meta_dev_ctx->simple_devs[index];
211
+ }
212
+
213
+ ggml_backend_dev_t ggml_backend_meta_device(
214
+ ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
215
+ GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES);
216
+ // TODO: this is not thread-safe - needs to be fixed
217
+ static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs;
218
+ static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
219
+
220
+ std::vector<ggml_backend_dev_t> simple_devs;
221
+ simple_devs.reserve(n_devs);
222
+ for (size_t i = 0; i < n_devs; i++) {
223
+ simple_devs.push_back(devs[i]);
224
+ }
225
+ ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud);
226
+
227
+ {
228
+ auto it = meta_devs.find(ctx);
229
+ if (it != meta_devs.end()) {
230
+ return &it->second;
231
+ }
232
+ }
233
+ ctxs.push_back(std::make_unique<ggml_backend_meta_device_context>(ctx));
234
+
235
+ struct ggml_backend_device meta_dev = {
236
+ /*iface =*/ ggml_backend_meta_device_iface,
237
+ /*reg =*/ nullptr,
238
+ /*ctx =*/ ctxs.back().get(),
239
+ };
240
+
241
+ auto result = meta_devs.emplace(*ctxs.back(), meta_dev);
242
+ return &result.first->second;
243
+ }
244
+
245
+ //
246
+ // meta backend buffer type
247
+ //
248
+
249
+ struct ggml_backend_meta_buffer_type_context {
250
+ std::vector<ggml_backend_buffer_type_t> simple_bufts;
251
+
252
+ std::string name;
253
+
254
+ ggml_backend_meta_buffer_type_context(std::vector<ggml_backend_buffer_type_t> simple_bufts) : simple_bufts(std::move(simple_bufts)) {
255
+ name = "Meta(";
256
+ for (size_t i = 0; i < simple_bufts.size(); i++) {
257
+ if (i > 0) {
258
+ name += ",";
259
+ }
260
+ name += ggml_backend_buft_name(simple_bufts[i]);
261
+ }
262
+ name += ")";
263
+ }
264
+
265
+ bool operator<(const ggml_backend_meta_buffer_type_context & other) const {
266
+ return simple_bufts < other.simple_bufts;
267
+ }
268
+ };
269
+
270
+ static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) {
271
+ GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
272
+ const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
273
+ return meta_buft_ctx->simple_bufts.size();
274
+ }
275
+
276
+ static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
277
+ GGML_ASSERT(ggml_backend_buft_is_meta(buft));
278
+ const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context;
279
+ return meta_buft_ctx->name.c_str();
280
+ }
281
+
282
+ static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) {
283
+ GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
284
+ const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
285
+ GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size());
286
+ return meta_buft_ctx->simple_bufts[index];
287
+ }
288
+
289
+ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
290
+
291
+ static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
292
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
293
+ size_t max_alignment = 1;
294
+ for (size_t i = 0; i < n_simple_bufts; i++) {
295
+ const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i));
296
+ max_alignment = std::max(max_alignment, alignment);
297
+ GGML_ASSERT(max_alignment % alignment == 0);
298
+ }
299
+ return max_alignment;
300
+ }
301
+
302
+ static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
303
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
304
+ size_t max_size = SIZE_MAX;
305
+ for (size_t i = 0; i < n_simple_bufts; i++) {
306
+ max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i)));
307
+ }
308
+ return max_size;
309
+ }
310
+
311
+ static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
312
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
313
+ size_t max_alloc_size = 0;
314
+ for (size_t i = 0; i < n_simple_bufts; i++) {
315
+ const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor);
316
+ max_alloc_size = std::max(max_alloc_size, alloc_size);
317
+ }
318
+ return max_alloc_size;
319
+ }
320
+
321
+ static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
322
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
323
+ for (size_t i = 0; i < n_simple_bufts; i++) {
324
+ if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) {
325
+ return false;
326
+ }
327
+ }
328
+ return true;
329
+ }
330
+
331
+ static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = {
332
+ /* .get_name = */ ggml_backend_meta_buffer_type_get_name,
333
+ /* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer,
334
+ /* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment,
335
+ /* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size,
336
+ /* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size,
337
+ /* .is_host = */ ggml_backend_meta_buffer_type_is_host,
338
+ };
339
+
340
+ bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) {
341
+ return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name;
342
+ }
343
+
344
+ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) {
345
+ static std::map<ggml_backend_dev_t, struct ggml_backend_buffer_type> meta_bufts;
346
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
347
+ {
348
+ auto it = meta_bufts.find(dev);
349
+ if (it != meta_bufts.end()) {
350
+ return &it->second;
351
+ }
352
+ }
353
+
354
+ const size_t n_devs = ggml_backend_meta_dev_n_devs(dev);
355
+ std::vector<ggml_backend_buffer_type_t> simple_bufts;
356
+ simple_bufts.reserve(n_devs);
357
+ for (size_t i = 0; i < n_devs; i++) {
358
+ simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i)));
359
+ }
360
+ ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts);
361
+
362
+ struct ggml_backend_buffer_type meta_buft = {
363
+ /*iface =*/ ggml_backend_meta_buffer_type_iface,
364
+ /*device =*/ dev,
365
+ /*ctx =*/ buft_ctx,
366
+ };
367
+ auto result = meta_bufts.emplace(dev, meta_buft);
368
+ return &result.first->second;
369
+ }
370
+
371
+ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) {
372
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
373
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
374
+
375
+ ggml_backend_buffer_type_t host_buft = nullptr;
376
+ for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
377
+ ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev);
378
+ if (simple_host_buft == nullptr) {
379
+ return nullptr;
380
+ }
381
+ if (host_buft == nullptr) {
382
+ host_buft = simple_host_buft;
383
+ } else if (host_buft != simple_host_buft) {
384
+ // if different simple devices have different host buffer types,
385
+ // we cannot provide a single host buffer type for the meta device
386
+ return nullptr;
387
+ }
388
+ }
389
+ return host_buft;
390
+ }
391
+
392
+ //
393
+ // meta backend buffer
394
+ //
395
+
396
+ // Container to hold the tensor slices per simple ggml backend buffer.
397
+ struct ggml_backend_meta_simple_tensor_container {
398
+ std::vector<ggml_context_ptr> ctxs;
399
+ std::map<const ggml_tensor *, std::vector<ggml_tensor *>> simple_tensors;
400
+
401
+ ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) {
402
+ ctxs.reserve(n_simple);
403
+ for (int i = 0; i < n_simple; i++) {
404
+ ctxs.emplace_back(ggml_init(params));
405
+ }
406
+ }
407
+ ggml_backend_meta_simple_tensor_container() {}
408
+ };
409
+
410
+ struct ggml_backend_meta_buffer_context {
411
+ // FIXME
412
+ // Most tensors can simply be stored statically in their own buffer.
413
+ // Externally created views however also need a mapping to simple tensors but they use the buffer of the view source.
414
+ // If external views are simply using that buffer they will slowly deplete its memory.
415
+ // Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp.
416
+ // Long-term: tie the lifetime of external views to the meta backend executing the graph instead,
417
+ // currently not possible due to graph-external operations in the backend scheduler.
418
+ ggml_backend_meta_simple_tensor_container stc_static;
419
+ ggml_backend_meta_simple_tensor_container stc_compute[2];
420
+ int stc_compute_index = 0;
421
+ int stc_compute_index_next = 0;
422
+ std::vector<ggml_backend_buffer_ptr> bufs;
423
+
424
+ // FIXME
425
+ // The size of the split state cache is unbounded and can theoretically grow infinitely large.
426
+ // However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive.
427
+ static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding);
428
+ std::map<std::pair<const ggml_tensor *, bool>, std::pair<ggml_backend_meta_split_state, char[nbtc]>> split_state_cache;
429
+
430
+ int debug;
431
+
432
+ ggml_backend_meta_buffer_context(
433
+ ggml_backend_meta_simple_tensor_container & stc_static,
434
+ ggml_backend_meta_simple_tensor_container & stc_compute_0,
435
+ ggml_backend_meta_simple_tensor_container & stc_compute_1,
436
+ const std::vector<ggml_backend_buffer_t> & bufs)
437
+ : stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} {
438
+ this->bufs.reserve(bufs.size());
439
+ for (ggml_backend_buffer_t buf : bufs) {
440
+ this->bufs.emplace_back(buf);
441
+ }
442
+ const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG");
443
+ debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0;
444
+ }
445
+
446
+ ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) {
447
+ if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) {
448
+ return stc_static;
449
+ }
450
+ return stc_compute[stc_compute_index];
451
+ }
452
+ };
453
+
454
+ static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) {
455
+ GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
456
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
457
+ delete buf_ctx;
458
+ }
459
+
460
+ static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) {
461
+ GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
462
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
463
+ return buf_ctx->bufs.size();
464
+ }
465
+
466
+ static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) {
467
+ GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
468
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
469
+ GGML_ASSERT(index < buf_ctx->bufs.size());
470
+ return buf_ctx->bufs[index].get();
471
+ }
472
+
473
+ static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) {
474
+ GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
475
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
476
+ GGML_ASSERT(index < buf_ctx->bufs.size());
477
+
478
+ ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor);
479
+ auto it = stc.simple_tensors.find(tensor);
480
+ if (it == stc.simple_tensors.end()) {
481
+ return nullptr;
482
+ }
483
+ return it->second[index];
484
+ }
485
+
486
+ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync);
487
+
488
+ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
489
+ ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) {
490
+ // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way.
491
+ // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there.
492
+ // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results.
493
+ const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
494
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
495
+
496
+ auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool {
497
+ if (a.axis != b.axis) {
498
+ return false;
499
+ }
500
+ for (size_t j = 0; j < n_bufs; j++) {
501
+ int64_t sum_a = 0;
502
+ for (size_t s = 0; s < a.n_segments; s++) {
503
+ sum_a += a.ne[s*n_bufs + j] * a.nr[s];
504
+ }
505
+ int64_t sum_b = 0;
506
+ for (size_t s = 0; s < b.n_segments; s++) {
507
+ sum_b += b.ne[s*n_bufs + j] * b.nr[s];
508
+ }
509
+ if (sum_a != sum_b) {
510
+ return false;
511
+ }
512
+ }
513
+ return true;
514
+ };
515
+
516
+ auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state {
517
+ ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1};
518
+ for (size_t i = 0; i < GGML_MAX_SRC; i++) {
519
+ if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
520
+ continue;
521
+ }
522
+ if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
523
+ ret = src_ss[i];
524
+ } else if (!split_states_equal(src_ss[i], ret)) {
525
+ ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
526
+ break;
527
+ }
528
+ }
529
+ if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
530
+ ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
531
+ }
532
+ if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
533
+ ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
534
+ }
535
+ GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
536
+ return ret;
537
+ };
538
+
539
+ // Some ops process data on a per-row bases:
540
+ auto handle_per_row = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
541
+ GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0);
542
+ return src_ss[0];
543
+ };
544
+
545
+ // Some ops broadcast the src1 data across src0:
546
+ auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
547
+ if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS &&
548
+ tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
549
+ return src_ss[0];
550
+ }
551
+ if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis ||
552
+ (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) {
553
+ return src_ss[0]; // GGML_OP_ADD_ID
554
+ }
555
+ GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
556
+ return handle_generic(src_ss, /*scalar_only =*/ false);
557
+ };
558
+
559
+ auto handle_concat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
560
+ const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0));
561
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) {
562
+ GGML_ASSERT(concat_axis != src_ss[1].axis);
563
+ return src_ss[1];
564
+ }
565
+ if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
566
+ GGML_ASSERT(concat_axis != src_ss[0].axis);
567
+ return src_ss[0];
568
+ }
569
+ if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) {
570
+ return src_ss[0];
571
+ }
572
+ return handle_generic(src_ss, /*scalar_only =*/ true);
573
+ };
574
+
575
+ auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
576
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
577
+ return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1};
578
+ }
579
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
580
+ ggml_backend_meta_split_state ret = src_ss[0];
581
+ ret.axis = GGML_BACKEND_SPLIT_AXIS_0;
582
+ ret.nr[0] = 1;
583
+ ret.n_segments = 1;
584
+ return ret;
585
+ }
586
+ if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
587
+ return src_ss[1];
588
+ }
589
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) {
590
+ GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1]));
591
+ return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1};
592
+ }
593
+ GGML_ABORT("fatal error");
594
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
595
+ };
596
+
597
+ auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
598
+ switch (src_ss[0].axis) {
599
+ case GGML_BACKEND_SPLIT_AXIS_0:
600
+ case GGML_BACKEND_SPLIT_AXIS_1:
601
+ case GGML_BACKEND_SPLIT_AXIS_2:
602
+ case GGML_BACKEND_SPLIT_AXIS_3: {
603
+ GGML_ASSERT(src_ss[0].n_segments == 1);
604
+ if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) {
605
+ return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1};
606
+ }
607
+ int64_t base_ne_in = tensor->src[0]->ne[0];
608
+ for (int dim = 1; dim <= src_ss[0].axis; dim++) {
609
+ base_ne_in *= tensor->src[0]->ne[dim];
610
+ }
611
+ base_ne_in /= src_ss[0].nr[0];
612
+ int64_t base_ne_out = 1;
613
+ for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
614
+ const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim];
615
+ if (base_ne_out_next % base_ne_in == 0) {
616
+ return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1};
617
+ }
618
+ if (base_ne_out_next > base_ne_in) {
619
+ GGML_ASSERT(src_ss[0].n_segments == 1);
620
+ GGML_ASSERT(src_ss[0].nr[0] == 1);
621
+ return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1};
622
+ }
623
+ base_ne_out = base_ne_out_next;
624
+ }
625
+ GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op));
626
+ }
627
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
628
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
629
+ return src_ss[0];
630
+ }
631
+ default: {
632
+ GGML_ABORT("fatal error");
633
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
634
+ }
635
+ }
636
+ };
637
+
638
+ auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
639
+ if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
640
+ return handle_reshape(src_ss);
641
+ }
642
+ return handle_generic(src_ss, /*scalar_only =*/ false);
643
+ };
644
+
645
+ auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
646
+ if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) {
647
+ return handle_reshape(src_ss);
648
+ }
649
+ const int axis = src_ss[0].axis;
650
+ {
651
+ bool all_strides_the_same = true;
652
+ for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
653
+ if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) {
654
+ continue;
655
+ }
656
+ if (tensor->nb[dim] != tensor->src[0]->nb[dim]) {
657
+ all_strides_the_same = false;
658
+ break;
659
+ }
660
+ }
661
+ if (all_strides_the_same) {
662
+ return src_ss[0];
663
+ }
664
+ }
665
+ if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) {
666
+ for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) {
667
+ if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) {
668
+ return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1};
669
+ }
670
+ }
671
+ GGML_ABORT("fatal error");
672
+ }
673
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
674
+ return src_ss[0];
675
+ }
676
+ GGML_ABORT("view of permuted tensor not implemented");
677
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
678
+ };
679
+
680
+ auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
681
+ switch (src_ss[0].axis) {
682
+ case GGML_BACKEND_SPLIT_AXIS_0:
683
+ case GGML_BACKEND_SPLIT_AXIS_1:
684
+ case GGML_BACKEND_SPLIT_AXIS_2:
685
+ case GGML_BACKEND_SPLIT_AXIS_3: {
686
+ GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1);
687
+ return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1};
688
+ }
689
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
690
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
691
+ return src_ss[0];
692
+ }
693
+ default: {
694
+ GGML_ABORT("fatal error");
695
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
696
+ }
697
+ }
698
+ };
699
+
700
+ auto handle_transpose = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
701
+ switch (src_ss[0].axis) {
702
+ case GGML_BACKEND_SPLIT_AXIS_0:
703
+ case GGML_BACKEND_SPLIT_AXIS_1: {
704
+ GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1);
705
+ return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1};
706
+ }
707
+ case GGML_BACKEND_SPLIT_AXIS_2:
708
+ case GGML_BACKEND_SPLIT_AXIS_3:
709
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
710
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
711
+ return src_ss[0];
712
+ }
713
+ default: {
714
+ GGML_ABORT("fatal error");
715
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
716
+ }
717
+ }
718
+ };
719
+
720
+ auto handle_get_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
721
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
722
+ return src_ss[0];
723
+ }
724
+ return handle_generic(src_ss, /*scalar_only =*/ true);
725
+ };
726
+
727
+ auto handle_set_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
728
+ GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1);
729
+ GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
730
+ GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2]));
731
+ return src_ss[0];
732
+ };
733
+
734
+ auto handle_rope = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
735
+ GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
736
+ return src_ss[0];
737
+ };
738
+
739
+ auto handle_pad = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
740
+ if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
741
+ GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0);
742
+ GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0);
743
+ }
744
+ return src_ss[0];
745
+ };
746
+
747
+ auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
748
+ GGML_ASSERT( src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2);
749
+ GGML_ASSERT( src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2);
750
+ GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2);
751
+ GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
752
+ GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0);
753
+ return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1};
754
+ };
755
+
756
+ auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
757
+ if (src_ss[0].axis == src_ss[1].axis) {
758
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) {
759
+ return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1};
760
+ }
761
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) {
762
+ return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
763
+ }
764
+ }
765
+ return handle_generic(src_ss, /*scalar_only =*/ false);
766
+ };
767
+
768
+ auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
769
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
770
+ src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
771
+ src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
772
+ return src_ss[0];
773
+ }
774
+ GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1);
775
+ GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1);
776
+ GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
777
+ GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
778
+ GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
779
+ // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2,
780
+ // so a head-aligned split on the input cache lands on axis 2 here.
781
+ GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0);
782
+ return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
783
+ };
784
+
785
+ auto calculate_split_state = [&]() -> ggml_backend_meta_split_state {
786
+ if (ggml_nelements(tensor) == 0) {
787
+ return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
788
+ }
789
+ if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) {
790
+ ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer));
791
+ const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
792
+ ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud);
793
+ if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) {
794
+ const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1;
795
+ int64_t ne_sum = 0;
796
+ for (size_t s = 0; s < ret.n_segments; s++) {
797
+ for (size_t j = 0; j < n_bufs; j++) {
798
+ GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0);
799
+ ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s];
800
+ }
801
+ }
802
+ GGML_ASSERT(ne_sum == tensor->ne[ret.axis]);
803
+ }
804
+ return ret;
805
+ }
806
+
807
+ std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1});
808
+ for (size_t i = 0; i < GGML_MAX_SRC; i++) {
809
+ if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
810
+ src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
811
+ continue;
812
+ }
813
+ src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true);
814
+ GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
815
+ }
816
+
817
+ ggml_backend_meta_split_state split_state;
818
+ switch (tensor->op) {
819
+ case GGML_OP_NONE: {
820
+ split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1};
821
+ } break;
822
+ case GGML_OP_DUP: {
823
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
824
+ } break;
825
+ case GGML_OP_ADD:
826
+ case GGML_OP_ADD_ID: {
827
+ split_state = handle_bin_bcast(src_ss);
828
+ } break;
829
+ case GGML_OP_ADD1:
830
+ case GGML_OP_ACC: {
831
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
832
+ } break;
833
+ case GGML_OP_SUB:
834
+ case GGML_OP_MUL:
835
+ case GGML_OP_DIV: {
836
+ split_state = handle_bin_bcast(src_ss);
837
+ } break;
838
+ case GGML_OP_SQR:
839
+ case GGML_OP_SQRT:
840
+ case GGML_OP_LOG:
841
+ case GGML_OP_SIN:
842
+ case GGML_OP_COS: {
843
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
844
+ } break;
845
+ case GGML_OP_SUM: {
846
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
847
+ } break;
848
+ case GGML_OP_SUM_ROWS:
849
+ case GGML_OP_CUMSUM:
850
+ case GGML_OP_MEAN:
851
+ case GGML_OP_ARGMAX:
852
+ case GGML_OP_COUNT_EQUAL: {
853
+ split_state = handle_per_row(src_ss);
854
+ } break;
855
+ case GGML_OP_REPEAT:
856
+ case GGML_OP_REPEAT_BACK: {
857
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
858
+ } break;
859
+ case GGML_OP_CONCAT: {
860
+ split_state = handle_concat(src_ss);
861
+ } break;
862
+ case GGML_OP_SILU_BACK: {
863
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
864
+ } break;
865
+ case GGML_OP_NORM:
866
+ case GGML_OP_RMS_NORM:
867
+ case GGML_OP_RMS_NORM_BACK:
868
+ case GGML_OP_GROUP_NORM:
869
+ case GGML_OP_L2_NORM: {
870
+ split_state = handle_per_row(src_ss);
871
+ } break;
872
+ case GGML_OP_MUL_MAT:
873
+ case GGML_OP_MUL_MAT_ID: {
874
+ split_state = handle_mul_mat(src_ss);
875
+ } break;
876
+ case GGML_OP_OUT_PROD: {
877
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
878
+ } break;
879
+ case GGML_OP_SCALE: {
880
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
881
+ } break;
882
+ case GGML_OP_SET: {
883
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
884
+ } break;
885
+ case GGML_OP_CPY: {
886
+ split_state = handle_cpy(src_ss);
887
+ } break;
888
+ case GGML_OP_CONT:
889
+ case GGML_OP_RESHAPE: {
890
+ split_state = handle_reshape(src_ss);
891
+ } break;
892
+ case GGML_OP_VIEW: {
893
+ split_state = handle_view(src_ss);
894
+ } break;
895
+ case GGML_OP_PERMUTE: {
896
+ split_state = handle_permute(src_ss);
897
+ } break;
898
+ case GGML_OP_TRANSPOSE: {
899
+ split_state = handle_transpose(src_ss);
900
+ } break;
901
+ case GGML_OP_GET_ROWS: {
902
+ split_state = handle_get_rows(src_ss);
903
+ } break;
904
+ case GGML_OP_GET_ROWS_BACK: {
905
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
906
+ } break;
907
+ case GGML_OP_SET_ROWS: {
908
+ split_state = handle_set_rows(src_ss);
909
+ } break;
910
+ case GGML_OP_DIAG:
911
+ case GGML_OP_DIAG_MASK_INF:
912
+ case GGML_OP_DIAG_MASK_ZERO: {
913
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
914
+ } break;
915
+ case GGML_OP_SOFT_MAX:
916
+ case GGML_OP_SOFT_MAX_BACK: {
917
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
918
+ } break;
919
+ case GGML_OP_ROPE: {
920
+ split_state = handle_rope(src_ss);
921
+ } break;
922
+ case GGML_OP_ROPE_BACK: {
923
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
924
+ } break;
925
+ case GGML_OP_CLAMP: {
926
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
927
+ } break;
928
+ case GGML_OP_CONV_TRANSPOSE_1D:
929
+ case GGML_OP_IM2COL:
930
+ case GGML_OP_IM2COL_BACK:
931
+ case GGML_OP_IM2COL_3D:
932
+ case GGML_OP_CONV_2D:
933
+ case GGML_OP_CONV_3D:
934
+ case GGML_OP_CONV_2D_DW:
935
+ case GGML_OP_CONV_TRANSPOSE_2D:
936
+ case GGML_OP_POOL_1D:
937
+ case GGML_OP_POOL_2D:
938
+ case GGML_OP_POOL_2D_BACK:
939
+ case GGML_OP_UPSCALE: {
940
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
941
+ } break;
942
+ case GGML_OP_PAD: {
943
+ split_state = handle_pad(src_ss);
944
+ } break;
945
+ case GGML_OP_PAD_REFLECT_1D:
946
+ case GGML_OP_ROLL:
947
+ case GGML_OP_ARANGE:
948
+ case GGML_OP_TIMESTEP_EMBEDDING: {
949
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
950
+ } break;
951
+ case GGML_OP_ARGSORT:
952
+ case GGML_OP_TOP_K: {
953
+ split_state = handle_per_row(src_ss);
954
+ } break;
955
+ case GGML_OP_LEAKY_RELU: {
956
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
957
+ } break;
958
+ case GGML_OP_TRI: {
959
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
960
+ } break;
961
+ case GGML_OP_FILL: {
962
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
963
+ } break;
964
+ case GGML_OP_FLASH_ATTN_EXT: {
965
+ split_state = handle_flash_attn_ext(src_ss);
966
+ } break;
967
+ case GGML_OP_FLASH_ATTN_BACK: {
968
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
969
+ } break;
970
+ case GGML_OP_SSM_CONV: {
971
+ split_state = handle_ssm_conv(src_ss);
972
+ } break;
973
+ case GGML_OP_SSM_SCAN:
974
+ case GGML_OP_WIN_PART:
975
+ case GGML_OP_WIN_UNPART:
976
+ case GGML_OP_GET_REL_POS:
977
+ case GGML_OP_ADD_REL_POS:
978
+ case GGML_OP_RWKV_WKV6:
979
+ case GGML_OP_GATED_LINEAR_ATTN:
980
+ case GGML_OP_RWKV_WKV7:
981
+ case GGML_OP_SOLVE_TRI: {
982
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
983
+ } break;
984
+ case GGML_OP_GATED_DELTA_NET: {
985
+ split_state = handle_gated_delta_net(src_ss);
986
+ } break;
987
+ case GGML_OP_UNARY: {
988
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
989
+ } break;
990
+ case GGML_OP_MAP_CUSTOM1:
991
+ case GGML_OP_MAP_CUSTOM2:
992
+ case GGML_OP_MAP_CUSTOM3:
993
+ case GGML_OP_CUSTOM: {
994
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
995
+ } break;
996
+ case GGML_OP_CROSS_ENTROPY_LOSS:
997
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK: {
998
+ split_state = handle_per_row(src_ss);
999
+ } break;
1000
+ case GGML_OP_OPT_STEP_ADAMW:
1001
+ case GGML_OP_OPT_STEP_SGD:
1002
+ case GGML_OP_GLU: {
1003
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
1004
+ } break;
1005
+ default: {
1006
+ GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op));
1007
+ split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
1008
+ } break;
1009
+ }
1010
+ if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
1011
+ bool first_src_split_by_axis = true;
1012
+ const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
1013
+
1014
+ for (size_t i = 0; i < GGML_MAX_SRC; i++) {
1015
+ if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) {
1016
+ continue;
1017
+ }
1018
+ if (first_src_split_by_axis) {
1019
+ for (size_t j = 0; j < n_bufs; j++) {
1020
+ // Take over ratio from src:
1021
+ for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1022
+ split_state.ne[s*n_bufs + j] = 0;
1023
+ }
1024
+ for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1025
+ split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s];
1026
+ }
1027
+ split_state.ne[j] *= tensor->ne[split_state.axis];
1028
+ if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) {
1029
+ const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0];
1030
+ GGML_ASSERT(split_state.ne[j] % div == 0);
1031
+ split_state.ne[j] /= div;
1032
+ }
1033
+ }
1034
+ } else {
1035
+ GGML_ASSERT(split_state.n_segments == 1);
1036
+ for (size_t j = 0; j < n_bufs; j++) {
1037
+ // Assert that ratio is consistent:
1038
+ int64_t sum = 0;
1039
+ for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1040
+ sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s];
1041
+ }
1042
+ GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis]
1043
+ == sum * tensor->ne[split_state.axis]);
1044
+ }
1045
+ }
1046
+ first_src_split_by_axis = false;
1047
+ }
1048
+ GGML_ASSERT(!first_src_split_by_axis);
1049
+ }
1050
+ return split_state;
1051
+ };
1052
+
1053
+ const std::pair key = std::make_pair(tensor, assume_sync);
1054
+ auto it = buf_ctx->split_state_cache.find(key);
1055
+ if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) {
1056
+ buf_ctx->split_state_cache.clear();
1057
+ it = buf_ctx->split_state_cache.end();
1058
+ }
1059
+
1060
+ if (it == buf_ctx->split_state_cache.end()) {
1061
+ buf_ctx->split_state_cache[key].first = calculate_split_state();
1062
+ memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second));
1063
+ if (buf_ctx->debug > 0) {
1064
+ std::string srcs_info;
1065
+ for (size_t i = 0; i < GGML_MAX_SRC; i++) {
1066
+ if (tensor->src[i] == nullptr) {
1067
+ continue;
1068
+ }
1069
+ if (!srcs_info.empty()) {
1070
+ srcs_info += ", ";
1071
+ }
1072
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true);
1073
+ GGML_ASSERT(split_state.n_segments == 1);
1074
+ const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis);
1075
+ std::string ne_info;
1076
+ for (size_t j = 0; j < n_bufs; j++) {
1077
+ if (!ne_info.empty()) {
1078
+ ne_info += ", ";
1079
+ }
1080
+ ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]);
1081
+ }
1082
+ srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]";
1083
+ }
1084
+ std::string ne_info;
1085
+ for (size_t j = 0; j < n_bufs; j++) {
1086
+ if (!ne_info.empty()) {
1087
+ ne_info += ", ";
1088
+ }
1089
+ const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first;
1090
+ ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]);
1091
+ }
1092
+ GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op),
1093
+ ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str());
1094
+ }
1095
+ }
1096
+
1097
+ ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first;
1098
+ GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE);
1099
+ #ifndef NDEBUG
1100
+ if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
1101
+ int64_t ne_ret = 0;
1102
+ for (size_t s = 0; s < ret.n_segments; s++) {
1103
+ for (size_t j = 0; j < n_bufs; j++) {
1104
+ ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s];
1105
+ }
1106
+ }
1107
+ assert(ne_ret == tensor->ne[int(ret.axis)]);
1108
+ }
1109
+ #endif // NDEBUG
1110
+ return ret;
1111
+ }
1112
+
1113
+ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
1114
+ GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
1115
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
1116
+ return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync);
1117
+ }
1118
+
1119
+ static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) {
1120
+ GGML_UNUSED(buffer);
1121
+ return (void *) 0x1000000000000000; // FIXME
1122
+ }
1123
+
1124
+ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) {
1125
+ GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
1126
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
1127
+ const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
1128
+
1129
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true);
1130
+ GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
1131
+ GGML_ASSERT(split_state.n_segments <= 16);
1132
+
1133
+ int split_dim = split_state.axis;
1134
+ int64_t ne[GGML_MAX_DIMS];
1135
+ size_t nb[GGML_MAX_DIMS];
1136
+ for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
1137
+ ne[k] = tensor->ne[k];
1138
+ nb[k] = tensor->nb[k];
1139
+ }
1140
+
1141
+ std::vector<ggml_tensor *> simple_tensors;
1142
+ simple_tensors.reserve(n_simple_bufs);
1143
+ for (size_t j = 0; j < n_simple_bufs; j++) {
1144
+ ggml_context * simple_ctx = stc.ctxs[j].get();
1145
+ ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get();
1146
+
1147
+ if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1148
+ // TODO: the following assert fails for llama-parallel even though the results are correct:
1149
+ // GGML_ASSERT(ggml_is_contiguously_allocated(tensor));
1150
+ ne[split_dim] = 0;
1151
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1152
+ ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s];
1153
+ }
1154
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
1155
+ if (tensor->nb[i] > tensor->nb[split_dim]) {
1156
+ nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim];
1157
+ }
1158
+ }
1159
+ }
1160
+
1161
+ ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne);
1162
+ t_ij->op = tensor->op;
1163
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
1164
+ t_ij->nb[i] = nb[i];
1165
+ }
1166
+ t_ij->flags = tensor->flags;
1167
+ memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params));
1168
+ ggml_set_name(t_ij, tensor->name);
1169
+ t_ij->buffer = simple_buf;
1170
+ t_ij->view_src = tensor->view_src;
1171
+ t_ij->view_offs = tensor->view_offs;
1172
+ if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) {
1173
+ t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j);
1174
+ if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1175
+ GGML_ASSERT(tensor->ne[split_dim] != 0);
1176
+ const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis;
1177
+ GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS);
1178
+
1179
+ // The offset can be internal to the data split, in those cases the view offset should not be scaled.
1180
+ // If however, the offset is larger than the data split then it needs to be scaled proportionally.
1181
+ bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src];
1182
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
1183
+ const size_t dim_size = tensor->ne[i] * tensor->nb[i];
1184
+ if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) {
1185
+ split_internal_offset = true;
1186
+ break;
1187
+ }
1188
+ }
1189
+ if (!split_internal_offset) {
1190
+ t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
1191
+ }
1192
+ }
1193
+ }
1194
+ if (t_ij->view_src != nullptr) {
1195
+ t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs;
1196
+ } else if (simple_buf != nullptr) {
1197
+ t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf)
1198
+ + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer));
1199
+ }
1200
+ t_ij->extra = tensor->extra;
1201
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1202
+ t_ij->src[i] = tensor->src[i];
1203
+ if (tensor->src[i] == tensor) {
1204
+ t_ij->src[i] = t_ij;
1205
+ } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) {
1206
+ t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j);
1207
+ }
1208
+ }
1209
+
1210
+ simple_tensors.push_back(t_ij);
1211
+ }
1212
+
1213
+ // If one of the sources has a zero-sized slice, disable the computation:
1214
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1215
+ if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) {
1216
+ continue;
1217
+ }
1218
+
1219
+ const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true);
1220
+ if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) {
1221
+ continue;
1222
+ }
1223
+ for (size_t j = 0; j < n_simple_bufs; j++) {
1224
+ int64_t ne_sum = 0;
1225
+ for (size_t s = 0; s < split_state_src.n_segments; s++) {
1226
+ ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s];
1227
+ }
1228
+ if (ne_sum == 0) {
1229
+ simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
1230
+ }
1231
+ }
1232
+ }
1233
+
1234
+ stc.simple_tensors[tensor] = simple_tensors;
1235
+
1236
+ return GGML_STATUS_SUCCESS;
1237
+ }
1238
+
1239
+ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
1240
+ GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
1241
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
1242
+ buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next;
1243
+ return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor);
1244
+ }
1245
+
1246
+ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1247
+ const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1248
+ GGML_ASSERT(ggml_is_contiguous(tensor));
1249
+
1250
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1251
+
1252
+ if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
1253
+ GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1254
+ GGML_ASSERT(split_state.nr[0] != 0);
1255
+ GGML_ASSERT(tensor->ne[3] == 1);
1256
+
1257
+ size_t offset_data = 0;
1258
+ std::vector<size_t> simple_offsets(n_bufs, 0);
1259
+ if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
1260
+ GGML_ASSERT(tensor->ne[2] == 1);
1261
+
1262
+ const size_t row_stride = tensor->nb[1];
1263
+ GGML_ASSERT(offset % row_stride == 0);
1264
+ GGML_ASSERT(size % row_stride == 0);
1265
+ const int64_t row_start = offset / row_stride;
1266
+ const int64_t row_count = size / row_stride;
1267
+ GGML_ASSERT(row_start + row_count <= tensor->ne[1]);
1268
+
1269
+ const int64_t blck_size = ggml_blck_size(tensor->type);
1270
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1271
+ for (size_t r = 0; r < split_state.nr[s]; r++) {
1272
+ for (size_t j = 0; j < n_bufs; j++) {
1273
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1274
+ GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
1275
+ const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1276
+ ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
1277
+ simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes,
1278
+ row_count, simple_tensor->nb[1], tensor->nb[1]);
1279
+ offset_data += nbytes;
1280
+ simple_offsets[j] += nbytes;
1281
+ }
1282
+ }
1283
+ }
1284
+ GGML_ASSERT(offset_data*row_count == size);
1285
+ return;
1286
+ }
1287
+ GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1288
+
1289
+ const size_t row_stride = tensor->nb[2];
1290
+ GGML_ASSERT(offset % row_stride == 0);
1291
+ GGML_ASSERT(size % row_stride == 0);
1292
+ const int64_t row_start = offset / row_stride;
1293
+ const int64_t row_count = size / row_stride;
1294
+ GGML_ASSERT(row_start + row_count <= tensor->ne[2]);
1295
+
1296
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1297
+ for (size_t r = 0; r < split_state.nr[s]; r++) {
1298
+ for (size_t j = 0; j < n_bufs; j++) {
1299
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1300
+ const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1301
+ ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
1302
+ simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes,
1303
+ row_count, simple_tensor->nb[2], tensor->nb[2]);
1304
+ offset_data += nbytes;
1305
+ simple_offsets[j] += nbytes;
1306
+ }
1307
+ }
1308
+ }
1309
+ GGML_ASSERT(offset_data*row_count == size);
1310
+ return;
1311
+ }
1312
+
1313
+ switch (split_state.axis) {
1314
+ case GGML_BACKEND_SPLIT_AXIS_0:
1315
+ case GGML_BACKEND_SPLIT_AXIS_1:
1316
+ case GGML_BACKEND_SPLIT_AXIS_2: {
1317
+ // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1318
+ const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1319
+ GGML_ASSERT(offset % chunk_size_full == 0);
1320
+ GGML_ASSERT(size % chunk_size_full == 0);
1321
+ const int64_t i_start = offset /chunk_size_full;
1322
+ const int64_t i_stop = (offset + size)/chunk_size_full;
1323
+ size_t offset_j = 0;
1324
+ for (size_t j = 0; j < n_bufs; j++) {
1325
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1326
+ const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1327
+ if (chunk_size_j == 0) {
1328
+ continue;
1329
+ }
1330
+ const size_t simple_offset = i_start * chunk_size_j;
1331
+ ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
1332
+ offset_j += chunk_size_j;
1333
+ }
1334
+ GGML_ASSERT(offset_j == chunk_size_full);
1335
+ } break;
1336
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1337
+ for (size_t j = 0; j < n_bufs; j++) {
1338
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1339
+ ggml_backend_tensor_set(simple_tensor, data, offset, size);
1340
+ }
1341
+ } break;
1342
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
1343
+ GGML_ASSERT(tensor->type == GGML_TYPE_F32);
1344
+ const int64_t ne = ggml_nelements(tensor);
1345
+ std::vector<float> tmp;
1346
+ tmp.reserve(ne);
1347
+ for (int64_t i = 0; i < ne; i++) {
1348
+ tmp.push_back(((const float *) data)[i] / n_bufs);
1349
+ }
1350
+ for (size_t j = 0; j < n_bufs; j++) {
1351
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1352
+ ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size);
1353
+ }
1354
+ } break;
1355
+ default: {
1356
+ GGML_ABORT("fatal error");
1357
+ }
1358
+ }
1359
+ }
1360
+
1361
+ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1362
+ const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1363
+ GGML_ASSERT(ggml_is_contiguous(tensor));
1364
+
1365
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1366
+
1367
+ if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
1368
+ GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1369
+ GGML_ASSERT(split_state.nr[0] != 0);
1370
+ GGML_ASSERT(tensor->ne[3] == 1);
1371
+
1372
+ size_t offset_data = 0;
1373
+ std::vector<size_t> simple_offsets(n_bufs, 0);
1374
+ if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
1375
+ GGML_ASSERT(tensor->ne[2] == 1);
1376
+
1377
+ const size_t row_stride = tensor->nb[1];
1378
+ GGML_ASSERT(offset % row_stride == 0);
1379
+ GGML_ASSERT(size % row_stride == 0);
1380
+ const int64_t row_start = offset / row_stride;
1381
+ const int64_t row_count = size / row_stride;
1382
+ GGML_ASSERT(row_start + row_count <= tensor->ne[1]);
1383
+
1384
+ const int64_t blck_size = ggml_blck_size(tensor->type);
1385
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1386
+ for (size_t r = 0; r < split_state.nr[s]; r++) {
1387
+ for (size_t j = 0; j < n_bufs; j++) {
1388
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1389
+ GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
1390
+ const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1391
+ ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
1392
+ simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes,
1393
+ row_count, simple_tensor->nb[1], tensor->nb[1]);
1394
+ offset_data += nbytes;
1395
+ simple_offsets[j] += nbytes;
1396
+ }
1397
+ }
1398
+ }
1399
+ GGML_ASSERT(offset_data*row_count == size);
1400
+ return;
1401
+ }
1402
+ GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1403
+
1404
+ const size_t row_stride = tensor->nb[2];
1405
+ GGML_ASSERT(offset % row_stride == 0);
1406
+ GGML_ASSERT(size % row_stride == 0);
1407
+ const int64_t row_start = offset / row_stride;
1408
+ const int64_t row_count = size / row_stride;
1409
+ GGML_ASSERT(row_start + row_count <= tensor->ne[2]);
1410
+
1411
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1412
+ for (size_t r = 0; r < split_state.nr[s]; r++) {
1413
+ for (size_t j = 0; j < n_bufs; j++) {
1414
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1415
+ const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1416
+ ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
1417
+ simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes,
1418
+ row_count, simple_tensor->nb[2], tensor->nb[2]);
1419
+ offset_data += nbytes;
1420
+ simple_offsets[j] += nbytes;
1421
+ }
1422
+ }
1423
+ }
1424
+ GGML_ASSERT(offset_data*row_count == size);
1425
+ return;
1426
+ }
1427
+
1428
+ switch (split_state.axis) {
1429
+ case GGML_BACKEND_SPLIT_AXIS_0:
1430
+ case GGML_BACKEND_SPLIT_AXIS_1:
1431
+ case GGML_BACKEND_SPLIT_AXIS_2: {
1432
+ // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1433
+ const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1434
+ GGML_ASSERT(offset % chunk_size_full == 0);
1435
+ GGML_ASSERT(size % chunk_size_full == 0);
1436
+ const int64_t i_start = offset /chunk_size_full;
1437
+ const int64_t i_stop = (offset + size)/chunk_size_full;
1438
+ size_t offset_j = 0;
1439
+ for (size_t j = 0; j < n_bufs; j++){
1440
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1441
+ const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1442
+ if (chunk_size_j == 0) {
1443
+ continue;
1444
+ }
1445
+ const size_t simple_offset = i_start * chunk_size_j;
1446
+ ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
1447
+ offset_j += chunk_size_j;
1448
+ }
1449
+ GGML_ASSERT(offset_j == chunk_size_full);
1450
+ } break;
1451
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1452
+ // TODO other simple backend may be better
1453
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
1454
+ ggml_backend_tensor_get(simple_tensor, data, offset, size);
1455
+ } break;
1456
+ default: {
1457
+ GGML_ABORT("fatal error");
1458
+ }
1459
+ }
1460
+ }
1461
+
1462
+ static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1463
+ const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer);
1464
+ for (size_t i = 0; i < n_buffers; i++) {
1465
+ ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value);
1466
+ }
1467
+ }
1468
+
1469
+ static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) {
1470
+ GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
1471
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
1472
+ for (size_t i = 0; i < buf_ctx->bufs.size(); i++) {
1473
+ ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i));
1474
+ }
1475
+ }
1476
+
1477
+ static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = {
1478
+ /* .free_buffer = */ ggml_backend_meta_buffer_free_buffer,
1479
+ /* .get_base = */ ggml_backend_meta_buffer_get_base,
1480
+ /* .init_tensor = */ ggml_backend_meta_buffer_init_tensor,
1481
+ /* .memset_tensor = */ nullptr, // TODO implement
1482
+ /* .set_tensor = */ ggml_backend_meta_buffer_set_tensor,
1483
+ /* .get_tensor = */ ggml_backend_meta_buffer_get_tensor,
1484
+ /* .set_tensor_2d = */ nullptr,
1485
+ /* .get_tensor_2d = */ nullptr,
1486
+ /* .cpy_tensor = */ nullptr,
1487
+ /* .clear = */ ggml_backend_meta_buffer_clear,
1488
+ /* .reset = */ ggml_backend_meta_buffer_reset,
1489
+ };
1490
+
1491
+ bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) {
1492
+ return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer;
1493
+ }
1494
+
1495
+ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1496
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
1497
+
1498
+ const ggml_init_params params = {
1499
+ /*.mem_size =*/ 1024*1024*ggml_tensor_overhead(), // FIXME
1500
+ /*.mem_buffer =*/ nullptr,
1501
+ /*.no_alloc =*/ true,
1502
+ };
1503
+ ggml_backend_meta_simple_tensor_container stc_static;
1504
+ ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts);
1505
+ ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts);
1506
+
1507
+ size_t max_size = 0;
1508
+ std::vector<ggml_backend_buffer_t> bufs;
1509
+ bufs.reserve(n_simple_bufts);
1510
+ for (size_t i = 0; i < n_simple_bufts; i++) {
1511
+ bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size));
1512
+ GGML_ASSERT(bufs.back() != nullptr);
1513
+ max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back()));
1514
+ }
1515
+ ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
1516
+
1517
+ return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size);
1518
+ }
1519
+
1520
+ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
1521
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
1522
+
1523
+ constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals.
1524
+ const ggml_init_params params_static = {
1525
+ /*.mem_size =*/ ggml_get_mem_size(ctx),
1526
+ /*.mem_buffer =*/ nullptr,
1527
+ /*.no_alloc =*/ true,
1528
+ };
1529
+ const ggml_init_params params_compute = {
1530
+ /*.mem_size =*/ compute_headroom*ggml_get_mem_size(ctx),
1531
+ /*.mem_buffer =*/ nullptr,
1532
+ /*.no_alloc =*/ true,
1533
+ };
1534
+ ggml_backend_meta_simple_tensor_container stc_static (params_static, n_simple_bufts);
1535
+ ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts);
1536
+ ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts);
1537
+
1538
+ std::vector<ggml_backend_buffer_t> bufs(n_simple_bufts, nullptr);
1539
+ ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
1540
+
1541
+ ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0);
1542
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1543
+ t->buffer = meta_buf;
1544
+ ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t);
1545
+ t->data = (void *) 0x2000000000000000; // FIXME
1546
+ }
1547
+ for (size_t i = 0; i < n_simple_bufts; i++) {
1548
+ ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get();
1549
+ ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i);
1550
+
1551
+ // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL.
1552
+ // For those edge cases, allocate a dummy buffer instead.
1553
+ bool any_nonzero_slice = false;
1554
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1555
+ if (ggml_nelements(t) != 0) {
1556
+ any_nonzero_slice = true;
1557
+ break;
1558
+ }
1559
+ }
1560
+ if (any_nonzero_slice) {
1561
+ meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft));
1562
+ } else {
1563
+ meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0));
1564
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1565
+ t->buffer = meta_buf_ctx->bufs[i].get();
1566
+ }
1567
+ }
1568
+ GGML_ASSERT(meta_buf_ctx->bufs[i]);
1569
+ meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get()));
1570
+ }
1571
+ return meta_buf;
1572
+ }
1573
+
1574
+ //
1575
+ // meta backend
1576
+ //
1577
+
1578
+ static ggml_guid_t ggml_backend_meta_guid() {
1579
+ static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda};
1580
+ return &guid;
1581
+ }
1582
+
1583
+ struct ggml_backend_meta_context {
1584
+ struct cgraph_config {
1585
+ ggml_cgraph * cgraph_main = nullptr;
1586
+ int offset = 0; // Node offset vs. original graph
1587
+
1588
+ std::vector<ggml_cgraph *> cgraphs_aux;
1589
+ };
1590
+ struct backend_config {
1591
+ ggml_backend_t backend;
1592
+
1593
+ std::vector<cgraph_config> cgraphs;
1594
+ std::vector<ggml_tensor *> nodes;
1595
+ std::vector<ggml_backend_buffer_ptr> bufs;
1596
+
1597
+ backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) {
1598
+ bufs.resize(n_reduce_steps);
1599
+ }
1600
+ };
1601
+ std::string name;
1602
+ std::vector<backend_config> backend_configs;
1603
+ ggml_context_ptr ctx;
1604
+ std::vector<ggml_cgraph *> cgraphs_aux;
1605
+ std::vector<ggml_tensor *> nodes_aux;
1606
+ size_t n_reduce_steps;
1607
+ int max_nnodes = 0;
1608
+ size_t max_tmp_size = 0;
1609
+ size_t max_subgraphs = 0;
1610
+ size_t n_subgraphs = 0;
1611
+ uint64_t uid = 0;
1612
+
1613
+ void * comm_ctx = nullptr;
1614
+ ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr;
1615
+
1616
+ ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) {
1617
+ const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev);
1618
+ n_reduce_steps = std::ceil(std::log2(n_devs));
1619
+ name = "Meta(";
1620
+ std::vector<ggml_backend_t> simple_backends;
1621
+ backend_configs.reserve(n_devs);
1622
+ simple_backends.reserve(n_devs);
1623
+ for (size_t i = 0; i < n_devs; i++) {
1624
+ ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i);
1625
+ if (i > 0) {
1626
+ name += ",";
1627
+ }
1628
+ name += ggml_backend_dev_name(simple_dev);
1629
+ simple_backends.push_back(ggml_backend_dev_init(simple_dev, params));
1630
+ backend_configs.emplace_back(simple_backends.back(), n_reduce_steps);
1631
+ }
1632
+ name += ")";
1633
+
1634
+ if (n_devs > 1) {
1635
+ ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address(
1636
+ ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init");
1637
+ if (comm_init != nullptr) {
1638
+ comm_ctx = comm_init(simple_backends.data(), simple_backends.size());
1639
+ }
1640
+ }
1641
+ if (comm_ctx != nullptr) {
1642
+ comm_allreduce = (ggml_backend_comm_allreduce_tensor_t)
1643
+ ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg(
1644
+ ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor");
1645
+ GGML_ASSERT(comm_allreduce != nullptr);
1646
+ }
1647
+ }
1648
+
1649
+ ~ggml_backend_meta_context() {
1650
+ if (comm_ctx != nullptr) {
1651
+ ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address(
1652
+ ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free");
1653
+ GGML_ASSERT(comm_free != nullptr);
1654
+ comm_free(comm_ctx);
1655
+ }
1656
+ for (auto & bc : backend_configs) {
1657
+ ggml_backend_free(bc.backend);
1658
+ }
1659
+ }
1660
+ };
1661
+
1662
+ static const char * ggml_backend_meta_get_name(ggml_backend_t backend) {
1663
+ GGML_ASSERT(ggml_backend_is_meta(backend));
1664
+ const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context;
1665
+ return backend_ctx->name.c_str();
1666
+ }
1667
+
1668
+ static void ggml_backend_meta_free(ggml_backend_t backend) {
1669
+ GGML_ASSERT(ggml_backend_is_meta(backend));
1670
+ ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
1671
+ delete backend_ctx;
1672
+ delete backend;
1673
+ }
1674
+
1675
+ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1676
+ const size_t n_backends = ggml_backend_meta_n_backends(backend);
1677
+ GGML_ASSERT(offset == 0);
1678
+ GGML_ASSERT(ggml_is_contiguous(tensor));
1679
+
1680
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1681
+ GGML_ASSERT(split_state.n_segments == 1);
1682
+ GGML_ASSERT(split_state.nr[0] == 1);
1683
+
1684
+ switch (split_state.axis) {
1685
+ case GGML_BACKEND_SPLIT_AXIS_0:
1686
+ case GGML_BACKEND_SPLIT_AXIS_1:
1687
+ case GGML_BACKEND_SPLIT_AXIS_2: {
1688
+ // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1689
+ const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1690
+ GGML_ASSERT(offset % chunk_size_full == 0);
1691
+ GGML_ASSERT(size % chunk_size_full == 0);
1692
+ const int64_t i_start = offset /chunk_size_full;
1693
+ const int64_t i_stop = (offset + size)/chunk_size_full;
1694
+ size_t offset_j = 0;
1695
+ for (size_t j = 0; j < n_backends; j++){
1696
+ ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
1697
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1698
+ const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1699
+ if (chunk_size_j == 0) {
1700
+ continue;
1701
+ }
1702
+ ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j,
1703
+ i_stop - i_start, chunk_size_j, chunk_size_full);
1704
+ offset_j += chunk_size_j;
1705
+ }
1706
+ GGML_ASSERT(offset_j == chunk_size_full);
1707
+ } break;
1708
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1709
+ for (size_t j = 0; j < n_backends; j++) {
1710
+ ggml_backend_tensor_set_async(
1711
+ ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size);
1712
+ }
1713
+ } break;
1714
+ default: {
1715
+ GGML_ABORT("fatal error");
1716
+ }
1717
+ }
1718
+ }
1719
+
1720
+ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1721
+ const size_t n_backends = ggml_backend_meta_n_backends(backend);
1722
+ GGML_ASSERT(offset == 0);
1723
+ GGML_ASSERT(ggml_is_contiguous(tensor));
1724
+
1725
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1726
+ GGML_ASSERT(split_state.n_segments == 1);
1727
+ GGML_ASSERT(split_state.nr[0] == 1);
1728
+
1729
+ switch (split_state.axis) {
1730
+ case GGML_BACKEND_SPLIT_AXIS_0:
1731
+ case GGML_BACKEND_SPLIT_AXIS_1:
1732
+ case GGML_BACKEND_SPLIT_AXIS_2: {
1733
+ // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1734
+ const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1735
+ GGML_ASSERT(offset % chunk_size_full == 0);
1736
+ GGML_ASSERT(size % chunk_size_full == 0);
1737
+ const int64_t i_start = offset /chunk_size_full;
1738
+ const int64_t i_stop = (offset + size)/chunk_size_full;
1739
+ size_t offset_j = 0;
1740
+ for (size_t j = 0; j < n_backends; j++){
1741
+ ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
1742
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1743
+ const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1744
+ if (chunk_size_j == 0) {
1745
+ continue;
1746
+ }
1747
+ ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j,
1748
+ i_stop - i_start, chunk_size_j, chunk_size_full);
1749
+ offset_j += chunk_size_j;
1750
+ }
1751
+ GGML_ASSERT(offset_j == chunk_size_full);
1752
+ } break;
1753
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1754
+ // TODO other simple backend may be better
1755
+ ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0);
1756
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
1757
+ ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size);
1758
+ } break;
1759
+ default: {
1760
+ GGML_ABORT("fatal error");
1761
+ }
1762
+ }
1763
+ }
1764
+
1765
+ static void ggml_backend_meta_synchronize(ggml_backend_t backend) {
1766
+ const size_t n_backends = ggml_backend_meta_n_backends(backend);
1767
+ for (size_t i = 0; i < n_backends; i++) {
1768
+ ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i));
1769
+ }
1770
+ }
1771
+
1772
+ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1773
+ GGML_ASSERT(cgraph->grads == nullptr);
1774
+ const size_t n_backends = ggml_backend_meta_n_backends(backend);
1775
+ ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
1776
+
1777
+ // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend.
1778
+ const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid);
1779
+
1780
+ bool max_nnodes_raised = false;
1781
+ if (cgraph->n_nodes > backend_ctx->max_nnodes) {
1782
+ for (size_t j = 0; j < n_backends; j++) {
1783
+ auto & bcj = backend_ctx->backend_configs[j];
1784
+ bcj.nodes.resize(cgraph->n_nodes);
1785
+ bcj.cgraphs.resize(cgraph->n_nodes);
1786
+ }
1787
+ backend_ctx->max_nnodes = cgraph->n_nodes;
1788
+ max_nnodes_raised = true;
1789
+ assert(needs_rebuild);
1790
+ }
1791
+
1792
+ if (needs_rebuild) {
1793
+ std::set<ggml_backend_buffer_t> used_buffers;
1794
+ for (int i = 0; i < cgraph->n_leafs; i++) {
1795
+ if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) {
1796
+ used_buffers.emplace(cgraph->leafs[i]->buffer);
1797
+ }
1798
+ }
1799
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1800
+ if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) {
1801
+ used_buffers.emplace(cgraph->nodes[i]->buffer);
1802
+ }
1803
+ }
1804
+ for (ggml_backend_buffer_t buf : used_buffers) {
1805
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context;
1806
+ buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1;
1807
+ ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next];
1808
+ for (ggml_context_ptr & ctx : stc.ctxs) {
1809
+ ggml_reset(ctx.get());
1810
+ }
1811
+ stc.simple_tensors.clear();
1812
+ }
1813
+ size_t n_subgraphs = 0;
1814
+ size_t max_tmp_size = 0;
1815
+
1816
+ for (size_t j = 0; j < n_backends; j++) {
1817
+ auto & bcj = backend_ctx->backend_configs[j];
1818
+
1819
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1820
+ ggml_tensor * node = cgraph->nodes[i];
1821
+ if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
1822
+ // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes.
1823
+ // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash.
1824
+ bcj.nodes[i] = node;
1825
+ continue;
1826
+ }
1827
+ bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j);
1828
+ GGML_ASSERT(bcj.nodes[i]);
1829
+ }
1830
+ }
1831
+
1832
+ {
1833
+ // For MoE models it may make sense to delay the AllReduce in order to reduce I/O:
1834
+ auto get_i_delayed = [&](const int i) -> int {
1835
+ int id = i; // i_delayed
1836
+ int idr = i; // i_delayed return, last safe return value
1837
+
1838
+ ggml_tensor * node = cgraph->nodes[id];
1839
+ int32_t n_used = ggml_node_get_use_count(cgraph, id);
1840
+
1841
+ // Skip MIRRORED nodes that don't consume node
1842
+ auto skip_unrelated = [&]() {
1843
+ while (id + 1 < cgraph->n_nodes) {
1844
+ ggml_tensor * next = cgraph->nodes[id+1];
1845
+ if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1846
+ break;
1847
+ }
1848
+ bool safe = true;
1849
+ for (int s = 0; s < GGML_MAX_SRC; s++) {
1850
+ if (next->src[s] == nullptr) {
1851
+ continue;
1852
+ }
1853
+ if (next->src[s] == node) {
1854
+ safe = false;
1855
+ break;
1856
+ }
1857
+ if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1858
+ safe = false;
1859
+ break;
1860
+ }
1861
+ }
1862
+ if (!safe) {
1863
+ break;
1864
+ }
1865
+ id++;
1866
+ }
1867
+ };
1868
+
1869
+ skip_unrelated();
1870
+ if (id + 1 >= cgraph->n_nodes) {
1871
+ return idr;
1872
+ }
1873
+ {
1874
+ ggml_tensor * next = cgraph->nodes[id+1];
1875
+ if (next->op == GGML_OP_ADD_ID && next->src[0] == node &&
1876
+ ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL &&
1877
+ ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1878
+ node = next;
1879
+ id++;
1880
+ idr = id;
1881
+ n_used = ggml_node_get_use_count(cgraph, id);
1882
+ }
1883
+ }
1884
+ // Chain of MULs with MIRRORED src[1]
1885
+ while (true) {
1886
+ skip_unrelated();
1887
+ if (id + 1 >= cgraph->n_nodes) {
1888
+ return idr;
1889
+ }
1890
+ ggml_tensor * next = cgraph->nodes[id+1];
1891
+ if (next->op == GGML_OP_MUL && next->src[0] == node &&
1892
+ ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1893
+ node = next;
1894
+ id++;
1895
+ idr = id;
1896
+ n_used = ggml_node_get_use_count(cgraph, id);
1897
+ } else {
1898
+ break;
1899
+ }
1900
+ }
1901
+
1902
+ if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) {
1903
+ return idr;
1904
+ }
1905
+ for (int32_t k = 0; k < n_used; k++) {
1906
+ ggml_tensor * next = cgraph->nodes[id+1];
1907
+ if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] ||
1908
+ next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] ||
1909
+ ggml_node_get_use_count(cgraph, id+1) != 1) {
1910
+ return idr;
1911
+ }
1912
+ id++;
1913
+ }
1914
+ {
1915
+ ggml_tensor * next = cgraph->nodes[id+1];
1916
+ if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] ||
1917
+ next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
1918
+ return idr;
1919
+ }
1920
+ id++;
1921
+ }
1922
+ for (int32_t k = 0; k < n_used - 2; k++) {
1923
+ ggml_tensor * next = cgraph->nodes[id+1];
1924
+ if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] ||
1925
+ next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
1926
+ return idr;
1927
+ }
1928
+ id++;
1929
+ }
1930
+ idr = id;
1931
+ return idr;
1932
+ };
1933
+
1934
+ int i_start = 0;
1935
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1936
+ ggml_tensor * node = cgraph->nodes[i];
1937
+ if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
1938
+ continue;
1939
+ }
1940
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false);
1941
+ if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
1942
+ max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node));
1943
+ }
1944
+ const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL;
1945
+ if (!new_subgraph) {
1946
+ continue;
1947
+ }
1948
+
1949
+ const int i_delayed = get_i_delayed(i);
1950
+
1951
+ // If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices.
1952
+ // A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has
1953
+ // its compute flag disabled and thus gets its data zeroed out.
1954
+ // If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled.
1955
+ if (i_delayed > i) {
1956
+ for (size_t j = 0; j < n_backends; j++) {
1957
+ auto & bcj = backend_ctx->backend_configs[j];
1958
+ if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
1959
+ for (int ii = i + 1; ii <= i_delayed; ii++) {
1960
+ bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
1961
+ }
1962
+ }
1963
+ }
1964
+ }
1965
+
1966
+ i = i_delayed;
1967
+
1968
+ for (size_t j = 0; j < n_backends; j++) {
1969
+ auto & bcj = backend_ctx->backend_configs[j];
1970
+ bcj.cgraphs[n_subgraphs].offset = i_start;
1971
+ }
1972
+ n_subgraphs++;
1973
+ i_start = i + 1;
1974
+ }
1975
+ GGML_ASSERT(i_start == cgraph->n_nodes);
1976
+ }
1977
+
1978
+ backend_ctx->uid = cgraph->uid;
1979
+ backend_ctx->n_subgraphs = n_subgraphs;
1980
+
1981
+ if (max_tmp_size > backend_ctx->max_tmp_size) {
1982
+ for (size_t j = 0; j < n_backends; j++) {
1983
+ auto & bcj = backend_ctx->backend_configs[j];
1984
+ for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) {
1985
+ bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size));
1986
+ }
1987
+ }
1988
+ backend_ctx->max_tmp_size = max_tmp_size;
1989
+ }
1990
+
1991
+ if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) {
1992
+ backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs);
1993
+ const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device
1994
+ const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device
1995
+ const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads);
1996
+ const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads);
1997
+ const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead();
1998
+ const ggml_init_params params = {
1999
+ /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux),
2000
+ /*.mem_buffer =*/ nullptr,
2001
+ /*.no_alloc =*/ true,
2002
+ };
2003
+ backend_ctx->ctx.reset(ggml_init(params));
2004
+ for (size_t j = 0; j < n_backends; j++) {
2005
+ auto & bcj = backend_ctx->backend_configs[j];
2006
+ for (size_t i = 0; i < n_subgraphs; i++) {
2007
+ bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false);
2008
+ }
2009
+ }
2010
+ backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs);
2011
+ for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) {
2012
+ backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads);
2013
+ }
2014
+ backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs);
2015
+ for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) {
2016
+ backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1);
2017
+ }
2018
+ }
2019
+
2020
+ for (size_t j = 0; j < n_backends; j++) {
2021
+ auto & bcj = backend_ctx->backend_configs[j];
2022
+ for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) {
2023
+ ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main;
2024
+ const size_t i_node_start = bcj.cgraphs[i_graph].offset;
2025
+ const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes;
2026
+ cgraph_ij->n_nodes = i_node_stop - i_node_start;
2027
+ ggml_hash_set_reset(&cgraph_ij->visited_hash_set);
2028
+ for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) {
2029
+ ggml_tensor * node_ij = bcj.nodes[i_node];
2030
+ cgraph_ij->nodes[i_node - i_node_start] = node_ij;
2031
+ const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]);
2032
+ const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij);
2033
+ cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig];
2034
+ }
2035
+ cgraph_ij->uid = ggml_graph_next_uid();
2036
+ }
2037
+ }
2038
+ }
2039
+
2040
+ size_t iga = 0; // i graph aux
2041
+ size_t ina = 0; // i node aux
2042
+
2043
+ auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * {
2044
+ ggml_tensor * ret = backend_ctx->nodes_aux[ina++];
2045
+ memset(ret, 0, sizeof(ggml_tensor));
2046
+ ret->op = GGML_OP_NONE;
2047
+ ret->type = t->type;
2048
+ for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
2049
+ ret->ne[k] = t->ne[k];
2050
+ ret->nb[k] = t->nb[k];
2051
+ }
2052
+ return ret;
2053
+ };
2054
+ auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) {
2055
+ auto & bcj = backend_ctx->backend_configs[j];
2056
+ ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf];
2057
+ if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) {
2058
+ buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size));
2059
+ }
2060
+ tensor->buffer = buf_ptr.get();
2061
+ tensor->data = ggml_backend_buffer_get_base(buf_ptr.get());
2062
+ };
2063
+ // FIXME usage_counts
2064
+ auto get_cgraph_aux = [&]() -> ggml_cgraph * {
2065
+ ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++];
2066
+ return ret;
2067
+ };
2068
+
2069
+ // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable:
2070
+ auto allreduce_fallback = [&](size_t i) -> ggml_status {
2071
+ std::vector<ggml_cgraph *> step_cgraphs(n_backends, nullptr);
2072
+
2073
+ // Zero out nodes that were disabled due to having a zero-sized slice:
2074
+ for (size_t j = 0; j < n_backends; j++) {
2075
+ auto & bcj = backend_ctx->backend_configs[j];
2076
+ ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1];
2077
+ if (node->flags & GGML_TENSOR_FLAG_COMPUTE) {
2078
+ continue;
2079
+ }
2080
+ ggml_tensor * node_zero = get_node_aux(node);
2081
+ node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN
2082
+ node_zero->src[0] = node;
2083
+ ggml_set_op_params_f32(node_zero, 0, 0.0f);
2084
+ node_zero->data = node->data;
2085
+ node_zero->buffer = node->buffer;
2086
+ node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE;
2087
+
2088
+ step_cgraphs[j] = get_cgraph_aux();
2089
+ step_cgraphs[j]->nodes[0] = node_zero;
2090
+ step_cgraphs[j]->n_nodes = 1;
2091
+ const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
2092
+ if (status != GGML_STATUS_SUCCESS) {
2093
+ return status;
2094
+ }
2095
+ }
2096
+ std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
2097
+
2098
+ auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) {
2099
+ assert(step_cgraphs[j_dst] == nullptr);
2100
+ auto & bcj_src = backend_ctx->backend_configs[j_src];
2101
+ auto & bcj_dst = backend_ctx->backend_configs[j_dst];
2102
+
2103
+ ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1];
2104
+ ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1];
2105
+ GGML_ASSERT(ggml_is_contiguous(node_src));
2106
+ GGML_ASSERT(ggml_is_contiguous(node_dst));
2107
+
2108
+ ggml_tensor * node_tmp = get_node_aux(node_dst);
2109
+ set_tmp_data(node_tmp, j_dst, i_buf);
2110
+
2111
+ ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp);
2112
+
2113
+ ggml_tensor * node_red = get_node_aux(node_dst);
2114
+ node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src;
2115
+ node_red->view_offs = node_dst->view_offs;
2116
+ node_red->op = GGML_OP_ADD;
2117
+ node_red->src[0] = node_dst;
2118
+ node_red->src[1] = node_tmp;
2119
+ node_red->flags |= GGML_TENSOR_FLAG_COMPUTE;
2120
+ ggml_backend_view_init(node_red);
2121
+
2122
+ ggml_cgraph * cgraph_aux = get_cgraph_aux();
2123
+ cgraph_aux->nodes[0] = node_red;
2124
+ cgraph_aux->n_nodes = 1;
2125
+ step_cgraphs[j_dst] = cgraph_aux;
2126
+ };
2127
+
2128
+ size_t offset_j = n_backends/2;
2129
+ while ((offset_j & (offset_j - 1)) != 0) {
2130
+ offset_j--;
2131
+ }
2132
+ const size_t offset_j_max = offset_j;
2133
+ size_t i_buf = 0;
2134
+
2135
+ // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction:
2136
+ for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) {
2137
+ const size_t j_dst = j_src - 2*offset_j_max;
2138
+ push_data(j_src, j_dst, i_buf);
2139
+ const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]);
2140
+ if (status != GGML_STATUS_SUCCESS) {
2141
+ return status;
2142
+ }
2143
+ i_buf = 1;
2144
+ }
2145
+
2146
+ // Butterfly reduction:
2147
+ for (; offset_j >= 1; offset_j /= 2) {
2148
+ std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
2149
+
2150
+ for (size_t j = 0; j < 2*offset_j_max; j++) {
2151
+ const size_t j_other = j ^ offset_j;
2152
+ if (j_other >= n_backends) {
2153
+ continue;
2154
+ }
2155
+ push_data(j, j_other, i_buf);
2156
+ }
2157
+
2158
+ for (size_t j = 0; j < 2*offset_j_max; j++) {
2159
+ if (step_cgraphs[j] == nullptr) {
2160
+ continue;
2161
+ }
2162
+ auto & bcj = backend_ctx->backend_configs[j];
2163
+ const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
2164
+ if (status != GGML_STATUS_SUCCESS) {
2165
+ return status;
2166
+ }
2167
+ }
2168
+ i_buf++;
2169
+ }
2170
+ assert(i_buf == backend_ctx->n_reduce_steps);
2171
+
2172
+ // If n_backends is not a power of 2, copy back the reduced tensors to the excess:
2173
+ for (size_t j = 2*offset_j_max; j < n_backends; j++) {
2174
+ auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max];
2175
+ auto & bcj_dst = backend_ctx->backend_configs[j];
2176
+
2177
+ ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1];
2178
+ ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1];
2179
+ ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst);
2180
+ }
2181
+
2182
+ return GGML_STATUS_SUCCESS;
2183
+ };
2184
+
2185
+
2186
+ for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) {
2187
+ for (size_t j = 0; j < n_backends; j++) {
2188
+ auto & bcj = backend_ctx->backend_configs[j];
2189
+ const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main);
2190
+ if (status != GGML_STATUS_SUCCESS) {
2191
+ return status;
2192
+ }
2193
+ }
2194
+
2195
+ if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) {
2196
+ bool backend_allreduce_success = false;
2197
+ if (backend_ctx->comm_ctx) {
2198
+ std::vector<ggml_tensor *> nodes;
2199
+ nodes.reserve(n_backends);
2200
+ for (size_t j = 0; j < n_backends; j++) {
2201
+ auto & bcj = backend_ctx->backend_configs[j];
2202
+ ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
2203
+ nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
2204
+ }
2205
+ backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data());
2206
+ }
2207
+
2208
+ if (!backend_allreduce_success) {
2209
+ const ggml_status status = allreduce_fallback(i);
2210
+ if (status != GGML_STATUS_SUCCESS) {
2211
+ return status;
2212
+ }
2213
+ }
2214
+ }
2215
+ }
2216
+ return GGML_STATUS_SUCCESS;
2217
+ }
2218
+
2219
+ static const ggml_backend_i ggml_backend_meta_i = {
2220
+ /* .get_name = */ ggml_backend_meta_get_name,
2221
+ /* .free = */ ggml_backend_meta_free,
2222
+ /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async,
2223
+ /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async,
2224
+ /* .set_tensor_2d_async = */ nullptr,
2225
+ /* .get_tensor_2d_async = */ nullptr,
2226
+ /* .cpy_tensor_async = */ nullptr,
2227
+ /* .synchronize = */ ggml_backend_meta_synchronize,
2228
+ /* .graph_plan_create = */ nullptr,
2229
+ /* .graph_plan_free = */ nullptr,
2230
+ /* .graph_plan_update = */ nullptr,
2231
+ /* .graph_plan_compute = */ nullptr,
2232
+ /* .graph_compute = */ ggml_backend_meta_graph_compute,
2233
+ /* .event_record = */ nullptr,
2234
+ /* .event_wait = */ nullptr,
2235
+ /* .graph_optimize = */ nullptr,
2236
+ };
2237
+
2238
+ bool ggml_backend_is_meta(ggml_backend_t backend) {
2239
+ return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name;
2240
+ }
2241
+
2242
+ static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) {
2243
+ ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params);
2244
+
2245
+ ggml_backend_t backend = new struct ggml_backend;
2246
+ backend->guid = ggml_backend_meta_guid();
2247
+ backend->iface = ggml_backend_meta_i;
2248
+ backend->device = dev;
2249
+ backend->context = backend_ctx;
2250
+ return backend;
2251
+ }
2252
+
2253
+ size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) {
2254
+ GGML_ASSERT(ggml_backend_is_meta(meta_backend));
2255
+ const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
2256
+ return backend_ctx->backend_configs.size();
2257
+ }
2258
+
2259
+ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) {
2260
+ GGML_ASSERT(ggml_backend_is_meta(meta_backend));
2261
+ const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
2262
+ return backend_ctx->backend_configs[index].backend;
2263
+ }