whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -0,0 +1,2263 @@
1
+ #include "ggml.h"
2
+ #include "ggml-impl.h"
3
+ #include "ggml-backend.h"
4
+ #include "ggml-backend-impl.h"
5
+ #include "ggml-alloc.h"
6
+ #include "ggml-cpp.h"
7
+
8
+ #include <algorithm>
9
+ #include <cassert>
10
+ #include <cmath>
11
+ #include <cstddef>
12
+ #include <cstdint>
13
+ #include <cstring>
14
+ #include <map>
15
+ #include <memory>
16
+ #include <set>
17
+ #include <string>
18
+ #include <tuple>
19
+ #include <utility>
20
+ #include <vector>
21
+
22
+ struct ggml_backend_meta_device;
23
+ struct ggml_backend_meta_buffer_type;
24
+ struct ggml_backend_meta_buffer;
25
+ struct ggml_backend_meta;
26
+
27
+ const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) {
28
+ switch (split_axis) {
29
+ case GGML_BACKEND_SPLIT_AXIS_0:
30
+ return "0";
31
+ case GGML_BACKEND_SPLIT_AXIS_1:
32
+ return "1";
33
+ case GGML_BACKEND_SPLIT_AXIS_2:
34
+ return "2";
35
+ case GGML_BACKEND_SPLIT_AXIS_3:
36
+ return "3";
37
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
38
+ return "MIRRORED";
39
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL:
40
+ return "PARTIAL";
41
+ case GGML_BACKEND_SPLIT_AXIS_NONE:
42
+ return "NONE";
43
+ case GGML_BACKEND_SPLIT_AXIS_UNKNOWN:
44
+ return "UNKNOWN";
45
+ default:
46
+ GGML_ABORT("fatal error");
47
+ }
48
+ }
49
+
50
+ //
51
+ // meta backend device
52
+ //
53
+
54
+ struct ggml_backend_meta_device_context {
55
+ std::vector<ggml_backend_dev_t> simple_devs;
56
+ ggml_backend_meta_get_split_state_t get_split_state;
57
+ void * get_split_state_ud;
58
+
59
+ std::string name;
60
+ std::string description;
61
+
62
+ ggml_backend_meta_device_context(
63
+ std::vector<ggml_backend_dev_t> simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) :
64
+ simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) {
65
+ name = std::string("Meta(");
66
+ description = std::string("Meta(");
67
+ for (size_t i = 0; i < simple_devs.size(); i++) {
68
+ if (i > 0) {
69
+ name += ",";
70
+ description += ",";
71
+ }
72
+ name += ggml_backend_dev_name (simple_devs[i]);
73
+ description += ggml_backend_dev_description(simple_devs[i]);
74
+ }
75
+ name += ")";
76
+ description += ")";
77
+ }
78
+
79
+ bool operator<(const ggml_backend_meta_device_context & other) const {
80
+ return std::tie(simple_devs, get_split_state, get_split_state_ud)
81
+ < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud);
82
+ }
83
+ };
84
+
85
+ static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev);
86
+
87
+ static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) {
88
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
89
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
90
+ return meta_dev_ctx->name.c_str();
91
+ }
92
+
93
+ static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) {
94
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
95
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
96
+ return meta_dev_ctx->description.c_str();
97
+ }
98
+
99
+ static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
100
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
101
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
102
+ *free = 0;
103
+ *total = 0;
104
+ for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) {
105
+ size_t tmp_free, tmp_total;
106
+ ggml_backend_dev_memory(dev, &tmp_free, &tmp_total);
107
+ *free += tmp_free;
108
+ *total += tmp_total;
109
+ }
110
+ }
111
+
112
+ static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) {
113
+ return GGML_BACKEND_DEVICE_TYPE_META;
114
+
115
+ GGML_UNUSED(dev);
116
+ }
117
+
118
+ static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
119
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
120
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
121
+
122
+ // TODO replace placeholders
123
+ props->name = ggml_backend_meta_device_get_name(dev);
124
+ props->description = ggml_backend_meta_device_get_description(dev);
125
+ props->type = ggml_backend_meta_device_get_type(dev);
126
+ props->device_id = 0;
127
+
128
+ ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total);
129
+
130
+ props->caps = {
131
+ /* .async = */ true,
132
+ /* .host_buffer = */ false, // Not implemented.
133
+ /* .buffer_from_host_ptr = */ false, // Not implemented.
134
+ /* .events = */ false, // Not implemented.
135
+ };
136
+ for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
137
+ ggml_backend_dev_props tmp_props;
138
+ ggml_backend_dev_get_props(simple_dev, &tmp_props);
139
+ props->caps.async = props->caps.async && tmp_props.caps.async;
140
+ props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer;
141
+ props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr;
142
+ props->caps.events = props->caps.events && tmp_props.caps.events;
143
+ }
144
+ }
145
+
146
+ static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params);
147
+
148
+ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev);
149
+
150
+ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev);
151
+
152
+ static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
153
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
154
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
155
+ return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(),
156
+ [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); });
157
+ }
158
+
159
+ static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
160
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
161
+ ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft);
162
+ if (!ggml_backend_dev_is_meta(dev_buft)) {
163
+ return false;
164
+ }
165
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
166
+ const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context;
167
+ if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) {
168
+ return false;
169
+ }
170
+ for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) {
171
+ if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) {
172
+ return false;
173
+ }
174
+ }
175
+ return true;
176
+ }
177
+
178
+ static const ggml_backend_device_i ggml_backend_meta_device_iface = {
179
+ /* .get_name = */ ggml_backend_meta_device_get_name,
180
+ /* .get_description = */ ggml_backend_meta_device_get_description,
181
+ /* .get_memory = */ ggml_backend_meta_device_get_memory,
182
+ /* .get_type = */ ggml_backend_meta_device_get_type,
183
+ /* .get_props = */ ggml_backend_meta_device_get_props,
184
+ /* .init_backend = */ ggml_backend_meta_device_init_backend,
185
+ /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type,
186
+ /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type,
187
+ /* .buffer_from_host_ptr = */ nullptr,
188
+ /* .supports_op = */ ggml_backend_meta_device_supports_op,
189
+ /* .supports_buft = */ ggml_backend_meta_device_supports_buft,
190
+ /* .offload_op = */ nullptr,
191
+ /* .event_new = */ nullptr,
192
+ /* .event_free = */ nullptr,
193
+ /* .event_synchronize = */ nullptr,
194
+ };
195
+
196
+ static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) {
197
+ return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name;
198
+ }
199
+
200
+ static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) {
201
+ GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
202
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
203
+ return meta_dev_ctx->simple_devs.size();
204
+ }
205
+
206
+ static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) {
207
+ GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev));
208
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context;
209
+ GGML_ASSERT(index < meta_dev_ctx->simple_devs.size());
210
+ return meta_dev_ctx->simple_devs[index];
211
+ }
212
+
213
+ ggml_backend_dev_t ggml_backend_meta_device(
214
+ ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
215
+ GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES);
216
+ // TODO: this is not thread-safe - needs to be fixed
217
+ static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs;
218
+ static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
219
+
220
+ std::vector<ggml_backend_dev_t> simple_devs;
221
+ simple_devs.reserve(n_devs);
222
+ for (size_t i = 0; i < n_devs; i++) {
223
+ simple_devs.push_back(devs[i]);
224
+ }
225
+ ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud);
226
+
227
+ {
228
+ auto it = meta_devs.find(ctx);
229
+ if (it != meta_devs.end()) {
230
+ return &it->second;
231
+ }
232
+ }
233
+ ctxs.push_back(std::make_unique<ggml_backend_meta_device_context>(ctx));
234
+
235
+ struct ggml_backend_device meta_dev = {
236
+ /*iface =*/ ggml_backend_meta_device_iface,
237
+ /*reg =*/ nullptr,
238
+ /*ctx =*/ ctxs.back().get(),
239
+ };
240
+
241
+ auto result = meta_devs.emplace(*ctxs.back(), meta_dev);
242
+ return &result.first->second;
243
+ }
244
+
245
+ //
246
+ // meta backend buffer type
247
+ //
248
+
249
+ struct ggml_backend_meta_buffer_type_context {
250
+ std::vector<ggml_backend_buffer_type_t> simple_bufts;
251
+
252
+ std::string name;
253
+
254
+ ggml_backend_meta_buffer_type_context(std::vector<ggml_backend_buffer_type_t> simple_bufts) : simple_bufts(std::move(simple_bufts)) {
255
+ name = "Meta(";
256
+ for (size_t i = 0; i < simple_bufts.size(); i++) {
257
+ if (i > 0) {
258
+ name += ",";
259
+ }
260
+ name += ggml_backend_buft_name(simple_bufts[i]);
261
+ }
262
+ name += ")";
263
+ }
264
+
265
+ bool operator<(const ggml_backend_meta_buffer_type_context & other) const {
266
+ return simple_bufts < other.simple_bufts;
267
+ }
268
+ };
269
+
270
+ static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) {
271
+ GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
272
+ const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
273
+ return meta_buft_ctx->simple_bufts.size();
274
+ }
275
+
276
+ static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
277
+ GGML_ASSERT(ggml_backend_buft_is_meta(buft));
278
+ const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context;
279
+ return meta_buft_ctx->name.c_str();
280
+ }
281
+
282
+ static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) {
283
+ GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
284
+ const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
285
+ GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size());
286
+ return meta_buft_ctx->simple_bufts[index];
287
+ }
288
+
289
+ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
290
+
291
+ static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
292
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
293
+ size_t max_alignment = 1;
294
+ for (size_t i = 0; i < n_simple_bufts; i++) {
295
+ const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i));
296
+ max_alignment = std::max(max_alignment, alignment);
297
+ GGML_ASSERT(max_alignment % alignment == 0);
298
+ }
299
+ return max_alignment;
300
+ }
301
+
302
+ static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
303
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
304
+ size_t max_size = SIZE_MAX;
305
+ for (size_t i = 0; i < n_simple_bufts; i++) {
306
+ max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i)));
307
+ }
308
+ return max_size;
309
+ }
310
+
311
+ static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
312
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
313
+ size_t max_alloc_size = 0;
314
+ for (size_t i = 0; i < n_simple_bufts; i++) {
315
+ const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor);
316
+ max_alloc_size = std::max(max_alloc_size, alloc_size);
317
+ }
318
+ return max_alloc_size;
319
+ }
320
+
321
+ static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
322
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
323
+ for (size_t i = 0; i < n_simple_bufts; i++) {
324
+ if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) {
325
+ return false;
326
+ }
327
+ }
328
+ return true;
329
+ }
330
+
331
+ static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = {
332
+ /* .get_name = */ ggml_backend_meta_buffer_type_get_name,
333
+ /* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer,
334
+ /* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment,
335
+ /* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size,
336
+ /* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size,
337
+ /* .is_host = */ ggml_backend_meta_buffer_type_is_host,
338
+ };
339
+
340
+ bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) {
341
+ return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name;
342
+ }
343
+
344
+ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) {
345
+ static std::map<ggml_backend_dev_t, struct ggml_backend_buffer_type> meta_bufts;
346
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
347
+ {
348
+ auto it = meta_bufts.find(dev);
349
+ if (it != meta_bufts.end()) {
350
+ return &it->second;
351
+ }
352
+ }
353
+
354
+ const size_t n_devs = ggml_backend_meta_dev_n_devs(dev);
355
+ std::vector<ggml_backend_buffer_type_t> simple_bufts;
356
+ simple_bufts.reserve(n_devs);
357
+ for (size_t i = 0; i < n_devs; i++) {
358
+ simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i)));
359
+ }
360
+ ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts);
361
+
362
+ struct ggml_backend_buffer_type meta_buft = {
363
+ /*iface =*/ ggml_backend_meta_buffer_type_iface,
364
+ /*device =*/ dev,
365
+ /*ctx =*/ buft_ctx,
366
+ };
367
+ auto result = meta_bufts.emplace(dev, meta_buft);
368
+ return &result.first->second;
369
+ }
370
+
371
+ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) {
372
+ GGML_ASSERT(ggml_backend_dev_is_meta(dev));
373
+ const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
374
+
375
+ ggml_backend_buffer_type_t host_buft = nullptr;
376
+ for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
377
+ ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev);
378
+ if (simple_host_buft == nullptr) {
379
+ return nullptr;
380
+ }
381
+ if (host_buft == nullptr) {
382
+ host_buft = simple_host_buft;
383
+ } else if (host_buft != simple_host_buft) {
384
+ // if different simple devices have different host buffer types,
385
+ // we cannot provide a single host buffer type for the meta device
386
+ return nullptr;
387
+ }
388
+ }
389
+ return host_buft;
390
+ }
391
+
392
+ //
393
+ // meta backend buffer
394
+ //
395
+
396
+ // Container to hold the tensor slices per simple ggml backend buffer.
397
+ struct ggml_backend_meta_simple_tensor_container {
398
+ std::vector<ggml_context_ptr> ctxs;
399
+ std::map<const ggml_tensor *, std::vector<ggml_tensor *>> simple_tensors;
400
+
401
+ ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) {
402
+ ctxs.reserve(n_simple);
403
+ for (int i = 0; i < n_simple; i++) {
404
+ ctxs.emplace_back(ggml_init(params));
405
+ }
406
+ }
407
+ ggml_backend_meta_simple_tensor_container() {}
408
+ };
409
+
410
+ struct ggml_backend_meta_buffer_context {
411
+ // FIXME
412
+ // Most tensors can simply be stored statically in their own buffer.
413
+ // Externally created views however also need a mapping to simple tensors but they use the buffer of the view source.
414
+ // If external views are simply using that buffer they will slowly deplete its memory.
415
+ // Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp.
416
+ // Long-term: tie the lifetime of external views to the meta backend executing the graph instead,
417
+ // currently not possible due to graph-external operations in the backend scheduler.
418
+ ggml_backend_meta_simple_tensor_container stc_static;
419
+ ggml_backend_meta_simple_tensor_container stc_compute[2];
420
+ int stc_compute_index = 0;
421
+ int stc_compute_index_next = 0;
422
+ std::vector<ggml_backend_buffer_ptr> bufs;
423
+
424
+ // FIXME
425
+ // The size of the split state cache is unbounded and can theoretically grow infinitely large.
426
+ // However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive.
427
+ static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding);
428
+ std::map<std::pair<const ggml_tensor *, bool>, std::pair<ggml_backend_meta_split_state, char[nbtc]>> split_state_cache;
429
+
430
+ int debug;
431
+
432
+ ggml_backend_meta_buffer_context(
433
+ ggml_backend_meta_simple_tensor_container & stc_static,
434
+ ggml_backend_meta_simple_tensor_container & stc_compute_0,
435
+ ggml_backend_meta_simple_tensor_container & stc_compute_1,
436
+ const std::vector<ggml_backend_buffer_t> & bufs)
437
+ : stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} {
438
+ this->bufs.reserve(bufs.size());
439
+ for (ggml_backend_buffer_t buf : bufs) {
440
+ this->bufs.emplace_back(buf);
441
+ }
442
+ const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG");
443
+ debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0;
444
+ }
445
+
446
+ ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) {
447
+ if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) {
448
+ return stc_static;
449
+ }
450
+ return stc_compute[stc_compute_index];
451
+ }
452
+ };
453
+
454
+ static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) {
455
+ GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
456
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
457
+ delete buf_ctx;
458
+ }
459
+
460
+ static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) {
461
+ GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
462
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
463
+ return buf_ctx->bufs.size();
464
+ }
465
+
466
+ static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) {
467
+ GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf));
468
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context;
469
+ GGML_ASSERT(index < buf_ctx->bufs.size());
470
+ return buf_ctx->bufs[index].get();
471
+ }
472
+
473
+ static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) {
474
+ GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
475
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
476
+ GGML_ASSERT(index < buf_ctx->bufs.size());
477
+
478
+ ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor);
479
+ auto it = stc.simple_tensors.find(tensor);
480
+ if (it == stc.simple_tensors.end()) {
481
+ return nullptr;
482
+ }
483
+ return it->second[index];
484
+ }
485
+
486
+ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync);
487
+
488
+ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
489
+ ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) {
490
+ // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way.
491
+ // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there.
492
+ // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results.
493
+ const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
494
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
495
+
496
+ auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool {
497
+ if (a.axis != b.axis) {
498
+ return false;
499
+ }
500
+ for (size_t j = 0; j < n_bufs; j++) {
501
+ int64_t sum_a = 0;
502
+ for (size_t s = 0; s < a.n_segments; s++) {
503
+ sum_a += a.ne[s*n_bufs + j] * a.nr[s];
504
+ }
505
+ int64_t sum_b = 0;
506
+ for (size_t s = 0; s < b.n_segments; s++) {
507
+ sum_b += b.ne[s*n_bufs + j] * b.nr[s];
508
+ }
509
+ if (sum_a != sum_b) {
510
+ return false;
511
+ }
512
+ }
513
+ return true;
514
+ };
515
+
516
+ auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state {
517
+ ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1};
518
+ for (size_t i = 0; i < GGML_MAX_SRC; i++) {
519
+ if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
520
+ continue;
521
+ }
522
+ if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
523
+ ret = src_ss[i];
524
+ } else if (!split_states_equal(src_ss[i], ret)) {
525
+ ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
526
+ break;
527
+ }
528
+ }
529
+ if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
530
+ ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
531
+ }
532
+ if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
533
+ ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
534
+ }
535
+ GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
536
+ return ret;
537
+ };
538
+
539
+ // Some ops process data on a per-row bases:
540
+ auto handle_per_row = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
541
+ GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0);
542
+ return src_ss[0];
543
+ };
544
+
545
+ // Some ops broadcast the src1 data across src0:
546
+ auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
547
+ if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS &&
548
+ tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
549
+ return src_ss[0];
550
+ }
551
+ if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis ||
552
+ (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) {
553
+ return src_ss[0]; // GGML_OP_ADD_ID
554
+ }
555
+ GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
556
+ return handle_generic(src_ss, /*scalar_only =*/ false);
557
+ };
558
+
559
+ auto handle_concat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
560
+ const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0));
561
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) {
562
+ GGML_ASSERT(concat_axis != src_ss[1].axis);
563
+ return src_ss[1];
564
+ }
565
+ if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
566
+ GGML_ASSERT(concat_axis != src_ss[0].axis);
567
+ return src_ss[0];
568
+ }
569
+ if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) {
570
+ return src_ss[0];
571
+ }
572
+ return handle_generic(src_ss, /*scalar_only =*/ true);
573
+ };
574
+
575
+ auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
576
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
577
+ return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1};
578
+ }
579
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
580
+ ggml_backend_meta_split_state ret = src_ss[0];
581
+ ret.axis = GGML_BACKEND_SPLIT_AXIS_0;
582
+ ret.nr[0] = 1;
583
+ ret.n_segments = 1;
584
+ return ret;
585
+ }
586
+ if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
587
+ return src_ss[1];
588
+ }
589
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) {
590
+ GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1]));
591
+ return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1};
592
+ }
593
+ GGML_ABORT("fatal error");
594
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
595
+ };
596
+
597
+ auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
598
+ switch (src_ss[0].axis) {
599
+ case GGML_BACKEND_SPLIT_AXIS_0:
600
+ case GGML_BACKEND_SPLIT_AXIS_1:
601
+ case GGML_BACKEND_SPLIT_AXIS_2:
602
+ case GGML_BACKEND_SPLIT_AXIS_3: {
603
+ GGML_ASSERT(src_ss[0].n_segments == 1);
604
+ if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) {
605
+ return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1};
606
+ }
607
+ int64_t base_ne_in = tensor->src[0]->ne[0];
608
+ for (int dim = 1; dim <= src_ss[0].axis; dim++) {
609
+ base_ne_in *= tensor->src[0]->ne[dim];
610
+ }
611
+ base_ne_in /= src_ss[0].nr[0];
612
+ int64_t base_ne_out = 1;
613
+ for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
614
+ const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim];
615
+ if (base_ne_out_next % base_ne_in == 0) {
616
+ return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1};
617
+ }
618
+ if (base_ne_out_next > base_ne_in) {
619
+ GGML_ASSERT(src_ss[0].n_segments == 1);
620
+ GGML_ASSERT(src_ss[0].nr[0] == 1);
621
+ return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1};
622
+ }
623
+ base_ne_out = base_ne_out_next;
624
+ }
625
+ GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op));
626
+ }
627
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
628
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
629
+ return src_ss[0];
630
+ }
631
+ default: {
632
+ GGML_ABORT("fatal error");
633
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
634
+ }
635
+ }
636
+ };
637
+
638
+ auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
639
+ if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
640
+ return handle_reshape(src_ss);
641
+ }
642
+ return handle_generic(src_ss, /*scalar_only =*/ false);
643
+ };
644
+
645
+ auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
646
+ if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) {
647
+ return handle_reshape(src_ss);
648
+ }
649
+ const int axis = src_ss[0].axis;
650
+ {
651
+ bool all_strides_the_same = true;
652
+ for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
653
+ if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) {
654
+ continue;
655
+ }
656
+ if (tensor->nb[dim] != tensor->src[0]->nb[dim]) {
657
+ all_strides_the_same = false;
658
+ break;
659
+ }
660
+ }
661
+ if (all_strides_the_same) {
662
+ return src_ss[0];
663
+ }
664
+ }
665
+ if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) {
666
+ for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) {
667
+ if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) {
668
+ return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1};
669
+ }
670
+ }
671
+ GGML_ABORT("fatal error");
672
+ }
673
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
674
+ return src_ss[0];
675
+ }
676
+ GGML_ABORT("view of permuted tensor not implemented");
677
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
678
+ };
679
+
680
+ auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
681
+ switch (src_ss[0].axis) {
682
+ case GGML_BACKEND_SPLIT_AXIS_0:
683
+ case GGML_BACKEND_SPLIT_AXIS_1:
684
+ case GGML_BACKEND_SPLIT_AXIS_2:
685
+ case GGML_BACKEND_SPLIT_AXIS_3: {
686
+ GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1);
687
+ return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1};
688
+ }
689
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
690
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
691
+ return src_ss[0];
692
+ }
693
+ default: {
694
+ GGML_ABORT("fatal error");
695
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
696
+ }
697
+ }
698
+ };
699
+
700
+ auto handle_transpose = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
701
+ switch (src_ss[0].axis) {
702
+ case GGML_BACKEND_SPLIT_AXIS_0:
703
+ case GGML_BACKEND_SPLIT_AXIS_1: {
704
+ GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1);
705
+ return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1};
706
+ }
707
+ case GGML_BACKEND_SPLIT_AXIS_2:
708
+ case GGML_BACKEND_SPLIT_AXIS_3:
709
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
710
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
711
+ return src_ss[0];
712
+ }
713
+ default: {
714
+ GGML_ABORT("fatal error");
715
+ //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
716
+ }
717
+ }
718
+ };
719
+
720
+ auto handle_get_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
721
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
722
+ return src_ss[0];
723
+ }
724
+ return handle_generic(src_ss, /*scalar_only =*/ true);
725
+ };
726
+
727
+ auto handle_set_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
728
+ GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1);
729
+ GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
730
+ GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2]));
731
+ return src_ss[0];
732
+ };
733
+
734
+ auto handle_rope = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
735
+ GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
736
+ return src_ss[0];
737
+ };
738
+
739
+ auto handle_pad = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
740
+ if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) {
741
+ GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0);
742
+ GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0);
743
+ }
744
+ return src_ss[0];
745
+ };
746
+
747
+ auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
748
+ GGML_ASSERT( src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2);
749
+ GGML_ASSERT( src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2);
750
+ GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2);
751
+ GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
752
+ GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0);
753
+ return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1};
754
+ };
755
+
756
+ auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
757
+ if (src_ss[0].axis == src_ss[1].axis) {
758
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) {
759
+ return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1};
760
+ }
761
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) {
762
+ return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
763
+ }
764
+ }
765
+ return handle_generic(src_ss, /*scalar_only =*/ false);
766
+ };
767
+
768
+ auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
769
+ if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
770
+ src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
771
+ src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
772
+ return src_ss[0];
773
+ }
774
+ GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1);
775
+ GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1);
776
+ GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
777
+ GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
778
+ GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
779
+ // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2,
780
+ // so a head-aligned split on the input cache lands on axis 2 here.
781
+ GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0);
782
+ return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
783
+ };
784
+
785
+ auto calculate_split_state = [&]() -> ggml_backend_meta_split_state {
786
+ if (ggml_nelements(tensor) == 0) {
787
+ return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
788
+ }
789
+ if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) {
790
+ ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer));
791
+ const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
792
+ ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud);
793
+ if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) {
794
+ const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1;
795
+ int64_t ne_sum = 0;
796
+ for (size_t s = 0; s < ret.n_segments; s++) {
797
+ for (size_t j = 0; j < n_bufs; j++) {
798
+ GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0);
799
+ ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s];
800
+ }
801
+ }
802
+ GGML_ASSERT(ne_sum == tensor->ne[ret.axis]);
803
+ }
804
+ return ret;
805
+ }
806
+
807
+ std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1});
808
+ for (size_t i = 0; i < GGML_MAX_SRC; i++) {
809
+ if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
810
+ src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
811
+ continue;
812
+ }
813
+ src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true);
814
+ GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
815
+ }
816
+
817
+ ggml_backend_meta_split_state split_state;
818
+ switch (tensor->op) {
819
+ case GGML_OP_NONE: {
820
+ split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1};
821
+ } break;
822
+ case GGML_OP_DUP: {
823
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
824
+ } break;
825
+ case GGML_OP_ADD:
826
+ case GGML_OP_ADD_ID: {
827
+ split_state = handle_bin_bcast(src_ss);
828
+ } break;
829
+ case GGML_OP_ADD1:
830
+ case GGML_OP_ACC: {
831
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
832
+ } break;
833
+ case GGML_OP_SUB:
834
+ case GGML_OP_MUL:
835
+ case GGML_OP_DIV: {
836
+ split_state = handle_bin_bcast(src_ss);
837
+ } break;
838
+ case GGML_OP_SQR:
839
+ case GGML_OP_SQRT:
840
+ case GGML_OP_LOG:
841
+ case GGML_OP_SIN:
842
+ case GGML_OP_COS: {
843
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
844
+ } break;
845
+ case GGML_OP_SUM: {
846
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
847
+ } break;
848
+ case GGML_OP_SUM_ROWS:
849
+ case GGML_OP_CUMSUM:
850
+ case GGML_OP_MEAN:
851
+ case GGML_OP_ARGMAX:
852
+ case GGML_OP_COUNT_EQUAL: {
853
+ split_state = handle_per_row(src_ss);
854
+ } break;
855
+ case GGML_OP_REPEAT:
856
+ case GGML_OP_REPEAT_BACK: {
857
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
858
+ } break;
859
+ case GGML_OP_CONCAT: {
860
+ split_state = handle_concat(src_ss);
861
+ } break;
862
+ case GGML_OP_SILU_BACK: {
863
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
864
+ } break;
865
+ case GGML_OP_NORM:
866
+ case GGML_OP_RMS_NORM:
867
+ case GGML_OP_RMS_NORM_BACK:
868
+ case GGML_OP_GROUP_NORM:
869
+ case GGML_OP_L2_NORM: {
870
+ split_state = handle_per_row(src_ss);
871
+ } break;
872
+ case GGML_OP_MUL_MAT:
873
+ case GGML_OP_MUL_MAT_ID: {
874
+ split_state = handle_mul_mat(src_ss);
875
+ } break;
876
+ case GGML_OP_OUT_PROD: {
877
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
878
+ } break;
879
+ case GGML_OP_SCALE: {
880
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
881
+ } break;
882
+ case GGML_OP_SET: {
883
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
884
+ } break;
885
+ case GGML_OP_CPY: {
886
+ split_state = handle_cpy(src_ss);
887
+ } break;
888
+ case GGML_OP_CONT:
889
+ case GGML_OP_RESHAPE: {
890
+ split_state = handle_reshape(src_ss);
891
+ } break;
892
+ case GGML_OP_VIEW: {
893
+ split_state = handle_view(src_ss);
894
+ } break;
895
+ case GGML_OP_PERMUTE: {
896
+ split_state = handle_permute(src_ss);
897
+ } break;
898
+ case GGML_OP_TRANSPOSE: {
899
+ split_state = handle_transpose(src_ss);
900
+ } break;
901
+ case GGML_OP_GET_ROWS: {
902
+ split_state = handle_get_rows(src_ss);
903
+ } break;
904
+ case GGML_OP_GET_ROWS_BACK: {
905
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
906
+ } break;
907
+ case GGML_OP_SET_ROWS: {
908
+ split_state = handle_set_rows(src_ss);
909
+ } break;
910
+ case GGML_OP_DIAG:
911
+ case GGML_OP_DIAG_MASK_INF:
912
+ case GGML_OP_DIAG_MASK_ZERO: {
913
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
914
+ } break;
915
+ case GGML_OP_SOFT_MAX:
916
+ case GGML_OP_SOFT_MAX_BACK: {
917
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
918
+ } break;
919
+ case GGML_OP_ROPE: {
920
+ split_state = handle_rope(src_ss);
921
+ } break;
922
+ case GGML_OP_ROPE_BACK: {
923
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
924
+ } break;
925
+ case GGML_OP_CLAMP: {
926
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
927
+ } break;
928
+ case GGML_OP_CONV_TRANSPOSE_1D:
929
+ case GGML_OP_IM2COL:
930
+ case GGML_OP_IM2COL_BACK:
931
+ case GGML_OP_IM2COL_3D:
932
+ case GGML_OP_CONV_2D:
933
+ case GGML_OP_CONV_3D:
934
+ case GGML_OP_CONV_2D_DW:
935
+ case GGML_OP_CONV_TRANSPOSE_2D:
936
+ case GGML_OP_POOL_1D:
937
+ case GGML_OP_POOL_2D:
938
+ case GGML_OP_POOL_2D_BACK:
939
+ case GGML_OP_UPSCALE: {
940
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
941
+ } break;
942
+ case GGML_OP_PAD: {
943
+ split_state = handle_pad(src_ss);
944
+ } break;
945
+ case GGML_OP_PAD_REFLECT_1D:
946
+ case GGML_OP_ROLL:
947
+ case GGML_OP_ARANGE:
948
+ case GGML_OP_TIMESTEP_EMBEDDING: {
949
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
950
+ } break;
951
+ case GGML_OP_ARGSORT:
952
+ case GGML_OP_TOP_K: {
953
+ split_state = handle_per_row(src_ss);
954
+ } break;
955
+ case GGML_OP_LEAKY_RELU: {
956
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
957
+ } break;
958
+ case GGML_OP_TRI: {
959
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
960
+ } break;
961
+ case GGML_OP_FILL: {
962
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
963
+ } break;
964
+ case GGML_OP_FLASH_ATTN_EXT: {
965
+ split_state = handle_flash_attn_ext(src_ss);
966
+ } break;
967
+ case GGML_OP_FLASH_ATTN_BACK: {
968
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
969
+ } break;
970
+ case GGML_OP_SSM_CONV: {
971
+ split_state = handle_ssm_conv(src_ss);
972
+ } break;
973
+ case GGML_OP_SSM_SCAN:
974
+ case GGML_OP_WIN_PART:
975
+ case GGML_OP_WIN_UNPART:
976
+ case GGML_OP_GET_REL_POS:
977
+ case GGML_OP_ADD_REL_POS:
978
+ case GGML_OP_RWKV_WKV6:
979
+ case GGML_OP_GATED_LINEAR_ATTN:
980
+ case GGML_OP_RWKV_WKV7:
981
+ case GGML_OP_SOLVE_TRI: {
982
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
983
+ } break;
984
+ case GGML_OP_GATED_DELTA_NET: {
985
+ split_state = handle_gated_delta_net(src_ss);
986
+ } break;
987
+ case GGML_OP_UNARY: {
988
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
989
+ } break;
990
+ case GGML_OP_MAP_CUSTOM1:
991
+ case GGML_OP_MAP_CUSTOM2:
992
+ case GGML_OP_MAP_CUSTOM3:
993
+ case GGML_OP_CUSTOM: {
994
+ split_state = handle_generic(src_ss, /*scalar_only =*/ true);
995
+ } break;
996
+ case GGML_OP_CROSS_ENTROPY_LOSS:
997
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK: {
998
+ split_state = handle_per_row(src_ss);
999
+ } break;
1000
+ case GGML_OP_OPT_STEP_ADAMW:
1001
+ case GGML_OP_OPT_STEP_SGD:
1002
+ case GGML_OP_GLU: {
1003
+ split_state = handle_generic(src_ss, /*scalar_only =*/ false);
1004
+ } break;
1005
+ default: {
1006
+ GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op));
1007
+ split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1};
1008
+ } break;
1009
+ }
1010
+ if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
1011
+ bool first_src_split_by_axis = true;
1012
+ const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
1013
+
1014
+ for (size_t i = 0; i < GGML_MAX_SRC; i++) {
1015
+ if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) {
1016
+ continue;
1017
+ }
1018
+ if (first_src_split_by_axis) {
1019
+ for (size_t j = 0; j < n_bufs; j++) {
1020
+ // Take over ratio from src:
1021
+ for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1022
+ split_state.ne[s*n_bufs + j] = 0;
1023
+ }
1024
+ for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1025
+ split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s];
1026
+ }
1027
+ split_state.ne[j] *= tensor->ne[split_state.axis];
1028
+ if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) {
1029
+ const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0];
1030
+ GGML_ASSERT(split_state.ne[j] % div == 0);
1031
+ split_state.ne[j] /= div;
1032
+ }
1033
+ }
1034
+ } else {
1035
+ GGML_ASSERT(split_state.n_segments == 1);
1036
+ for (size_t j = 0; j < n_bufs; j++) {
1037
+ // Assert that ratio is consistent:
1038
+ int64_t sum = 0;
1039
+ for (size_t s = 0; s < src_ss[i].n_segments; s++) {
1040
+ sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s];
1041
+ }
1042
+ GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis]
1043
+ == sum * tensor->ne[split_state.axis]);
1044
+ }
1045
+ }
1046
+ first_src_split_by_axis = false;
1047
+ }
1048
+ GGML_ASSERT(!first_src_split_by_axis);
1049
+ }
1050
+ return split_state;
1051
+ };
1052
+
1053
+ const std::pair key = std::make_pair(tensor, assume_sync);
1054
+ auto it = buf_ctx->split_state_cache.find(key);
1055
+ if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) {
1056
+ buf_ctx->split_state_cache.clear();
1057
+ it = buf_ctx->split_state_cache.end();
1058
+ }
1059
+
1060
+ if (it == buf_ctx->split_state_cache.end()) {
1061
+ buf_ctx->split_state_cache[key].first = calculate_split_state();
1062
+ memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second));
1063
+ if (buf_ctx->debug > 0) {
1064
+ std::string srcs_info;
1065
+ for (size_t i = 0; i < GGML_MAX_SRC; i++) {
1066
+ if (tensor->src[i] == nullptr) {
1067
+ continue;
1068
+ }
1069
+ if (!srcs_info.empty()) {
1070
+ srcs_info += ", ";
1071
+ }
1072
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true);
1073
+ GGML_ASSERT(split_state.n_segments == 1);
1074
+ const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis);
1075
+ std::string ne_info;
1076
+ for (size_t j = 0; j < n_bufs; j++) {
1077
+ if (!ne_info.empty()) {
1078
+ ne_info += ", ";
1079
+ }
1080
+ ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]);
1081
+ }
1082
+ srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]";
1083
+ }
1084
+ std::string ne_info;
1085
+ for (size_t j = 0; j < n_bufs; j++) {
1086
+ if (!ne_info.empty()) {
1087
+ ne_info += ", ";
1088
+ }
1089
+ const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first;
1090
+ ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]);
1091
+ }
1092
+ GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op),
1093
+ ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str());
1094
+ }
1095
+ }
1096
+
1097
+ ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first;
1098
+ GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE);
1099
+ #ifndef NDEBUG
1100
+ if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
1101
+ int64_t ne_ret = 0;
1102
+ for (size_t s = 0; s < ret.n_segments; s++) {
1103
+ for (size_t j = 0; j < n_bufs; j++) {
1104
+ ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s];
1105
+ }
1106
+ }
1107
+ assert(ne_ret == tensor->ne[int(ret.axis)]);
1108
+ }
1109
+ #endif // NDEBUG
1110
+ return ret;
1111
+ }
1112
+
1113
+ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
1114
+ GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
1115
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
1116
+ return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync);
1117
+ }
1118
+
1119
+ static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) {
1120
+ GGML_UNUSED(buffer);
1121
+ return (void *) 0x1000000000000000; // FIXME
1122
+ }
1123
+
1124
+ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) {
1125
+ GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
1126
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
1127
+ const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
1128
+
1129
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true);
1130
+ GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
1131
+ GGML_ASSERT(split_state.n_segments <= 16);
1132
+
1133
+ int split_dim = split_state.axis;
1134
+ int64_t ne[GGML_MAX_DIMS];
1135
+ size_t nb[GGML_MAX_DIMS];
1136
+ for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
1137
+ ne[k] = tensor->ne[k];
1138
+ nb[k] = tensor->nb[k];
1139
+ }
1140
+
1141
+ std::vector<ggml_tensor *> simple_tensors;
1142
+ simple_tensors.reserve(n_simple_bufs);
1143
+ for (size_t j = 0; j < n_simple_bufs; j++) {
1144
+ ggml_context * simple_ctx = stc.ctxs[j].get();
1145
+ ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get();
1146
+
1147
+ if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1148
+ // TODO: the following assert fails for llama-parallel even though the results are correct:
1149
+ // GGML_ASSERT(ggml_is_contiguously_allocated(tensor));
1150
+ ne[split_dim] = 0;
1151
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1152
+ ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s];
1153
+ }
1154
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
1155
+ if (tensor->nb[i] > tensor->nb[split_dim]) {
1156
+ nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim];
1157
+ }
1158
+ }
1159
+ }
1160
+
1161
+ ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne);
1162
+ t_ij->op = tensor->op;
1163
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
1164
+ t_ij->nb[i] = nb[i];
1165
+ }
1166
+ t_ij->flags = tensor->flags;
1167
+ memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params));
1168
+ ggml_set_name(t_ij, tensor->name);
1169
+ t_ij->buffer = simple_buf;
1170
+ t_ij->view_src = tensor->view_src;
1171
+ t_ij->view_offs = tensor->view_offs;
1172
+ if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) {
1173
+ t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j);
1174
+ if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
1175
+ GGML_ASSERT(tensor->ne[split_dim] != 0);
1176
+ const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis;
1177
+ GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS);
1178
+
1179
+ // The offset can be internal to the data split, in those cases the view offset should not be scaled.
1180
+ // If however, the offset is larger than the data split then it needs to be scaled proportionally.
1181
+ bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src];
1182
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
1183
+ const size_t dim_size = tensor->ne[i] * tensor->nb[i];
1184
+ if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) {
1185
+ split_internal_offset = true;
1186
+ break;
1187
+ }
1188
+ }
1189
+ if (!split_internal_offset) {
1190
+ t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
1191
+ }
1192
+ }
1193
+ }
1194
+ if (t_ij->view_src != nullptr) {
1195
+ t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs;
1196
+ } else if (simple_buf != nullptr) {
1197
+ t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf)
1198
+ + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer));
1199
+ }
1200
+ t_ij->extra = tensor->extra;
1201
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1202
+ t_ij->src[i] = tensor->src[i];
1203
+ if (tensor->src[i] == tensor) {
1204
+ t_ij->src[i] = t_ij;
1205
+ } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) {
1206
+ t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j);
1207
+ }
1208
+ }
1209
+
1210
+ simple_tensors.push_back(t_ij);
1211
+ }
1212
+
1213
+ // If one of the sources has a zero-sized slice, disable the computation:
1214
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1215
+ if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) {
1216
+ continue;
1217
+ }
1218
+
1219
+ const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true);
1220
+ if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) {
1221
+ continue;
1222
+ }
1223
+ for (size_t j = 0; j < n_simple_bufs; j++) {
1224
+ int64_t ne_sum = 0;
1225
+ for (size_t s = 0; s < split_state_src.n_segments; s++) {
1226
+ ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s];
1227
+ }
1228
+ if (ne_sum == 0) {
1229
+ simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
1230
+ }
1231
+ }
1232
+ }
1233
+
1234
+ stc.simple_tensors[tensor] = simple_tensors;
1235
+
1236
+ return GGML_STATUS_SUCCESS;
1237
+ }
1238
+
1239
+ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
1240
+ GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
1241
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
1242
+ buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next;
1243
+ return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor);
1244
+ }
1245
+
1246
+ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1247
+ const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1248
+ GGML_ASSERT(ggml_is_contiguous(tensor));
1249
+
1250
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1251
+
1252
+ if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
1253
+ GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1254
+ GGML_ASSERT(split_state.nr[0] != 0);
1255
+ GGML_ASSERT(tensor->ne[3] == 1);
1256
+
1257
+ size_t offset_data = 0;
1258
+ std::vector<size_t> simple_offsets(n_bufs, 0);
1259
+ if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
1260
+ GGML_ASSERT(tensor->ne[2] == 1);
1261
+
1262
+ const size_t row_stride = tensor->nb[1];
1263
+ GGML_ASSERT(offset % row_stride == 0);
1264
+ GGML_ASSERT(size % row_stride == 0);
1265
+ const int64_t row_start = offset / row_stride;
1266
+ const int64_t row_count = size / row_stride;
1267
+ GGML_ASSERT(row_start + row_count <= tensor->ne[1]);
1268
+
1269
+ const int64_t blck_size = ggml_blck_size(tensor->type);
1270
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1271
+ for (size_t r = 0; r < split_state.nr[s]; r++) {
1272
+ for (size_t j = 0; j < n_bufs; j++) {
1273
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1274
+ GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
1275
+ const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1276
+ ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
1277
+ simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes,
1278
+ row_count, simple_tensor->nb[1], tensor->nb[1]);
1279
+ offset_data += nbytes;
1280
+ simple_offsets[j] += nbytes;
1281
+ }
1282
+ }
1283
+ }
1284
+ GGML_ASSERT(offset_data*row_count == size);
1285
+ return;
1286
+ }
1287
+ GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1288
+
1289
+ const size_t row_stride = tensor->nb[2];
1290
+ GGML_ASSERT(offset % row_stride == 0);
1291
+ GGML_ASSERT(size % row_stride == 0);
1292
+ const int64_t row_start = offset / row_stride;
1293
+ const int64_t row_count = size / row_stride;
1294
+ GGML_ASSERT(row_start + row_count <= tensor->ne[2]);
1295
+
1296
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1297
+ for (size_t r = 0; r < split_state.nr[s]; r++) {
1298
+ for (size_t j = 0; j < n_bufs; j++) {
1299
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1300
+ const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1301
+ ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
1302
+ simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes,
1303
+ row_count, simple_tensor->nb[2], tensor->nb[2]);
1304
+ offset_data += nbytes;
1305
+ simple_offsets[j] += nbytes;
1306
+ }
1307
+ }
1308
+ }
1309
+ GGML_ASSERT(offset_data*row_count == size);
1310
+ return;
1311
+ }
1312
+
1313
+ switch (split_state.axis) {
1314
+ case GGML_BACKEND_SPLIT_AXIS_0:
1315
+ case GGML_BACKEND_SPLIT_AXIS_1:
1316
+ case GGML_BACKEND_SPLIT_AXIS_2: {
1317
+ // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1318
+ const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1319
+ GGML_ASSERT(offset % chunk_size_full == 0);
1320
+ GGML_ASSERT(size % chunk_size_full == 0);
1321
+ const int64_t i_start = offset /chunk_size_full;
1322
+ const int64_t i_stop = (offset + size)/chunk_size_full;
1323
+ size_t offset_j = 0;
1324
+ for (size_t j = 0; j < n_bufs; j++) {
1325
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1326
+ const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1327
+ if (chunk_size_j == 0) {
1328
+ continue;
1329
+ }
1330
+ const size_t simple_offset = i_start * chunk_size_j;
1331
+ ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
1332
+ offset_j += chunk_size_j;
1333
+ }
1334
+ GGML_ASSERT(offset_j == chunk_size_full);
1335
+ } break;
1336
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1337
+ for (size_t j = 0; j < n_bufs; j++) {
1338
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1339
+ ggml_backend_tensor_set(simple_tensor, data, offset, size);
1340
+ }
1341
+ } break;
1342
+ case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
1343
+ GGML_ASSERT(tensor->type == GGML_TYPE_F32);
1344
+ const int64_t ne = ggml_nelements(tensor);
1345
+ std::vector<float> tmp;
1346
+ tmp.reserve(ne);
1347
+ for (int64_t i = 0; i < ne; i++) {
1348
+ tmp.push_back(((const float *) data)[i] / n_bufs);
1349
+ }
1350
+ for (size_t j = 0; j < n_bufs; j++) {
1351
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1352
+ ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size);
1353
+ }
1354
+ } break;
1355
+ default: {
1356
+ GGML_ABORT("fatal error");
1357
+ }
1358
+ }
1359
+ }
1360
+
1361
+ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1362
+ const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
1363
+ GGML_ASSERT(ggml_is_contiguous(tensor));
1364
+
1365
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1366
+
1367
+ if (split_state.n_segments != 1 || split_state.nr[0] != 1) {
1368
+ GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
1369
+ GGML_ASSERT(split_state.nr[0] != 0);
1370
+ GGML_ASSERT(tensor->ne[3] == 1);
1371
+
1372
+ size_t offset_data = 0;
1373
+ std::vector<size_t> simple_offsets(n_bufs, 0);
1374
+ if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
1375
+ GGML_ASSERT(tensor->ne[2] == 1);
1376
+
1377
+ const size_t row_stride = tensor->nb[1];
1378
+ GGML_ASSERT(offset % row_stride == 0);
1379
+ GGML_ASSERT(size % row_stride == 0);
1380
+ const int64_t row_start = offset / row_stride;
1381
+ const int64_t row_count = size / row_stride;
1382
+ GGML_ASSERT(row_start + row_count <= tensor->ne[1]);
1383
+
1384
+ const int64_t blck_size = ggml_blck_size(tensor->type);
1385
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1386
+ for (size_t r = 0; r < split_state.nr[s]; r++) {
1387
+ for (size_t j = 0; j < n_bufs; j++) {
1388
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1389
+ GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
1390
+ const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
1391
+ ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
1392
+ simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes,
1393
+ row_count, simple_tensor->nb[1], tensor->nb[1]);
1394
+ offset_data += nbytes;
1395
+ simple_offsets[j] += nbytes;
1396
+ }
1397
+ }
1398
+ }
1399
+ GGML_ASSERT(offset_data*row_count == size);
1400
+ return;
1401
+ }
1402
+ GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);
1403
+
1404
+ const size_t row_stride = tensor->nb[2];
1405
+ GGML_ASSERT(offset % row_stride == 0);
1406
+ GGML_ASSERT(size % row_stride == 0);
1407
+ const int64_t row_start = offset / row_stride;
1408
+ const int64_t row_count = size / row_stride;
1409
+ GGML_ASSERT(row_start + row_count <= tensor->ne[2]);
1410
+
1411
+ for (size_t s = 0; s < split_state.n_segments; s++) {
1412
+ for (size_t r = 0; r < split_state.nr[s]; r++) {
1413
+ for (size_t j = 0; j < n_bufs; j++) {
1414
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1415
+ const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
1416
+ ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
1417
+ simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes,
1418
+ row_count, simple_tensor->nb[2], tensor->nb[2]);
1419
+ offset_data += nbytes;
1420
+ simple_offsets[j] += nbytes;
1421
+ }
1422
+ }
1423
+ }
1424
+ GGML_ASSERT(offset_data*row_count == size);
1425
+ return;
1426
+ }
1427
+
1428
+ switch (split_state.axis) {
1429
+ case GGML_BACKEND_SPLIT_AXIS_0:
1430
+ case GGML_BACKEND_SPLIT_AXIS_1:
1431
+ case GGML_BACKEND_SPLIT_AXIS_2: {
1432
+ // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1433
+ const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1434
+ GGML_ASSERT(offset % chunk_size_full == 0);
1435
+ GGML_ASSERT(size % chunk_size_full == 0);
1436
+ const int64_t i_start = offset /chunk_size_full;
1437
+ const int64_t i_stop = (offset + size)/chunk_size_full;
1438
+ size_t offset_j = 0;
1439
+ for (size_t j = 0; j < n_bufs; j++){
1440
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1441
+ const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1442
+ if (chunk_size_j == 0) {
1443
+ continue;
1444
+ }
1445
+ const size_t simple_offset = i_start * chunk_size_j;
1446
+ ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
1447
+ offset_j += chunk_size_j;
1448
+ }
1449
+ GGML_ASSERT(offset_j == chunk_size_full);
1450
+ } break;
1451
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1452
+ // TODO other simple backend may be better
1453
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
1454
+ ggml_backend_tensor_get(simple_tensor, data, offset, size);
1455
+ } break;
1456
+ default: {
1457
+ GGML_ABORT("fatal error");
1458
+ }
1459
+ }
1460
+ }
1461
+
1462
+ static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1463
+ const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer);
1464
+ for (size_t i = 0; i < n_buffers; i++) {
1465
+ ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value);
1466
+ }
1467
+ }
1468
+
1469
+ static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) {
1470
+ GGML_ASSERT(ggml_backend_buffer_is_meta(buffer));
1471
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context;
1472
+ for (size_t i = 0; i < buf_ctx->bufs.size(); i++) {
1473
+ ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i));
1474
+ }
1475
+ }
1476
+
1477
+ static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = {
1478
+ /* .free_buffer = */ ggml_backend_meta_buffer_free_buffer,
1479
+ /* .get_base = */ ggml_backend_meta_buffer_get_base,
1480
+ /* .init_tensor = */ ggml_backend_meta_buffer_init_tensor,
1481
+ /* .memset_tensor = */ nullptr, // TODO implement
1482
+ /* .set_tensor = */ ggml_backend_meta_buffer_set_tensor,
1483
+ /* .get_tensor = */ ggml_backend_meta_buffer_get_tensor,
1484
+ /* .set_tensor_2d = */ nullptr,
1485
+ /* .get_tensor_2d = */ nullptr,
1486
+ /* .cpy_tensor = */ nullptr,
1487
+ /* .clear = */ ggml_backend_meta_buffer_clear,
1488
+ /* .reset = */ ggml_backend_meta_buffer_reset,
1489
+ };
1490
+
1491
+ bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) {
1492
+ return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer;
1493
+ }
1494
+
1495
+ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1496
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
1497
+
1498
+ const ggml_init_params params = {
1499
+ /*.mem_size =*/ 1024*1024*ggml_tensor_overhead(), // FIXME
1500
+ /*.mem_buffer =*/ nullptr,
1501
+ /*.no_alloc =*/ true,
1502
+ };
1503
+ ggml_backend_meta_simple_tensor_container stc_static;
1504
+ ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts);
1505
+ ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts);
1506
+
1507
+ size_t max_size = 0;
1508
+ std::vector<ggml_backend_buffer_t> bufs;
1509
+ bufs.reserve(n_simple_bufts);
1510
+ for (size_t i = 0; i < n_simple_bufts; i++) {
1511
+ bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size));
1512
+ GGML_ASSERT(bufs.back() != nullptr);
1513
+ max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back()));
1514
+ }
1515
+ ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
1516
+
1517
+ return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size);
1518
+ }
1519
+
1520
+ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
1521
+ const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft);
1522
+
1523
+ constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals.
1524
+ const ggml_init_params params_static = {
1525
+ /*.mem_size =*/ ggml_get_mem_size(ctx),
1526
+ /*.mem_buffer =*/ nullptr,
1527
+ /*.no_alloc =*/ true,
1528
+ };
1529
+ const ggml_init_params params_compute = {
1530
+ /*.mem_size =*/ compute_headroom*ggml_get_mem_size(ctx),
1531
+ /*.mem_buffer =*/ nullptr,
1532
+ /*.no_alloc =*/ true,
1533
+ };
1534
+ ggml_backend_meta_simple_tensor_container stc_static (params_static, n_simple_bufts);
1535
+ ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts);
1536
+ ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts);
1537
+
1538
+ std::vector<ggml_backend_buffer_t> bufs(n_simple_bufts, nullptr);
1539
+ ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs);
1540
+
1541
+ ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0);
1542
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1543
+ t->buffer = meta_buf;
1544
+ ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t);
1545
+ t->data = (void *) 0x2000000000000000; // FIXME
1546
+ }
1547
+ for (size_t i = 0; i < n_simple_bufts; i++) {
1548
+ ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get();
1549
+ ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i);
1550
+
1551
+ // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL.
1552
+ // For those edge cases, allocate a dummy buffer instead.
1553
+ bool any_nonzero_slice = false;
1554
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1555
+ if (ggml_nelements(t) != 0) {
1556
+ any_nonzero_slice = true;
1557
+ break;
1558
+ }
1559
+ }
1560
+ if (any_nonzero_slice) {
1561
+ meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft));
1562
+ } else {
1563
+ meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0));
1564
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1565
+ t->buffer = meta_buf_ctx->bufs[i].get();
1566
+ }
1567
+ }
1568
+ GGML_ASSERT(meta_buf_ctx->bufs[i]);
1569
+ meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get()));
1570
+ }
1571
+ return meta_buf;
1572
+ }
1573
+
1574
+ //
1575
+ // meta backend
1576
+ //
1577
+
1578
+ static ggml_guid_t ggml_backend_meta_guid() {
1579
+ static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda};
1580
+ return &guid;
1581
+ }
1582
+
1583
+ struct ggml_backend_meta_context {
1584
+ struct cgraph_config {
1585
+ ggml_cgraph * cgraph_main = nullptr;
1586
+ int offset = 0; // Node offset vs. original graph
1587
+
1588
+ std::vector<ggml_cgraph *> cgraphs_aux;
1589
+ };
1590
+ struct backend_config {
1591
+ ggml_backend_t backend;
1592
+
1593
+ std::vector<cgraph_config> cgraphs;
1594
+ std::vector<ggml_tensor *> nodes;
1595
+ std::vector<ggml_backend_buffer_ptr> bufs;
1596
+
1597
+ backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) {
1598
+ bufs.resize(n_reduce_steps);
1599
+ }
1600
+ };
1601
+ std::string name;
1602
+ std::vector<backend_config> backend_configs;
1603
+ ggml_context_ptr ctx;
1604
+ std::vector<ggml_cgraph *> cgraphs_aux;
1605
+ std::vector<ggml_tensor *> nodes_aux;
1606
+ size_t n_reduce_steps;
1607
+ int max_nnodes = 0;
1608
+ size_t max_tmp_size = 0;
1609
+ size_t max_subgraphs = 0;
1610
+ size_t n_subgraphs = 0;
1611
+ uint64_t uid = 0;
1612
+
1613
+ void * comm_ctx = nullptr;
1614
+ ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr;
1615
+
1616
+ ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) {
1617
+ const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev);
1618
+ n_reduce_steps = std::ceil(std::log2(n_devs));
1619
+ name = "Meta(";
1620
+ std::vector<ggml_backend_t> simple_backends;
1621
+ backend_configs.reserve(n_devs);
1622
+ simple_backends.reserve(n_devs);
1623
+ for (size_t i = 0; i < n_devs; i++) {
1624
+ ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i);
1625
+ if (i > 0) {
1626
+ name += ",";
1627
+ }
1628
+ name += ggml_backend_dev_name(simple_dev);
1629
+ simple_backends.push_back(ggml_backend_dev_init(simple_dev, params));
1630
+ backend_configs.emplace_back(simple_backends.back(), n_reduce_steps);
1631
+ }
1632
+ name += ")";
1633
+
1634
+ if (n_devs > 1) {
1635
+ ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address(
1636
+ ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init");
1637
+ if (comm_init != nullptr) {
1638
+ comm_ctx = comm_init(simple_backends.data(), simple_backends.size());
1639
+ }
1640
+ }
1641
+ if (comm_ctx != nullptr) {
1642
+ comm_allreduce = (ggml_backend_comm_allreduce_tensor_t)
1643
+ ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg(
1644
+ ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor");
1645
+ GGML_ASSERT(comm_allreduce != nullptr);
1646
+ }
1647
+ }
1648
+
1649
+ ~ggml_backend_meta_context() {
1650
+ if (comm_ctx != nullptr) {
1651
+ ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address(
1652
+ ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free");
1653
+ GGML_ASSERT(comm_free != nullptr);
1654
+ comm_free(comm_ctx);
1655
+ }
1656
+ for (auto & bc : backend_configs) {
1657
+ ggml_backend_free(bc.backend);
1658
+ }
1659
+ }
1660
+ };
1661
+
1662
+ static const char * ggml_backend_meta_get_name(ggml_backend_t backend) {
1663
+ GGML_ASSERT(ggml_backend_is_meta(backend));
1664
+ const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context;
1665
+ return backend_ctx->name.c_str();
1666
+ }
1667
+
1668
+ static void ggml_backend_meta_free(ggml_backend_t backend) {
1669
+ GGML_ASSERT(ggml_backend_is_meta(backend));
1670
+ ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
1671
+ delete backend_ctx;
1672
+ delete backend;
1673
+ }
1674
+
1675
+ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1676
+ const size_t n_backends = ggml_backend_meta_n_backends(backend);
1677
+ GGML_ASSERT(offset == 0);
1678
+ GGML_ASSERT(ggml_is_contiguous(tensor));
1679
+
1680
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1681
+ GGML_ASSERT(split_state.n_segments == 1);
1682
+ GGML_ASSERT(split_state.nr[0] == 1);
1683
+
1684
+ switch (split_state.axis) {
1685
+ case GGML_BACKEND_SPLIT_AXIS_0:
1686
+ case GGML_BACKEND_SPLIT_AXIS_1:
1687
+ case GGML_BACKEND_SPLIT_AXIS_2: {
1688
+ // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1689
+ const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1690
+ GGML_ASSERT(offset % chunk_size_full == 0);
1691
+ GGML_ASSERT(size % chunk_size_full == 0);
1692
+ const int64_t i_start = offset /chunk_size_full;
1693
+ const int64_t i_stop = (offset + size)/chunk_size_full;
1694
+ size_t offset_j = 0;
1695
+ for (size_t j = 0; j < n_backends; j++){
1696
+ ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
1697
+ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1698
+ const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1699
+ if (chunk_size_j == 0) {
1700
+ continue;
1701
+ }
1702
+ ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j,
1703
+ i_stop - i_start, chunk_size_j, chunk_size_full);
1704
+ offset_j += chunk_size_j;
1705
+ }
1706
+ GGML_ASSERT(offset_j == chunk_size_full);
1707
+ } break;
1708
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1709
+ for (size_t j = 0; j < n_backends; j++) {
1710
+ ggml_backend_tensor_set_async(
1711
+ ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size);
1712
+ }
1713
+ } break;
1714
+ default: {
1715
+ GGML_ABORT("fatal error");
1716
+ }
1717
+ }
1718
+ }
1719
+
1720
+ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1721
+ const size_t n_backends = ggml_backend_meta_n_backends(backend);
1722
+ GGML_ASSERT(offset == 0);
1723
+ GGML_ASSERT(ggml_is_contiguous(tensor));
1724
+
1725
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
1726
+ GGML_ASSERT(split_state.n_segments == 1);
1727
+ GGML_ASSERT(split_state.nr[0] == 1);
1728
+
1729
+ switch (split_state.axis) {
1730
+ case GGML_BACKEND_SPLIT_AXIS_0:
1731
+ case GGML_BACKEND_SPLIT_AXIS_1:
1732
+ case GGML_BACKEND_SPLIT_AXIS_2: {
1733
+ // Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
1734
+ const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
1735
+ GGML_ASSERT(offset % chunk_size_full == 0);
1736
+ GGML_ASSERT(size % chunk_size_full == 0);
1737
+ const int64_t i_start = offset /chunk_size_full;
1738
+ const int64_t i_stop = (offset + size)/chunk_size_full;
1739
+ size_t offset_j = 0;
1740
+ for (size_t j = 0; j < n_backends; j++){
1741
+ ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
1742
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
1743
+ const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
1744
+ if (chunk_size_j == 0) {
1745
+ continue;
1746
+ }
1747
+ ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j,
1748
+ i_stop - i_start, chunk_size_j, chunk_size_full);
1749
+ offset_j += chunk_size_j;
1750
+ }
1751
+ GGML_ASSERT(offset_j == chunk_size_full);
1752
+ } break;
1753
+ case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
1754
+ // TODO other simple backend may be better
1755
+ ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0);
1756
+ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
1757
+ ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size);
1758
+ } break;
1759
+ default: {
1760
+ GGML_ABORT("fatal error");
1761
+ }
1762
+ }
1763
+ }
1764
+
1765
+ static void ggml_backend_meta_synchronize(ggml_backend_t backend) {
1766
+ const size_t n_backends = ggml_backend_meta_n_backends(backend);
1767
+ for (size_t i = 0; i < n_backends; i++) {
1768
+ ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i));
1769
+ }
1770
+ }
1771
+
1772
+ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1773
+ GGML_ASSERT(cgraph->grads == nullptr);
1774
+ const size_t n_backends = ggml_backend_meta_n_backends(backend);
1775
+ ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
1776
+
1777
+ // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend.
1778
+ const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid);
1779
+
1780
+ bool max_nnodes_raised = false;
1781
+ if (cgraph->n_nodes > backend_ctx->max_nnodes) {
1782
+ for (size_t j = 0; j < n_backends; j++) {
1783
+ auto & bcj = backend_ctx->backend_configs[j];
1784
+ bcj.nodes.resize(cgraph->n_nodes);
1785
+ bcj.cgraphs.resize(cgraph->n_nodes);
1786
+ }
1787
+ backend_ctx->max_nnodes = cgraph->n_nodes;
1788
+ max_nnodes_raised = true;
1789
+ assert(needs_rebuild);
1790
+ }
1791
+
1792
+ if (needs_rebuild) {
1793
+ std::set<ggml_backend_buffer_t> used_buffers;
1794
+ for (int i = 0; i < cgraph->n_leafs; i++) {
1795
+ if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) {
1796
+ used_buffers.emplace(cgraph->leafs[i]->buffer);
1797
+ }
1798
+ }
1799
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1800
+ if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) {
1801
+ used_buffers.emplace(cgraph->nodes[i]->buffer);
1802
+ }
1803
+ }
1804
+ for (ggml_backend_buffer_t buf : used_buffers) {
1805
+ ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context;
1806
+ buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1;
1807
+ ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next];
1808
+ for (ggml_context_ptr & ctx : stc.ctxs) {
1809
+ ggml_reset(ctx.get());
1810
+ }
1811
+ stc.simple_tensors.clear();
1812
+ }
1813
+ size_t n_subgraphs = 0;
1814
+ size_t max_tmp_size = 0;
1815
+
1816
+ for (size_t j = 0; j < n_backends; j++) {
1817
+ auto & bcj = backend_ctx->backend_configs[j];
1818
+
1819
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1820
+ ggml_tensor * node = cgraph->nodes[i];
1821
+ if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
1822
+ // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes.
1823
+ // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash.
1824
+ bcj.nodes[i] = node;
1825
+ continue;
1826
+ }
1827
+ bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j);
1828
+ GGML_ASSERT(bcj.nodes[i]);
1829
+ }
1830
+ }
1831
+
1832
+ {
1833
+ // For MoE models it may make sense to delay the AllReduce in order to reduce I/O:
1834
+ auto get_i_delayed = [&](const int i) -> int {
1835
+ int id = i; // i_delayed
1836
+ int idr = i; // i_delayed return, last safe return value
1837
+
1838
+ ggml_tensor * node = cgraph->nodes[id];
1839
+ int32_t n_used = ggml_node_get_use_count(cgraph, id);
1840
+
1841
+ // Skip MIRRORED nodes that don't consume node
1842
+ auto skip_unrelated = [&]() {
1843
+ while (id + 1 < cgraph->n_nodes) {
1844
+ ggml_tensor * next = cgraph->nodes[id+1];
1845
+ if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1846
+ break;
1847
+ }
1848
+ bool safe = true;
1849
+ for (int s = 0; s < GGML_MAX_SRC; s++) {
1850
+ if (next->src[s] == nullptr) {
1851
+ continue;
1852
+ }
1853
+ if (next->src[s] == node) {
1854
+ safe = false;
1855
+ break;
1856
+ }
1857
+ if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1858
+ safe = false;
1859
+ break;
1860
+ }
1861
+ }
1862
+ if (!safe) {
1863
+ break;
1864
+ }
1865
+ id++;
1866
+ }
1867
+ };
1868
+
1869
+ skip_unrelated();
1870
+ if (id + 1 >= cgraph->n_nodes) {
1871
+ return idr;
1872
+ }
1873
+ {
1874
+ ggml_tensor * next = cgraph->nodes[id+1];
1875
+ if (next->op == GGML_OP_ADD_ID && next->src[0] == node &&
1876
+ ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL &&
1877
+ ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1878
+ node = next;
1879
+ id++;
1880
+ idr = id;
1881
+ n_used = ggml_node_get_use_count(cgraph, id);
1882
+ }
1883
+ }
1884
+ // Chain of MULs with MIRRORED src[1]
1885
+ while (true) {
1886
+ skip_unrelated();
1887
+ if (id + 1 >= cgraph->n_nodes) {
1888
+ return idr;
1889
+ }
1890
+ ggml_tensor * next = cgraph->nodes[id+1];
1891
+ if (next->op == GGML_OP_MUL && next->src[0] == node &&
1892
+ ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1893
+ node = next;
1894
+ id++;
1895
+ idr = id;
1896
+ n_used = ggml_node_get_use_count(cgraph, id);
1897
+ } else {
1898
+ break;
1899
+ }
1900
+ }
1901
+
1902
+ if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) {
1903
+ return idr;
1904
+ }
1905
+ for (int32_t k = 0; k < n_used; k++) {
1906
+ ggml_tensor * next = cgraph->nodes[id+1];
1907
+ if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] ||
1908
+ next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] ||
1909
+ ggml_node_get_use_count(cgraph, id+1) != 1) {
1910
+ return idr;
1911
+ }
1912
+ id++;
1913
+ }
1914
+ {
1915
+ ggml_tensor * next = cgraph->nodes[id+1];
1916
+ if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] ||
1917
+ next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
1918
+ return idr;
1919
+ }
1920
+ id++;
1921
+ }
1922
+ for (int32_t k = 0; k < n_used - 2; k++) {
1923
+ ggml_tensor * next = cgraph->nodes[id+1];
1924
+ if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] ||
1925
+ next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) {
1926
+ return idr;
1927
+ }
1928
+ id++;
1929
+ }
1930
+ idr = id;
1931
+ return idr;
1932
+ };
1933
+
1934
+ int i_start = 0;
1935
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1936
+ ggml_tensor * node = cgraph->nodes[i];
1937
+ if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) {
1938
+ continue;
1939
+ }
1940
+ const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false);
1941
+ if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
1942
+ max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node));
1943
+ }
1944
+ const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL;
1945
+ if (!new_subgraph) {
1946
+ continue;
1947
+ }
1948
+
1949
+ const int i_delayed = get_i_delayed(i);
1950
+
1951
+ // If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices.
1952
+ // A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has
1953
+ // its compute flag disabled and thus gets its data zeroed out.
1954
+ // If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled.
1955
+ if (i_delayed > i) {
1956
+ for (size_t j = 0; j < n_backends; j++) {
1957
+ auto & bcj = backend_ctx->backend_configs[j];
1958
+ if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
1959
+ for (int ii = i + 1; ii <= i_delayed; ii++) {
1960
+ bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
1961
+ }
1962
+ }
1963
+ }
1964
+ }
1965
+
1966
+ i = i_delayed;
1967
+
1968
+ for (size_t j = 0; j < n_backends; j++) {
1969
+ auto & bcj = backend_ctx->backend_configs[j];
1970
+ bcj.cgraphs[n_subgraphs].offset = i_start;
1971
+ }
1972
+ n_subgraphs++;
1973
+ i_start = i + 1;
1974
+ }
1975
+ GGML_ASSERT(i_start == cgraph->n_nodes);
1976
+ }
1977
+
1978
+ backend_ctx->uid = cgraph->uid;
1979
+ backend_ctx->n_subgraphs = n_subgraphs;
1980
+
1981
+ if (max_tmp_size > backend_ctx->max_tmp_size) {
1982
+ for (size_t j = 0; j < n_backends; j++) {
1983
+ auto & bcj = backend_ctx->backend_configs[j];
1984
+ for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) {
1985
+ bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size));
1986
+ }
1987
+ }
1988
+ backend_ctx->max_tmp_size = max_tmp_size;
1989
+ }
1990
+
1991
+ if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) {
1992
+ backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs);
1993
+ const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device
1994
+ const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device
1995
+ const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads);
1996
+ const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads);
1997
+ const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead();
1998
+ const ggml_init_params params = {
1999
+ /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux),
2000
+ /*.mem_buffer =*/ nullptr,
2001
+ /*.no_alloc =*/ true,
2002
+ };
2003
+ backend_ctx->ctx.reset(ggml_init(params));
2004
+ for (size_t j = 0; j < n_backends; j++) {
2005
+ auto & bcj = backend_ctx->backend_configs[j];
2006
+ for (size_t i = 0; i < n_subgraphs; i++) {
2007
+ bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false);
2008
+ }
2009
+ }
2010
+ backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs);
2011
+ for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) {
2012
+ backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads);
2013
+ }
2014
+ backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs);
2015
+ for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) {
2016
+ backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1);
2017
+ }
2018
+ }
2019
+
2020
+ for (size_t j = 0; j < n_backends; j++) {
2021
+ auto & bcj = backend_ctx->backend_configs[j];
2022
+ for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) {
2023
+ ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main;
2024
+ const size_t i_node_start = bcj.cgraphs[i_graph].offset;
2025
+ const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes;
2026
+ cgraph_ij->n_nodes = i_node_stop - i_node_start;
2027
+ ggml_hash_set_reset(&cgraph_ij->visited_hash_set);
2028
+ for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) {
2029
+ ggml_tensor * node_ij = bcj.nodes[i_node];
2030
+ cgraph_ij->nodes[i_node - i_node_start] = node_ij;
2031
+ const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]);
2032
+ const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij);
2033
+ cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig];
2034
+ }
2035
+ cgraph_ij->uid = ggml_graph_next_uid();
2036
+ }
2037
+ }
2038
+ }
2039
+
2040
+ size_t iga = 0; // i graph aux
2041
+ size_t ina = 0; // i node aux
2042
+
2043
+ auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * {
2044
+ ggml_tensor * ret = backend_ctx->nodes_aux[ina++];
2045
+ memset(ret, 0, sizeof(ggml_tensor));
2046
+ ret->op = GGML_OP_NONE;
2047
+ ret->type = t->type;
2048
+ for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
2049
+ ret->ne[k] = t->ne[k];
2050
+ ret->nb[k] = t->nb[k];
2051
+ }
2052
+ return ret;
2053
+ };
2054
+ auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) {
2055
+ auto & bcj = backend_ctx->backend_configs[j];
2056
+ ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf];
2057
+ if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) {
2058
+ buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size));
2059
+ }
2060
+ tensor->buffer = buf_ptr.get();
2061
+ tensor->data = ggml_backend_buffer_get_base(buf_ptr.get());
2062
+ };
2063
+ // FIXME usage_counts
2064
+ auto get_cgraph_aux = [&]() -> ggml_cgraph * {
2065
+ ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++];
2066
+ return ret;
2067
+ };
2068
+
2069
+ // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable:
2070
+ auto allreduce_fallback = [&](size_t i) -> ggml_status {
2071
+ std::vector<ggml_cgraph *> step_cgraphs(n_backends, nullptr);
2072
+
2073
+ // Zero out nodes that were disabled due to having a zero-sized slice:
2074
+ for (size_t j = 0; j < n_backends; j++) {
2075
+ auto & bcj = backend_ctx->backend_configs[j];
2076
+ ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1];
2077
+ if (node->flags & GGML_TENSOR_FLAG_COMPUTE) {
2078
+ continue;
2079
+ }
2080
+ ggml_tensor * node_zero = get_node_aux(node);
2081
+ node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN
2082
+ node_zero->src[0] = node;
2083
+ ggml_set_op_params_f32(node_zero, 0, 0.0f);
2084
+ node_zero->data = node->data;
2085
+ node_zero->buffer = node->buffer;
2086
+ node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE;
2087
+
2088
+ step_cgraphs[j] = get_cgraph_aux();
2089
+ step_cgraphs[j]->nodes[0] = node_zero;
2090
+ step_cgraphs[j]->n_nodes = 1;
2091
+ const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
2092
+ if (status != GGML_STATUS_SUCCESS) {
2093
+ return status;
2094
+ }
2095
+ }
2096
+ std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
2097
+
2098
+ auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) {
2099
+ assert(step_cgraphs[j_dst] == nullptr);
2100
+ auto & bcj_src = backend_ctx->backend_configs[j_src];
2101
+ auto & bcj_dst = backend_ctx->backend_configs[j_dst];
2102
+
2103
+ ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1];
2104
+ ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1];
2105
+ GGML_ASSERT(ggml_is_contiguous(node_src));
2106
+ GGML_ASSERT(ggml_is_contiguous(node_dst));
2107
+
2108
+ ggml_tensor * node_tmp = get_node_aux(node_dst);
2109
+ set_tmp_data(node_tmp, j_dst, i_buf);
2110
+
2111
+ ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp);
2112
+
2113
+ ggml_tensor * node_red = get_node_aux(node_dst);
2114
+ node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src;
2115
+ node_red->view_offs = node_dst->view_offs;
2116
+ node_red->op = GGML_OP_ADD;
2117
+ node_red->src[0] = node_dst;
2118
+ node_red->src[1] = node_tmp;
2119
+ node_red->flags |= GGML_TENSOR_FLAG_COMPUTE;
2120
+ ggml_backend_view_init(node_red);
2121
+
2122
+ ggml_cgraph * cgraph_aux = get_cgraph_aux();
2123
+ cgraph_aux->nodes[0] = node_red;
2124
+ cgraph_aux->n_nodes = 1;
2125
+ step_cgraphs[j_dst] = cgraph_aux;
2126
+ };
2127
+
2128
+ size_t offset_j = n_backends/2;
2129
+ while ((offset_j & (offset_j - 1)) != 0) {
2130
+ offset_j--;
2131
+ }
2132
+ const size_t offset_j_max = offset_j;
2133
+ size_t i_buf = 0;
2134
+
2135
+ // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction:
2136
+ for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) {
2137
+ const size_t j_dst = j_src - 2*offset_j_max;
2138
+ push_data(j_src, j_dst, i_buf);
2139
+ const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]);
2140
+ if (status != GGML_STATUS_SUCCESS) {
2141
+ return status;
2142
+ }
2143
+ i_buf = 1;
2144
+ }
2145
+
2146
+ // Butterfly reduction:
2147
+ for (; offset_j >= 1; offset_j /= 2) {
2148
+ std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr);
2149
+
2150
+ for (size_t j = 0; j < 2*offset_j_max; j++) {
2151
+ const size_t j_other = j ^ offset_j;
2152
+ if (j_other >= n_backends) {
2153
+ continue;
2154
+ }
2155
+ push_data(j, j_other, i_buf);
2156
+ }
2157
+
2158
+ for (size_t j = 0; j < 2*offset_j_max; j++) {
2159
+ if (step_cgraphs[j] == nullptr) {
2160
+ continue;
2161
+ }
2162
+ auto & bcj = backend_ctx->backend_configs[j];
2163
+ const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]);
2164
+ if (status != GGML_STATUS_SUCCESS) {
2165
+ return status;
2166
+ }
2167
+ }
2168
+ i_buf++;
2169
+ }
2170
+ assert(i_buf == backend_ctx->n_reduce_steps);
2171
+
2172
+ // If n_backends is not a power of 2, copy back the reduced tensors to the excess:
2173
+ for (size_t j = 2*offset_j_max; j < n_backends; j++) {
2174
+ auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max];
2175
+ auto & bcj_dst = backend_ctx->backend_configs[j];
2176
+
2177
+ ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1];
2178
+ ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1];
2179
+ ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst);
2180
+ }
2181
+
2182
+ return GGML_STATUS_SUCCESS;
2183
+ };
2184
+
2185
+
2186
+ for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) {
2187
+ for (size_t j = 0; j < n_backends; j++) {
2188
+ auto & bcj = backend_ctx->backend_configs[j];
2189
+ const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main);
2190
+ if (status != GGML_STATUS_SUCCESS) {
2191
+ return status;
2192
+ }
2193
+ }
2194
+
2195
+ if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) {
2196
+ bool backend_allreduce_success = false;
2197
+ if (backend_ctx->comm_ctx) {
2198
+ std::vector<ggml_tensor *> nodes;
2199
+ nodes.reserve(n_backends);
2200
+ for (size_t j = 0; j < n_backends; j++) {
2201
+ auto & bcj = backend_ctx->backend_configs[j];
2202
+ ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
2203
+ nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
2204
+ }
2205
+ backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data());
2206
+ }
2207
+
2208
+ if (!backend_allreduce_success) {
2209
+ const ggml_status status = allreduce_fallback(i);
2210
+ if (status != GGML_STATUS_SUCCESS) {
2211
+ return status;
2212
+ }
2213
+ }
2214
+ }
2215
+ }
2216
+ return GGML_STATUS_SUCCESS;
2217
+ }
2218
+
2219
+ static const ggml_backend_i ggml_backend_meta_i = {
2220
+ /* .get_name = */ ggml_backend_meta_get_name,
2221
+ /* .free = */ ggml_backend_meta_free,
2222
+ /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async,
2223
+ /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async,
2224
+ /* .set_tensor_2d_async = */ nullptr,
2225
+ /* .get_tensor_2d_async = */ nullptr,
2226
+ /* .cpy_tensor_async = */ nullptr,
2227
+ /* .synchronize = */ ggml_backend_meta_synchronize,
2228
+ /* .graph_plan_create = */ nullptr,
2229
+ /* .graph_plan_free = */ nullptr,
2230
+ /* .graph_plan_update = */ nullptr,
2231
+ /* .graph_plan_compute = */ nullptr,
2232
+ /* .graph_compute = */ ggml_backend_meta_graph_compute,
2233
+ /* .event_record = */ nullptr,
2234
+ /* .event_wait = */ nullptr,
2235
+ /* .graph_optimize = */ nullptr,
2236
+ };
2237
+
2238
+ bool ggml_backend_is_meta(ggml_backend_t backend) {
2239
+ return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name;
2240
+ }
2241
+
2242
+ static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) {
2243
+ ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params);
2244
+
2245
+ ggml_backend_t backend = new struct ggml_backend;
2246
+ backend->guid = ggml_backend_meta_guid();
2247
+ backend->iface = ggml_backend_meta_i;
2248
+ backend->device = dev;
2249
+ backend->context = backend_ctx;
2250
+ return backend;
2251
+ }
2252
+
2253
+ size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) {
2254
+ GGML_ASSERT(ggml_backend_is_meta(meta_backend));
2255
+ const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
2256
+ return backend_ctx->backend_configs.size();
2257
+ }
2258
+
2259
+ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) {
2260
+ GGML_ASSERT(ggml_backend_is_meta(meta_backend));
2261
+ const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
2262
+ return backend_ctx->backend_configs[index].backend;
2263
+ }