whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -29,9 +29,18 @@
29
29
  #include <cstring>
30
30
  #include <fstream>
31
31
  #include <filesystem>
32
+ #include <algorithm>
33
+
34
+ static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
35
+
36
+ #define LOG_DBG(...) \
37
+ do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)
38
+
32
39
 
33
40
  namespace fs = std::filesystem;
34
41
 
42
+ static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
43
+
35
44
  #ifdef _WIN32
36
45
  typedef SOCKET sockfd_t;
37
46
  using ssize_t = __int64;
@@ -44,7 +53,7 @@ struct socket_t {
44
53
  sockfd_t fd;
45
54
  socket_t(sockfd_t fd) : fd(fd) {}
46
55
  ~socket_t() {
47
- GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
56
+ LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
48
57
  #ifdef _WIN32
49
58
  closesocket(this->fd);
50
59
  #else
@@ -96,9 +105,13 @@ enum rpc_cmd {
96
105
  RPC_CMD_INIT_TENSOR,
97
106
  RPC_CMD_GET_ALLOC_SIZE,
98
107
  RPC_CMD_HELLO,
108
+ RPC_CMD_DEVICE_COUNT,
109
+ RPC_CMD_GRAPH_RECOMPUTE,
99
110
  RPC_CMD_COUNT,
100
111
  };
101
112
 
113
+ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
114
+
102
115
  // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
103
116
  const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
104
117
 
@@ -108,8 +121,14 @@ struct rpc_msg_hello_rsp {
108
121
  uint8_t patch;
109
122
  };
110
123
 
124
+ struct rpc_msg_device_count_rsp {
125
+ uint32_t device_count;
126
+ };
127
+
111
128
  struct rpc_msg_get_alloc_size_req {
129
+ uint32_t device;
112
130
  rpc_tensor tensor;
131
+ rpc_tensor srcs[GGML_MAX_SRC];
113
132
  };
114
133
 
115
134
  struct rpc_msg_get_alloc_size_rsp {
@@ -121,6 +140,7 @@ struct rpc_msg_init_tensor_req {
121
140
  };
122
141
 
123
142
  struct rpc_msg_alloc_buffer_req {
143
+ uint32_t device;
124
144
  uint64_t size;
125
145
  };
126
146
 
@@ -129,10 +149,18 @@ struct rpc_msg_alloc_buffer_rsp {
129
149
  uint64_t remote_size;
130
150
  };
131
151
 
152
+ struct rpc_msg_get_alignment_req {
153
+ uint32_t device;
154
+ };
155
+
132
156
  struct rpc_msg_get_alignment_rsp {
133
157
  uint64_t alignment;
134
158
  };
135
159
 
160
+ struct rpc_msg_get_max_size_req {
161
+ uint32_t device;
162
+ };
163
+
136
164
  struct rpc_msg_get_max_size_rsp {
137
165
  uint64_t max_size;
138
166
  };
@@ -179,14 +207,19 @@ struct rpc_msg_copy_tensor_rsp {
179
207
  uint8_t result;
180
208
  };
181
209
 
182
- struct rpc_msg_graph_compute_rsp {
183
- uint8_t result;
210
+ struct rpc_msg_get_device_memory_req {
211
+ uint32_t device;
184
212
  };
185
213
 
186
214
  struct rpc_msg_get_device_memory_rsp {
187
215
  uint64_t free_mem;
188
216
  uint64_t total_mem;
189
217
  };
218
+
219
+ struct rpc_msg_graph_recompute_req {
220
+ uint32_t device;
221
+ };
222
+
190
223
  #pragma pack(pop)
191
224
 
192
225
  // RPC data structures
@@ -198,14 +231,41 @@ static ggml_guid_t ggml_backend_rpc_guid() {
198
231
 
199
232
  struct ggml_backend_rpc_buffer_type_context {
200
233
  std::string endpoint;
234
+ uint32_t device;
201
235
  std::string name;
202
- size_t alignment;
203
- size_t max_size;
236
+ size_t alignment;
237
+ size_t max_size;
238
+ };
239
+
240
+ struct graph_cache {
241
+
242
+ bool is_cached(const ggml_cgraph * cgraph) {
243
+ if ((int)last_graph.size() != cgraph->n_nodes) {
244
+ return false;
245
+ }
246
+ for (int i = 0; i < cgraph->n_nodes; i++) {
247
+ if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
248
+ return false;
249
+ }
250
+ }
251
+ return true;
252
+ }
253
+
254
+ void add(const ggml_cgraph * cgraph) {
255
+ last_graph.resize(cgraph->n_nodes);
256
+ for (int i = 0; i < cgraph->n_nodes; i++) {
257
+ memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
258
+ }
259
+ }
260
+
261
+ std::vector<ggml_tensor> last_graph;
204
262
  };
205
263
 
206
264
  struct ggml_backend_rpc_context {
207
265
  std::string endpoint;
266
+ uint32_t device;
208
267
  std::string name;
268
+ graph_cache gc;
209
269
  };
210
270
 
211
271
  struct ggml_backend_rpc_buffer_context {
@@ -262,14 +322,14 @@ static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
262
322
  return nullptr;
263
323
  }
264
324
  if (!set_no_delay(sockfd)) {
265
- fprintf(stderr, "Failed to set TCP_NODELAY\n");
325
+ GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
266
326
  return nullptr;
267
327
  }
268
328
  addr.sin_family = AF_INET;
269
329
  addr.sin_port = htons(port);
270
330
  struct hostent * server = gethostbyname(host);
271
331
  if (server == NULL) {
272
- fprintf(stderr, "Cannot resolve host '%s'\n", host);
332
+ GGML_LOG_ERROR("Cannot resolve host '%s'\n", host);
273
333
  return nullptr;
274
334
  }
275
335
  memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
@@ -286,7 +346,7 @@ static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
286
346
  return nullptr;
287
347
  }
288
348
  if (!set_no_delay(client_socket_fd)) {
289
- fprintf(stderr, "Failed to set TCP_NODELAY\n");
349
+ GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
290
350
  return nullptr;
291
351
  }
292
352
  return client_socket;
@@ -299,11 +359,11 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
299
359
  return nullptr;
300
360
  }
301
361
  if (!set_reuse_addr(sockfd)) {
302
- fprintf(stderr, "Failed to set SO_REUSEADDR\n");
362
+ GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n");
303
363
  return nullptr;
304
364
  }
305
365
  if (inet_addr(host) == INADDR_NONE) {
306
- fprintf(stderr, "Invalid host address: %s\n", host);
366
+ GGML_LOG_ERROR("Invalid host address: %s\n", host);
307
367
  return nullptr;
308
368
  }
309
369
  struct sockaddr_in serv_addr;
@@ -323,11 +383,14 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
323
383
  static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
324
384
  size_t bytes_sent = 0;
325
385
  while (bytes_sent < size) {
326
- ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
386
+ size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
387
+ ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);
327
388
  if (n < 0) {
389
+ GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
390
+ bytes_sent, size_to_send);
328
391
  return false;
329
392
  }
330
- bytes_sent += n;
393
+ bytes_sent += (size_t)n;
331
394
  }
332
395
  return true;
333
396
  }
@@ -335,11 +398,18 @@ static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
335
398
  static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
336
399
  size_t bytes_recv = 0;
337
400
  while (bytes_recv < size) {
338
- ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
339
- if (n <= 0) {
401
+ size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
402
+ ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);
403
+ if (n < 0) {
404
+ GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
405
+ bytes_recv, size_to_recv);
406
+ return false;
407
+ }
408
+ if (n == 0) {
409
+ LOG_DBG("recv returned 0 (peer closed?)\n");
340
410
  return false;
341
411
  }
342
- bytes_recv += n;
412
+ bytes_recv += (size_t)n;
343
413
  }
344
414
  return true;
345
415
  }
@@ -370,7 +440,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
370
440
  try {
371
441
  input.resize(size);
372
442
  } catch (const std::bad_alloc & e) {
373
- fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
443
+ GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
374
444
  return false;
375
445
  }
376
446
  return recv_data(sockfd, input.data(), size);
@@ -430,11 +500,11 @@ static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
430
500
  bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
431
501
  RPC_STATUS_ASSERT(status);
432
502
  if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
433
- fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
503
+ GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
434
504
  return false;
435
505
  }
436
506
  if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
437
- fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
507
+ GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
438
508
  }
439
509
  return true;
440
510
  }
@@ -454,6 +524,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
454
524
  std::string host;
455
525
  int port;
456
526
  if (!parse_endpoint(endpoint, host, port)) {
527
+ GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
457
528
  return nullptr;
458
529
  }
459
530
  #ifdef _WIN32
@@ -475,7 +546,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
475
546
  if (!check_server_version(sock)) {
476
547
  return nullptr;
477
548
  }
478
- GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
549
+ LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
479
550
  sockets[endpoint] = sock;
480
551
  return sock;
481
552
  }
@@ -501,14 +572,23 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
501
572
  return ctx->base_ptr;
502
573
  }
503
574
 
575
+ static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
576
+ return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
577
+ }
578
+
504
579
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
505
580
  rpc_tensor result;
581
+ if (!tensor) {
582
+ memset(&result, 0, sizeof(result));
583
+ return result;
584
+ }
585
+
506
586
  result.id = reinterpret_cast<uint64_t>(tensor);
507
587
  result.type = tensor->type;
508
- if (tensor->buffer) {
588
+ if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
509
589
  ggml_backend_buffer_t buffer = tensor->buffer;
510
590
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
511
- result.buffer = ctx->remote_ptr;
591
+ result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
512
592
  } else {
513
593
  result.buffer = 0;
514
594
  }
@@ -590,22 +670,25 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
590
670
  }
591
671
 
592
672
  static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
593
- // check if src and dst are on the same server
594
- ggml_backend_buffer_t src_buffer = src->buffer;
595
- ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
596
- ggml_backend_buffer_t dst_buffer = dst->buffer;
597
- ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
598
- if (src_ctx->sock != dst_ctx->sock) {
599
- return false;
673
+ if (ggml_backend_buffer_is_rpc(src->buffer)) {
674
+ // check if src and dst are on the same server
675
+ ggml_backend_buffer_t src_buffer = src->buffer;
676
+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
677
+ ggml_backend_buffer_t dst_buffer = dst->buffer;
678
+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
679
+ if (src_ctx->sock != dst_ctx->sock) {
680
+ return false;
681
+ }
682
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
683
+ rpc_msg_copy_tensor_req request;
684
+ request.src = serialize_tensor(src);
685
+ request.dst = serialize_tensor(dst);
686
+ rpc_msg_copy_tensor_rsp response;
687
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
688
+ RPC_STATUS_ASSERT(status);
689
+ return response.result;
600
690
  }
601
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
602
- rpc_msg_copy_tensor_req request;
603
- request.src = serialize_tensor(src);
604
- request.dst = serialize_tensor(dst);
605
- rpc_msg_copy_tensor_rsp response;
606
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
607
- RPC_STATUS_ASSERT(status);
608
- return response.result;
691
+ return false;
609
692
  }
610
693
 
611
694
  static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -634,7 +717,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
634
717
 
635
718
  static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
636
719
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
637
- rpc_msg_alloc_buffer_req request = {size};
720
+ rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
638
721
  rpc_msg_alloc_buffer_rsp response;
639
722
  auto sock = get_socket(buft_ctx->endpoint);
640
723
  bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
@@ -650,9 +733,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
650
733
  }
651
734
  }
652
735
 
653
- static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
736
+ static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
737
+ rpc_msg_get_alignment_req request = {device};
654
738
  rpc_msg_get_alignment_rsp response;
655
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
739
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
656
740
  RPC_STATUS_ASSERT(status);
657
741
  return response.alignment;
658
742
  }
@@ -662,9 +746,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
662
746
  return buft_ctx->alignment;
663
747
  }
664
748
 
665
- static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
749
+ static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
750
+ rpc_msg_get_max_size_req request = {device};
666
751
  rpc_msg_get_max_size_rsp response;
667
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
752
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
668
753
  RPC_STATUS_ASSERT(status);
669
754
  return response.max_size;
670
755
  }
@@ -675,23 +760,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
675
760
  }
676
761
 
677
762
  static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
763
+ // should we query the remote server for the actual size
764
+ bool rpc_get = false;
765
+
678
766
  // See comments in init_tensor.
679
- if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
767
+ rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
768
+
769
+ // ops that require additional memory for fleeting data on certain backends
770
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15966
771
+ rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
772
+ rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
773
+
774
+ if (rpc_get) {
680
775
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
681
776
  auto sock = get_socket(buft_ctx->endpoint);
682
777
 
683
- rpc_msg_get_alloc_size_req request;
778
+ rpc_msg_get_alloc_size_req request = {
779
+ /*.device =*/ buft_ctx->device,
780
+ /*.tensor =*/ serialize_tensor(tensor),
781
+ /*.srcs =*/ {},
782
+ };
684
783
 
685
- request.tensor = serialize_tensor(tensor);
784
+ // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
785
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
786
+ request.srcs[i] = serialize_tensor(tensor->src[i]);
787
+ }
686
788
 
789
+ // TODO: cache the alloc responses to avoid extra RPC calls?
687
790
  rpc_msg_get_alloc_size_rsp response;
688
791
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
689
792
  RPC_STATUS_ASSERT(status);
690
793
 
691
794
  return response.alloc_size;
692
- } else {
693
- return ggml_nbytes(tensor);
694
795
  }
796
+
797
+ return ggml_nbytes(tensor);
695
798
  }
696
799
 
697
800
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -735,7 +838,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
735
838
  tensors.push_back(serialize_tensor(tensor));
736
839
  }
737
840
 
738
- static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
841
+ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
739
842
  uint32_t n_nodes = cgraph->n_nodes;
740
843
  std::vector<rpc_tensor> tensors;
741
844
  std::unordered_set<ggml_tensor*> visited;
@@ -743,29 +846,45 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
743
846
  add_tensor(cgraph->nodes[i], tensors, visited);
744
847
  }
745
848
  // serialization format:
746
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
849
+ // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
747
850
  uint32_t n_tensors = tensors.size();
748
- int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
851
+ int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
749
852
  output.resize(output_size, 0);
750
- memcpy(output.data(), &n_nodes, sizeof(n_nodes));
853
+ uint8_t * dest = output.data();
854
+ memcpy(dest, &device, sizeof(device));
855
+ dest += sizeof(device);
856
+ memcpy(dest, &n_nodes, sizeof(n_nodes));
857
+ dest += sizeof(n_nodes);
751
858
  for (uint32_t i = 0; i < n_nodes; i++) {
752
- memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
859
+ memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
753
860
  }
754
- uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
755
- *out_ntensors = n_tensors;
756
- rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
861
+ dest += n_nodes * sizeof(uint64_t);
862
+ memcpy(dest, &n_tensors, sizeof(n_tensors));
863
+ dest += sizeof(n_tensors);
864
+ rpc_tensor * out_tensors = (rpc_tensor *)dest;
757
865
  memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
758
866
  }
759
867
 
760
868
  static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
761
869
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
762
- std::vector<uint8_t> input;
763
- serialize_graph(cgraph, input);
764
- rpc_msg_graph_compute_rsp response;
765
- auto sock = get_socket(rpc_ctx->endpoint);
766
- bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
767
- RPC_STATUS_ASSERT(status);
768
- return (enum ggml_status)response.result;
870
+
871
+ GGML_ASSERT(cgraph->n_nodes > 0);
872
+ bool reuse = rpc_ctx->gc.is_cached(cgraph);
873
+ if (reuse) {
874
+ rpc_msg_graph_recompute_req request;
875
+ request.device = rpc_ctx->device;
876
+ auto sock = get_socket(rpc_ctx->endpoint);
877
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
878
+ RPC_STATUS_ASSERT(status);
879
+ } else {
880
+ rpc_ctx->gc.add(cgraph);
881
+ std::vector<uint8_t> input;
882
+ serialize_graph(rpc_ctx->device, cgraph, input);
883
+ auto sock = get_socket(rpc_ctx->endpoint);
884
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
885
+ RPC_STATUS_ASSERT(status);
886
+ }
887
+ return GGML_STATUS_SUCCESS;
769
888
  }
770
889
 
771
890
  static ggml_backend_i ggml_backend_rpc_interface = {
@@ -782,51 +901,57 @@ static ggml_backend_i ggml_backend_rpc_interface = {
782
901
  /* .graph_compute = */ ggml_backend_rpc_graph_compute,
783
902
  /* .event_record = */ NULL,
784
903
  /* .event_wait = */ NULL,
904
+ /* .graph_optimize = */ NULL,
785
905
  };
786
906
 
787
- ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
907
+ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
788
908
  static std::mutex mutex;
789
909
  std::lock_guard<std::mutex> lock(mutex);
910
+ std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
790
911
  // NOTE: buffer types are allocated and never freed; this is by design
791
912
  static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
792
- auto it = buft_map.find(endpoint);
913
+ auto it = buft_map.find(buft_name);
793
914
  if (it != buft_map.end()) {
794
915
  return it->second;
795
916
  }
796
917
  auto sock = get_socket(endpoint);
797
918
  if (sock == nullptr) {
798
- fprintf(stderr, "Failed to connect to %s\n", endpoint);
919
+ GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
799
920
  return nullptr;
800
921
  }
801
- size_t alignment = get_alignment(sock);
802
- size_t max_size = get_max_size(sock);
922
+ size_t alignment = get_alignment(sock, device);
923
+ size_t max_size = get_max_size(sock, device);
803
924
  ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
804
925
  /* .endpoint = */ endpoint,
805
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
926
+ /* .device = */ device,
927
+ /* .name = */ buft_name,
806
928
  /* .alignment = */ alignment,
807
929
  /* .max_size = */ max_size
808
930
  };
809
-
931
+ auto reg = ggml_backend_rpc_add_server(endpoint);
810
932
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
811
933
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
812
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
934
+ /* .device = */ ggml_backend_reg_dev_get(reg, device),
813
935
  /* .context = */ buft_ctx
814
936
  };
815
- buft_map[endpoint] = buft;
937
+ buft_map[buft_name] = buft;
816
938
  return buft;
817
939
  }
818
940
 
819
- ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
941
+ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
942
+ std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
820
943
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
821
- /* .endpoint = */ endpoint,
822
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
944
+ /* .endpoint = */ endpoint,
945
+ /* .device = */ device,
946
+ /* .name = */ dev_name,
947
+ /* .gc = */ {},
823
948
  };
824
-
949
+ auto reg = ggml_backend_rpc_add_server(endpoint);
825
950
  ggml_backend_t backend = new ggml_backend {
826
- /* .guid = */ ggml_backend_rpc_guid(),
827
- /* .interface = */ ggml_backend_rpc_interface,
828
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
829
- /* .context = */ ctx
951
+ /* .guid = */ ggml_backend_rpc_guid(),
952
+ /* .iface = */ ggml_backend_rpc_interface,
953
+ /* .device = */ ggml_backend_reg_dev_get(reg, device),
954
+ /* .context = */ ctx
830
955
  };
831
956
  return backend;
832
957
  }
@@ -835,37 +960,40 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
835
960
  return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
836
961
  }
837
962
 
838
- static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
963
+ static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
964
+ rpc_msg_get_device_memory_req request;
965
+ request.device = device;
839
966
  rpc_msg_get_device_memory_rsp response;
840
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
967
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
841
968
  RPC_STATUS_ASSERT(status);
842
969
  *free = response.free_mem;
843
970
  *total = response.total_mem;
844
971
  }
845
972
 
846
- void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
973
+ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
847
974
  auto sock = get_socket(endpoint);
848
975
  if (sock == nullptr) {
849
976
  *free = 0;
850
977
  *total = 0;
851
978
  return;
852
979
  }
853
- get_device_memory(sock, free, total);
980
+ get_device_memory(sock, device, free, total);
854
981
  }
855
982
 
856
983
  // RPC server-side implementation
857
984
 
858
985
  class rpc_server {
859
986
  public:
860
- rpc_server(ggml_backend_t backend, const char * cache_dir)
861
- : backend(backend), cache_dir(cache_dir) {
987
+ rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
988
+ : backends(std::move(all_backends)), cache_dir(cache_dir) {
989
+ stored_graphs.resize(backends.size());
862
990
  }
863
991
  ~rpc_server();
864
992
 
865
993
  void hello(rpc_msg_hello_rsp & response);
866
- void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
867
- void get_alignment(rpc_msg_get_alignment_rsp & response);
868
- void get_max_size(rpc_msg_get_max_size_rsp & response);
994
+ bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
995
+ bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
996
+ bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
869
997
  bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
870
998
  bool free_buffer(const rpc_msg_free_buffer_req & request);
871
999
  bool buffer_clear(const rpc_msg_buffer_clear_req & request);
@@ -873,9 +1001,16 @@ public:
873
1001
  bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
874
1002
  bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
875
1003
  bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
876
- bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
1004
+ bool graph_compute(const std::vector<uint8_t> & input);
1005
+ bool graph_recompute(const rpc_msg_graph_recompute_req & request);
877
1006
  bool init_tensor(const rpc_msg_init_tensor_req & request);
878
1007
  bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
1008
+ bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
1009
+
1010
+ struct stored_graph {
1011
+ ggml_context_ptr ctx_ptr;
1012
+ ggml_cgraph * graph;
1013
+ };
879
1014
 
880
1015
  private:
881
1016
  bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
@@ -886,22 +1021,28 @@ private:
886
1021
  std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
887
1022
 
888
1023
 
889
- ggml_backend_t backend;
1024
+ std::vector<ggml_backend_t> backends;
890
1025
  const char * cache_dir;
891
1026
  std::unordered_set<ggml_backend_buffer_t> buffers;
1027
+ // store the last computed graph for each backend
1028
+ std::vector<stored_graph> stored_graphs;
892
1029
  };
893
1030
 
894
1031
  void rpc_server::hello(rpc_msg_hello_rsp & response) {
895
1032
  response.major = RPC_PROTO_MAJOR_VERSION;
896
1033
  response.minor = RPC_PROTO_MINOR_VERSION;
897
1034
  response.patch = RPC_PROTO_PATCH_VERSION;
898
- GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
1035
+ LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
899
1036
  }
900
1037
 
901
1038
  bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
1039
+ uint32_t dev_id = request.device;
1040
+ if (dev_id >= backends.size()) {
1041
+ return false;
1042
+ }
902
1043
  ggml_backend_buffer_type_t buft;
903
1044
  struct ggml_init_params params {
904
- /*.mem_size =*/ ggml_tensor_overhead(),
1045
+ /*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
905
1046
  /*.mem_buffer =*/ NULL,
906
1047
  /*.no_alloc =*/ true,
907
1048
  };
@@ -909,56 +1050,78 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
909
1050
  ggml_context_ptr ctx_ptr { ggml_init(params) };
910
1051
  GGML_ASSERT(ctx_ptr != nullptr);
911
1052
  ggml_context * ctx = ctx_ptr.get();
912
- ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
913
1053
 
1054
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
914
1055
  if (tensor == nullptr) {
915
1056
  GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
916
1057
  return false;
917
1058
  }
1059
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1060
+ if (request.srcs[i].id != 0) {
1061
+ tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
1062
+ }
1063
+ }
918
1064
 
1065
+ LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
919
1066
  if (tensor->buffer == nullptr) {
920
1067
  //No buffer allocated.
921
- buft = ggml_backend_get_default_buffer_type(backend);
1068
+ buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
922
1069
  } else {
923
1070
  buft = tensor->buffer->buft;
924
1071
  }
925
1072
 
926
- response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
1073
+ response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor);
927
1074
 
928
1075
  return true;
929
1076
  }
930
1077
 
931
- void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
932
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
1078
+ bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
1079
+ uint32_t dev_id = request.device;
1080
+ if (dev_id >= backends.size()) {
1081
+ return false;
1082
+ }
1083
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
933
1084
  ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
934
1085
  response.remote_ptr = 0;
935
1086
  response.remote_size = 0;
936
1087
  if (buffer != nullptr) {
937
1088
  response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
938
1089
  response.remote_size = buffer->size;
939
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
1090
+ LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
1091
+ __func__, dev_id, request.size, response.remote_ptr, response.remote_size);
940
1092
  buffers.insert(buffer);
941
1093
  } else {
942
- GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
1094
+ LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
943
1095
  }
1096
+ return true;
944
1097
  }
945
1098
 
946
- void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
947
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
1099
+ bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
1100
+ uint32_t dev_id = request.device;
1101
+ if (dev_id >= backends.size()) {
1102
+ return false;
1103
+ }
1104
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
948
1105
  size_t alignment = ggml_backend_buft_get_alignment(buft);
949
- GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
1106
+ LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
950
1107
  response.alignment = alignment;
1108
+ return true;
951
1109
  }
952
1110
 
953
- void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
954
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
1111
+ bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
1112
+ uint32_t dev_id = request.device;
1113
+ if (dev_id >= backends.size()) {
1114
+ return false;
1115
+ }
1116
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
955
1117
  size_t max_size = ggml_backend_buft_get_max_size(buft);
956
- GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
1118
+ LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
957
1119
  response.max_size = max_size;
1120
+ return true;
958
1121
  }
959
1122
 
960
1123
  bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
961
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
1124
+ LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
962
1125
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
963
1126
  if (buffers.find(buffer) == buffers.end()) {
964
1127
  GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
@@ -970,7 +1133,7 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp
970
1133
  }
971
1134
 
972
1135
  bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
973
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
1136
+ LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
974
1137
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
975
1138
  if (buffers.find(buffer) == buffers.end()) {
976
1139
  GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
@@ -982,7 +1145,7 @@ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
982
1145
  }
983
1146
 
984
1147
  bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
985
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
1148
+ LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
986
1149
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
987
1150
  if (buffers.find(buffer) == buffers.end()) {
988
1151
  GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
@@ -1055,11 +1218,11 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
1055
1218
  GGML_ASSERT(ctx_ptr != nullptr);
1056
1219
  ggml_context * ctx = ctx_ptr.get();
1057
1220
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1058
- if (tensor == nullptr) {
1221
+ if (tensor == nullptr || tensor->buffer == nullptr) {
1059
1222
  GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1060
1223
  return false;
1061
1224
  }
1062
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
1225
+ LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
1063
1226
 
1064
1227
  // sanitize tensor->data
1065
1228
  {
@@ -1082,7 +1245,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
1082
1245
  fs::path cache_file = fs::path(cache_dir) / hash_str;
1083
1246
  std::ofstream ofs(cache_file, std::ios::binary);
1084
1247
  ofs.write((const char *)data, size);
1085
- printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
1248
+ GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
1086
1249
  }
1087
1250
  ggml_backend_tensor_set(tensor, data, offset, size);
1088
1251
  return true;
@@ -1095,7 +1258,8 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1095
1258
  char hash_str[17];
1096
1259
  snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1097
1260
  fs::path cache_file = fs::path(cache_dir) / hash_str;
1098
- if (!fs::exists(cache_file)) {
1261
+ std::error_code ec;
1262
+ if (!fs::exists(cache_file, ec)) {
1099
1263
  return false;
1100
1264
  }
1101
1265
  std::ifstream ifs(cache_file, std::ios::binary);
@@ -1124,12 +1288,12 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp
1124
1288
  GGML_ASSERT(ctx_ptr != nullptr);
1125
1289
  ggml_context * ctx = ctx_ptr.get();
1126
1290
  ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1127
- if (tensor == nullptr) {
1291
+ if (tensor == nullptr || tensor->buffer == nullptr) {
1128
1292
  GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1129
1293
  return false;
1130
1294
  }
1131
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
1132
- __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
1295
+ LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
1296
+ __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
1133
1297
 
1134
1298
  // sanitize tensor->data
1135
1299
  {
@@ -1163,7 +1327,7 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
1163
1327
  GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
1164
1328
  return false;
1165
1329
  }
1166
-
1330
+ LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
1167
1331
  // Call the backend's buffer_init_tensor function
1168
1332
  ggml_backend_buffer_t buffer = tensor->buffer;
1169
1333
  if (buffer && buffer->iface.init_tensor) {
@@ -1192,11 +1356,11 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
1192
1356
  GGML_ASSERT(ctx_ptr != nullptr);
1193
1357
  ggml_context * ctx = ctx_ptr.get();
1194
1358
  ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1195
- if (tensor == nullptr) {
1359
+ if (tensor == nullptr || tensor->buffer == nullptr) {
1196
1360
  GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1197
1361
  return false;
1198
1362
  }
1199
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
1363
+ LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
1200
1364
 
1201
1365
  // sanitize tensor->data
1202
1366
  {
@@ -1229,7 +1393,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
1229
1393
 
1230
1394
  ggml_tensor * src = deserialize_tensor(ctx, &request.src);
1231
1395
  ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
1232
- if (src == nullptr || dst == nullptr) {
1396
+ if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
1233
1397
  GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
1234
1398
  return false;
1235
1399
  }
@@ -1240,7 +1404,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
1240
1404
  uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
1241
1405
 
1242
1406
  if (dst_data + src_size > dst_base + dst_buf_sz) {
1243
- GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
1407
+ GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
1244
1408
  " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
1245
1409
  " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
1246
1410
  __func__,
@@ -1251,8 +1415,8 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
1251
1415
  return false;
1252
1416
  }
1253
1417
 
1254
- GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
1255
- __func__, (void*) src->buffer, (void*) dst->buffer);
1418
+ LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n",
1419
+ __func__, (void*) src->buffer, (void*) dst->buffer);
1256
1420
 
1257
1421
  response.result = ggml_backend_buffer_copy_tensor(src, dst);
1258
1422
  return true;
@@ -1310,25 +1474,35 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
1310
1474
  return result;
1311
1475
  }
1312
1476
 
1313
- bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1477
+ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
1314
1478
  // serialization format:
1315
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1316
- if (input.size() < sizeof(uint32_t)) {
1479
+ // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1480
+ if (input.size() < 2*sizeof(uint32_t)) {
1481
+ return false;
1482
+ }
1483
+ const uint8_t * src = input.data();
1484
+ uint32_t device;
1485
+ memcpy(&device, src, sizeof(device));
1486
+ src += sizeof(device);
1487
+ if (device >= backends.size()) {
1317
1488
  return false;
1318
1489
  }
1319
1490
  uint32_t n_nodes;
1320
- memcpy(&n_nodes, input.data(), sizeof(n_nodes));
1321
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
1491
+ memcpy(&n_nodes, src, sizeof(n_nodes));
1492
+ src += sizeof(n_nodes);
1493
+ if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
1322
1494
  return false;
1323
1495
  }
1324
- const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
1496
+ const uint64_t * nodes = (const uint64_t *)src;
1497
+ src += n_nodes*sizeof(uint64_t);
1325
1498
  uint32_t n_tensors;
1326
- memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
1327
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1499
+ memcpy(&n_tensors, src, sizeof(n_tensors));
1500
+ src += sizeof(n_tensors);
1501
+ if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1328
1502
  return false;
1329
1503
  }
1330
- const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
1331
- GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
1504
+ const rpc_tensor * tensors = (const rpc_tensor *)src;
1505
+ LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
1332
1506
 
1333
1507
  size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1334
1508
 
@@ -1343,10 +1517,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1343
1517
  struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1344
1518
  graph->n_nodes = n_nodes;
1345
1519
  std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
1520
+ tensor_ptrs.reserve(n_tensors);
1346
1521
  for (uint32_t i = 0; i < n_tensors; i++) {
1347
- tensor_ptrs[tensors[i].id] = &tensors[i];
1522
+ tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
1348
1523
  }
1349
1524
  std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
1525
+ tensor_map.reserve(n_nodes);
1350
1526
  for (uint32_t i = 0; i < n_nodes; i++) {
1351
1527
  int64_t id;
1352
1528
  memcpy(&id, &nodes[i], sizeof(id));
@@ -1360,8 +1536,39 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1360
1536
  return false;
1361
1537
  }
1362
1538
  }
1363
- ggml_status status = ggml_backend_graph_compute(backend, graph);
1364
- response.result = status;
1539
+ ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1540
+ GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1541
+ stored_graphs[device].ctx_ptr.swap(ctx_ptr);
1542
+ stored_graphs[device].graph = graph;
1543
+ return true;
1544
+ }
1545
+
1546
+ bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
1547
+ uint32_t device = request.device;
1548
+ if (device >= backends.size()) {
1549
+ return false;
1550
+ }
1551
+ if (stored_graphs[device].graph == nullptr) {
1552
+ return false;
1553
+ }
1554
+ ggml_cgraph * graph = stored_graphs[device].graph;
1555
+ LOG_DBG("[%s] device: %u\n", __func__, device);
1556
+ ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1557
+ GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1558
+ return true;
1559
+ }
1560
+
1561
+ bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1562
+ uint32_t dev_id = request.device;
1563
+ if (dev_id >= backends.size()) {
1564
+ return false;
1565
+ }
1566
+ size_t free, total;
1567
+ ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
1568
+ ggml_backend_dev_memory(dev, &free, &total);
1569
+ response.free_mem = free;
1570
+ response.total_mem = total;
1571
+ LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
1365
1572
  return true;
1366
1573
  }
1367
1574
 
@@ -1371,16 +1578,16 @@ rpc_server::~rpc_server() {
1371
1578
  }
1372
1579
  }
1373
1580
 
1374
- static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1375
- sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1376
- rpc_server server(backend, cache_dir);
1581
+ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
1582
+ sockfd_t sockfd) {
1583
+ rpc_server server(backends, cache_dir);
1377
1584
  uint8_t cmd;
1378
1585
  if (!recv_data(sockfd, &cmd, 1)) {
1379
1586
  return;
1380
1587
  }
1381
1588
  // the first command sent by the client must be HELLO
1382
1589
  if (cmd != RPC_CMD_HELLO) {
1383
- fprintf(stderr, "Expected HELLO command, update client\n");
1590
+ GGML_LOG_ERROR("Expected HELLO command, update client\n");
1384
1591
  return;
1385
1592
  }
1386
1593
  if (!recv_msg(sockfd, nullptr, 0)) {
@@ -1397,7 +1604,7 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1397
1604
  }
1398
1605
  if (cmd >= RPC_CMD_COUNT) {
1399
1606
  // fail fast if the command is invalid
1400
- fprintf(stderr, "Unknown command: %d\n", cmd);
1607
+ GGML_LOG_ERROR("Unknown command: %d\n", cmd);
1401
1608
  break;
1402
1609
  }
1403
1610
  switch (cmd) {
@@ -1405,13 +1612,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1405
1612
  // HELLO command is handled above
1406
1613
  return;
1407
1614
  }
1615
+ case RPC_CMD_DEVICE_COUNT: {
1616
+ if (!recv_msg(sockfd, nullptr, 0)) {
1617
+ return;
1618
+ }
1619
+ rpc_msg_device_count_rsp response;
1620
+ response.device_count = backends.size();
1621
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1622
+ return;
1623
+ }
1624
+ break;
1625
+ }
1408
1626
  case RPC_CMD_ALLOC_BUFFER: {
1409
1627
  rpc_msg_alloc_buffer_req request;
1410
1628
  if (!recv_msg(sockfd, &request, sizeof(request))) {
1411
1629
  return;
1412
1630
  }
1413
1631
  rpc_msg_alloc_buffer_rsp response;
1414
- server.alloc_buffer(request, response);
1632
+ if (!server.alloc_buffer(request, response)) {
1633
+ return;
1634
+ }
1415
1635
  if (!send_msg(sockfd, &response, sizeof(response))) {
1416
1636
  return;
1417
1637
  }
@@ -1432,22 +1652,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1432
1652
  break;
1433
1653
  }
1434
1654
  case RPC_CMD_GET_ALIGNMENT: {
1435
- if (!recv_msg(sockfd, nullptr, 0)) {
1655
+ rpc_msg_get_alignment_req request;
1656
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1436
1657
  return;
1437
1658
  }
1438
1659
  rpc_msg_get_alignment_rsp response;
1439
- server.get_alignment(response);
1660
+ if (!server.get_alignment(request, response)) {
1661
+ return;
1662
+ }
1440
1663
  if (!send_msg(sockfd, &response, sizeof(response))) {
1441
1664
  return;
1442
1665
  }
1443
1666
  break;
1444
1667
  }
1445
1668
  case RPC_CMD_GET_MAX_SIZE: {
1446
- if (!recv_msg(sockfd, nullptr, 0)) {
1669
+ rpc_msg_get_max_size_req request;
1670
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1447
1671
  return;
1448
1672
  }
1449
1673
  rpc_msg_get_max_size_rsp response;
1450
- server.get_max_size(response);
1674
+ if (!server.get_max_size(request, response)) {
1675
+ return;
1676
+ }
1451
1677
  if (!send_msg(sockfd, &response, sizeof(response))) {
1452
1678
  return;
1453
1679
  }
@@ -1563,45 +1789,77 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1563
1789
  if (!recv_msg(sockfd, input)) {
1564
1790
  return;
1565
1791
  }
1566
- rpc_msg_graph_compute_rsp response;
1567
- if (!server.graph_compute(input, response)) {
1792
+ if (!server.graph_compute(input)) {
1568
1793
  return;
1569
1794
  }
1570
- if (!send_msg(sockfd, &response, sizeof(response))) {
1795
+ break;
1796
+ }
1797
+ case RPC_CMD_GRAPH_RECOMPUTE: {
1798
+ rpc_msg_graph_recompute_req request;
1799
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1800
+ return;
1801
+ }
1802
+ if (!server.graph_recompute(request)) {
1571
1803
  return;
1572
1804
  }
1573
1805
  break;
1574
1806
  }
1575
1807
  case RPC_CMD_GET_DEVICE_MEMORY: {
1576
- if (!recv_msg(sockfd, nullptr, 0)) {
1808
+ rpc_msg_get_device_memory_req request;
1809
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1577
1810
  return;
1578
1811
  }
1579
1812
  rpc_msg_get_device_memory_rsp response;
1580
- response.free_mem = free_mem;
1581
- response.total_mem = total_mem;
1813
+ if (!server.get_device_memory(request, response)) {
1814
+ return;
1815
+ }
1582
1816
  if (!send_msg(sockfd, &response, sizeof(response))) {
1583
1817
  return;
1584
1818
  }
1585
1819
  break;
1586
1820
  }
1587
1821
  default: {
1588
- fprintf(stderr, "Unknown command: %d\n", cmd);
1822
+ GGML_LOG_ERROR("Unknown command: %d\n", cmd);
1589
1823
  return;
1590
1824
  }
1591
1825
  }
1592
1826
  }
1593
1827
  }
1594
1828
 
1595
- void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
1596
- const char * cache_dir,
1597
- size_t free_mem, size_t total_mem) {
1829
+ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
1830
+ size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
1831
+ if (n_devices == 0 || devices == nullptr) {
1832
+ fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
1833
+ return;
1834
+ }
1835
+ std::vector<ggml_backend_t> backends;
1598
1836
  printf("Starting RPC server v%d.%d.%d\n",
1599
1837
  RPC_PROTO_MAJOR_VERSION,
1600
1838
  RPC_PROTO_MINOR_VERSION,
1601
1839
  RPC_PROTO_PATCH_VERSION);
1602
1840
  printf(" endpoint : %s\n", endpoint);
1603
1841
  printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
1604
- printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
1842
+ printf("Devices:\n");
1843
+ for (size_t i = 0; i < n_devices; i++) {
1844
+ auto dev = devices[i];
1845
+ size_t free, total;
1846
+ ggml_backend_dev_memory(dev, &free, &total);
1847
+ printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
1848
+ total / 1024 / 1024, free / 1024 / 1024);
1849
+ auto backend = ggml_backend_dev_init(dev, nullptr);
1850
+ if (!backend) {
1851
+ fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
1852
+ return;
1853
+ }
1854
+ backends.push_back(backend);
1855
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
1856
+ if (reg) {
1857
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
1858
+ if (ggml_backend_set_n_threads_fn) {
1859
+ ggml_backend_set_n_threads_fn(backend, n_threads);
1860
+ }
1861
+ }
1862
+ }
1605
1863
 
1606
1864
  std::string host;
1607
1865
  int port;
@@ -1629,22 +1887,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
1629
1887
  fprintf(stderr, "Failed to accept client connection\n");
1630
1888
  return;
1631
1889
  }
1632
- printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1890
+ printf("Accepted client connection\n");
1633
1891
  fflush(stdout);
1634
- rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
1892
+ rpc_serve_client(backends, cache_dir, client_socket->fd);
1635
1893
  printf("Client connection closed\n");
1636
1894
  fflush(stdout);
1637
1895
  }
1638
1896
  #ifdef _WIN32
1639
1897
  WSACleanup();
1640
1898
  #endif
1899
+ for (auto backend : backends) {
1900
+ ggml_backend_free(backend);
1901
+ }
1641
1902
  }
1642
1903
 
1643
1904
  // device interface
1644
1905
 
1645
1906
  struct ggml_backend_rpc_device_context {
1646
1907
  std::string endpoint;
1908
+ uint32_t device;
1647
1909
  std::string name;
1910
+ std::string description;
1648
1911
  };
1649
1912
 
1650
1913
  static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
@@ -1656,15 +1919,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1656
1919
  static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1657
1920
  ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1658
1921
 
1659
- return ctx->name.c_str();
1922
+ return ctx->description.c_str();
1660
1923
  }
1661
1924
 
1662
1925
  static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1663
1926
  ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1664
1927
 
1665
- ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1666
-
1667
- GGML_UNUSED(dev);
1928
+ ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
1668
1929
  }
1669
1930
 
1670
1931
  static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
@@ -1690,7 +1951,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
1690
1951
  static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1691
1952
  ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1692
1953
 
1693
- return ggml_backend_rpc_init(ctx->endpoint.c_str());
1954
+ return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
1694
1955
 
1695
1956
  GGML_UNUSED(params);
1696
1957
  }
@@ -1698,7 +1959,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
1698
1959
  static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1699
1960
  ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1700
1961
 
1701
- return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1962
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
1702
1963
 
1703
1964
  GGML_UNUSED(dev);
1704
1965
  }
@@ -1716,7 +1977,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
1716
1977
  }
1717
1978
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1718
1979
  ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1719
- return buft_ctx->endpoint == dev_ctx->endpoint;
1980
+ return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
1720
1981
  }
1721
1982
 
1722
1983
  static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
@@ -1739,28 +2000,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1739
2000
 
1740
2001
  // backend reg interface
1741
2002
 
1742
- static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1743
- return "RPC";
2003
+ struct ggml_backend_rpc_reg_context {
2004
+ std::string name;
2005
+ std::vector<ggml_backend_dev_t> devices;
2006
+ };
1744
2007
 
1745
- GGML_UNUSED(reg);
2008
+ static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
2009
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2010
+ return ctx ? ctx->name.c_str() : "RPC";
1746
2011
  }
1747
2012
 
1748
2013
  static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1749
- return 0;
1750
-
1751
- GGML_UNUSED(reg);
2014
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2015
+ return ctx ? ctx->devices.size() : 0;
1752
2016
  }
1753
2017
 
1754
2018
  static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1755
- GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1756
-
1757
- GGML_UNUSED(reg);
1758
- GGML_UNUSED(index);
2019
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2020
+ if (ctx == nullptr) {
2021
+ GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
2022
+ } else {
2023
+ GGML_ASSERT(index < ctx->devices.size());
2024
+ return ctx->devices[index];
2025
+ }
1759
2026
  }
1760
2027
 
1761
2028
  static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1762
- if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1763
- return (void *)ggml_backend_rpc_add_device;
2029
+ if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
2030
+ return (void *)ggml_backend_rpc_add_server;
1764
2031
  }
1765
2032
  if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
1766
2033
  return (void *)ggml_backend_rpc_start_server;
@@ -1787,30 +2054,65 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
1787
2054
  return &ggml_backend_rpc_reg;
1788
2055
  }
1789
2056
 
1790
- ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1791
- static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
2057
+ static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
2058
+ auto sock = get_socket(endpoint);
2059
+ if (sock == nullptr) {
2060
+ GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
2061
+ return 0;
2062
+ }
2063
+ rpc_msg_device_count_rsp response;
2064
+ bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
2065
+ RPC_STATUS_ASSERT(status);
2066
+ return response.device_count;
2067
+ }
2068
+
2069
+ static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
2070
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
2071
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
2072
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
2073
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
2074
+ };
1792
2075
 
2076
+ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
2077
+ static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
1793
2078
  static std::mutex mutex;
2079
+ static uint32_t dev_id = 0;
1794
2080
  std::lock_guard<std::mutex> lock(mutex);
1795
-
1796
- if (dev_map.find(endpoint) != dev_map.end()) {
1797
- return dev_map[endpoint];
2081
+ if (reg_map.find(endpoint) != reg_map.end()) {
2082
+ return reg_map[endpoint];
1798
2083
  }
1799
-
1800
- ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1801
- /* .endpoint = */ endpoint,
1802
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
1803
- };
1804
-
1805
- ggml_backend_dev_t dev = new ggml_backend_device {
1806
- /* .iface = */ ggml_backend_rpc_device_i,
1807
- /* .reg = */ ggml_backend_rpc_reg(),
1808
- /* .context = */ ctx,
2084
+ uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
2085
+ if (dev_count == 0) {
2086
+ return nullptr;
2087
+ }
2088
+ ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
2089
+ ctx->name = "RPC[" + std::string(endpoint) + "]";
2090
+ for (uint32_t ind = 0; ind < dev_count; ind++) {
2091
+ std::string dev_name = "RPC" + std::to_string(dev_id);
2092
+ std::string dev_desc = std::string(endpoint);
2093
+ ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
2094
+ /* .endpoint = */ endpoint,
2095
+ /* .device = */ ind,
2096
+ /* .name = */ dev_name,
2097
+ /* .description = */ dev_desc
2098
+ };
2099
+
2100
+ ggml_backend_dev_t dev = new ggml_backend_device {
2101
+ /* .iface = */ ggml_backend_rpc_device_i,
2102
+ /* .reg = */ ggml_backend_rpc_reg(),
2103
+ /* .context = */ dev_ctx,
2104
+ };
2105
+ ctx->devices.push_back(dev);
2106
+ dev_id++;
2107
+ }
2108
+ ggml_backend_reg_t reg = new ggml_backend_reg {
2109
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
2110
+ /* .iface = */ ggml_backend_rpc_reg_interface,
2111
+ /* .context = */ ctx
1809
2112
  };
1810
-
1811
- dev_map[endpoint] = dev;
1812
-
1813
- return dev;
2113
+ reg_map[endpoint] = reg;
2114
+ return reg;
1814
2115
  }
1815
2116
 
2117
+
1816
2118
  GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)