whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -105,9 +105,13 @@ enum rpc_cmd {
105
105
  RPC_CMD_INIT_TENSOR,
106
106
  RPC_CMD_GET_ALLOC_SIZE,
107
107
  RPC_CMD_HELLO,
108
+ RPC_CMD_DEVICE_COUNT,
109
+ RPC_CMD_GRAPH_RECOMPUTE,
108
110
  RPC_CMD_COUNT,
109
111
  };
110
112
 
113
+ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
114
+
111
115
  // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
112
116
  const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
113
117
 
@@ -117,8 +121,14 @@ struct rpc_msg_hello_rsp {
117
121
  uint8_t patch;
118
122
  };
119
123
 
124
+ struct rpc_msg_device_count_rsp {
125
+ uint32_t device_count;
126
+ };
127
+
120
128
  struct rpc_msg_get_alloc_size_req {
129
+ uint32_t device;
121
130
  rpc_tensor tensor;
131
+ rpc_tensor srcs[GGML_MAX_SRC];
122
132
  };
123
133
 
124
134
  struct rpc_msg_get_alloc_size_rsp {
@@ -130,6 +140,7 @@ struct rpc_msg_init_tensor_req {
130
140
  };
131
141
 
132
142
  struct rpc_msg_alloc_buffer_req {
143
+ uint32_t device;
133
144
  uint64_t size;
134
145
  };
135
146
 
@@ -138,10 +149,18 @@ struct rpc_msg_alloc_buffer_rsp {
138
149
  uint64_t remote_size;
139
150
  };
140
151
 
152
+ struct rpc_msg_get_alignment_req {
153
+ uint32_t device;
154
+ };
155
+
141
156
  struct rpc_msg_get_alignment_rsp {
142
157
  uint64_t alignment;
143
158
  };
144
159
 
160
+ struct rpc_msg_get_max_size_req {
161
+ uint32_t device;
162
+ };
163
+
145
164
  struct rpc_msg_get_max_size_rsp {
146
165
  uint64_t max_size;
147
166
  };
@@ -188,14 +207,19 @@ struct rpc_msg_copy_tensor_rsp {
188
207
  uint8_t result;
189
208
  };
190
209
 
191
- struct rpc_msg_graph_compute_rsp {
192
- uint8_t result;
210
+ struct rpc_msg_get_device_memory_req {
211
+ uint32_t device;
193
212
  };
194
213
 
195
214
  struct rpc_msg_get_device_memory_rsp {
196
215
  uint64_t free_mem;
197
216
  uint64_t total_mem;
198
217
  };
218
+
219
+ struct rpc_msg_graph_recompute_req {
220
+ uint32_t device;
221
+ };
222
+
199
223
  #pragma pack(pop)
200
224
 
201
225
  // RPC data structures
@@ -207,14 +231,41 @@ static ggml_guid_t ggml_backend_rpc_guid() {
207
231
 
208
232
  struct ggml_backend_rpc_buffer_type_context {
209
233
  std::string endpoint;
234
+ uint32_t device;
210
235
  std::string name;
211
- size_t alignment;
212
- size_t max_size;
236
+ size_t alignment;
237
+ size_t max_size;
238
+ };
239
+
240
+ struct graph_cache {
241
+
242
+ bool is_cached(const ggml_cgraph * cgraph) {
243
+ if ((int)last_graph.size() != cgraph->n_nodes) {
244
+ return false;
245
+ }
246
+ for (int i = 0; i < cgraph->n_nodes; i++) {
247
+ if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
248
+ return false;
249
+ }
250
+ }
251
+ return true;
252
+ }
253
+
254
+ void add(const ggml_cgraph * cgraph) {
255
+ last_graph.resize(cgraph->n_nodes);
256
+ for (int i = 0; i < cgraph->n_nodes; i++) {
257
+ memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
258
+ }
259
+ }
260
+
261
+ std::vector<ggml_tensor> last_graph;
213
262
  };
214
263
 
215
264
  struct ggml_backend_rpc_context {
216
265
  std::string endpoint;
266
+ uint32_t device;
217
267
  std::string name;
268
+ graph_cache gc;
218
269
  };
219
270
 
220
271
  struct ggml_backend_rpc_buffer_context {
@@ -473,6 +524,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
473
524
  std::string host;
474
525
  int port;
475
526
  if (!parse_endpoint(endpoint, host, port)) {
527
+ GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
476
528
  return nullptr;
477
529
  }
478
530
  #ifdef _WIN32
@@ -520,14 +572,23 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
520
572
  return ctx->base_ptr;
521
573
  }
522
574
 
575
+ static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
576
+ return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
577
+ }
578
+
523
579
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
524
580
  rpc_tensor result;
581
+ if (!tensor) {
582
+ memset(&result, 0, sizeof(result));
583
+ return result;
584
+ }
585
+
525
586
  result.id = reinterpret_cast<uint64_t>(tensor);
526
587
  result.type = tensor->type;
527
- if (tensor->buffer) {
588
+ if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
528
589
  ggml_backend_buffer_t buffer = tensor->buffer;
529
590
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
530
- result.buffer = ctx->remote_ptr;
591
+ result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
531
592
  } else {
532
593
  result.buffer = 0;
533
594
  }
@@ -609,22 +670,25 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
609
670
  }
610
671
 
611
672
  static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
612
- // check if src and dst are on the same server
613
- ggml_backend_buffer_t src_buffer = src->buffer;
614
- ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
615
- ggml_backend_buffer_t dst_buffer = dst->buffer;
616
- ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
617
- if (src_ctx->sock != dst_ctx->sock) {
618
- return false;
673
+ if (ggml_backend_buffer_is_rpc(src->buffer)) {
674
+ // check if src and dst are on the same server
675
+ ggml_backend_buffer_t src_buffer = src->buffer;
676
+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
677
+ ggml_backend_buffer_t dst_buffer = dst->buffer;
678
+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
679
+ if (src_ctx->sock != dst_ctx->sock) {
680
+ return false;
681
+ }
682
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
683
+ rpc_msg_copy_tensor_req request;
684
+ request.src = serialize_tensor(src);
685
+ request.dst = serialize_tensor(dst);
686
+ rpc_msg_copy_tensor_rsp response;
687
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
688
+ RPC_STATUS_ASSERT(status);
689
+ return response.result;
619
690
  }
620
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
621
- rpc_msg_copy_tensor_req request;
622
- request.src = serialize_tensor(src);
623
- request.dst = serialize_tensor(dst);
624
- rpc_msg_copy_tensor_rsp response;
625
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
626
- RPC_STATUS_ASSERT(status);
627
- return response.result;
691
+ return false;
628
692
  }
629
693
 
630
694
  static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -653,7 +717,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
653
717
 
654
718
  static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
655
719
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
656
- rpc_msg_alloc_buffer_req request = {size};
720
+ rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
657
721
  rpc_msg_alloc_buffer_rsp response;
658
722
  auto sock = get_socket(buft_ctx->endpoint);
659
723
  bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
@@ -669,9 +733,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
669
733
  }
670
734
  }
671
735
 
672
- static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
736
+ static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
737
+ rpc_msg_get_alignment_req request = {device};
673
738
  rpc_msg_get_alignment_rsp response;
674
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
739
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
675
740
  RPC_STATUS_ASSERT(status);
676
741
  return response.alignment;
677
742
  }
@@ -681,9 +746,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
681
746
  return buft_ctx->alignment;
682
747
  }
683
748
 
684
- static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
749
+ static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
750
+ rpc_msg_get_max_size_req request = {device};
685
751
  rpc_msg_get_max_size_rsp response;
686
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
752
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
687
753
  RPC_STATUS_ASSERT(status);
688
754
  return response.max_size;
689
755
  }
@@ -694,23 +760,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
694
760
  }
695
761
 
696
762
  static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
763
+ // should we query the remote server for the actual size
764
+ bool rpc_get = false;
765
+
697
766
  // See comments in init_tensor.
698
- if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
767
+ rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
768
+
769
+ // ops that require additional memory for fleeting data on certain backends
770
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15966
771
+ rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
772
+ rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
773
+
774
+ if (rpc_get) {
699
775
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
700
776
  auto sock = get_socket(buft_ctx->endpoint);
701
777
 
702
- rpc_msg_get_alloc_size_req request;
778
+ rpc_msg_get_alloc_size_req request = {
779
+ /*.device =*/ buft_ctx->device,
780
+ /*.tensor =*/ serialize_tensor(tensor),
781
+ /*.srcs =*/ {},
782
+ };
703
783
 
704
- request.tensor = serialize_tensor(tensor);
784
+ // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
785
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
786
+ request.srcs[i] = serialize_tensor(tensor->src[i]);
787
+ }
705
788
 
789
+ // TODO: cache the alloc responses to avoid extra RPC calls?
706
790
  rpc_msg_get_alloc_size_rsp response;
707
791
  bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
708
792
  RPC_STATUS_ASSERT(status);
709
793
 
710
794
  return response.alloc_size;
711
- } else {
712
- return ggml_nbytes(tensor);
713
795
  }
796
+
797
+ return ggml_nbytes(tensor);
714
798
  }
715
799
 
716
800
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -754,7 +838,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
754
838
  tensors.push_back(serialize_tensor(tensor));
755
839
  }
756
840
 
757
- static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
841
+ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
758
842
  uint32_t n_nodes = cgraph->n_nodes;
759
843
  std::vector<rpc_tensor> tensors;
760
844
  std::unordered_set<ggml_tensor*> visited;
@@ -762,29 +846,45 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
762
846
  add_tensor(cgraph->nodes[i], tensors, visited);
763
847
  }
764
848
  // serialization format:
765
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
849
+ // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
766
850
  uint32_t n_tensors = tensors.size();
767
- int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
851
+ int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
768
852
  output.resize(output_size, 0);
769
- memcpy(output.data(), &n_nodes, sizeof(n_nodes));
853
+ uint8_t * dest = output.data();
854
+ memcpy(dest, &device, sizeof(device));
855
+ dest += sizeof(device);
856
+ memcpy(dest, &n_nodes, sizeof(n_nodes));
857
+ dest += sizeof(n_nodes);
770
858
  for (uint32_t i = 0; i < n_nodes; i++) {
771
- memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
859
+ memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
772
860
  }
773
- uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
774
- *out_ntensors = n_tensors;
775
- rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
861
+ dest += n_nodes * sizeof(uint64_t);
862
+ memcpy(dest, &n_tensors, sizeof(n_tensors));
863
+ dest += sizeof(n_tensors);
864
+ rpc_tensor * out_tensors = (rpc_tensor *)dest;
776
865
  memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
777
866
  }
778
867
 
779
868
  static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
780
869
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
781
- std::vector<uint8_t> input;
782
- serialize_graph(cgraph, input);
783
- rpc_msg_graph_compute_rsp response;
784
- auto sock = get_socket(rpc_ctx->endpoint);
785
- bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
786
- RPC_STATUS_ASSERT(status);
787
- return (enum ggml_status)response.result;
870
+
871
+ GGML_ASSERT(cgraph->n_nodes > 0);
872
+ bool reuse = rpc_ctx->gc.is_cached(cgraph);
873
+ if (reuse) {
874
+ rpc_msg_graph_recompute_req request;
875
+ request.device = rpc_ctx->device;
876
+ auto sock = get_socket(rpc_ctx->endpoint);
877
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
878
+ RPC_STATUS_ASSERT(status);
879
+ } else {
880
+ rpc_ctx->gc.add(cgraph);
881
+ std::vector<uint8_t> input;
882
+ serialize_graph(rpc_ctx->device, cgraph, input);
883
+ auto sock = get_socket(rpc_ctx->endpoint);
884
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
885
+ RPC_STATUS_ASSERT(status);
886
+ }
887
+ return GGML_STATUS_SUCCESS;
788
888
  }
789
889
 
790
890
  static ggml_backend_i ggml_backend_rpc_interface = {
@@ -804,12 +904,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
804
904
  /* .graph_optimize = */ NULL,
805
905
  };
806
906
 
807
- ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
907
+ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
808
908
  static std::mutex mutex;
809
909
  std::lock_guard<std::mutex> lock(mutex);
910
+ std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
810
911
  // NOTE: buffer types are allocated and never freed; this is by design
811
912
  static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
812
- auto it = buft_map.find(endpoint);
913
+ auto it = buft_map.find(buft_name);
813
914
  if (it != buft_map.end()) {
814
915
  return it->second;
815
916
  }
@@ -818,34 +919,38 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
818
919
  GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
819
920
  return nullptr;
820
921
  }
821
- size_t alignment = get_alignment(sock);
822
- size_t max_size = get_max_size(sock);
922
+ size_t alignment = get_alignment(sock, device);
923
+ size_t max_size = get_max_size(sock, device);
823
924
  ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
824
925
  /* .endpoint = */ endpoint,
825
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
926
+ /* .device = */ device,
927
+ /* .name = */ buft_name,
826
928
  /* .alignment = */ alignment,
827
929
  /* .max_size = */ max_size
828
930
  };
829
-
931
+ auto reg = ggml_backend_rpc_add_server(endpoint);
830
932
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
831
933
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
832
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
934
+ /* .device = */ ggml_backend_reg_dev_get(reg, device),
833
935
  /* .context = */ buft_ctx
834
936
  };
835
- buft_map[endpoint] = buft;
937
+ buft_map[buft_name] = buft;
836
938
  return buft;
837
939
  }
838
940
 
839
- ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
941
+ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
942
+ std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
840
943
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
841
- /* .endpoint = */ endpoint,
842
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
944
+ /* .endpoint = */ endpoint,
945
+ /* .device = */ device,
946
+ /* .name = */ dev_name,
947
+ /* .gc = */ {},
843
948
  };
844
-
949
+ auto reg = ggml_backend_rpc_add_server(endpoint);
845
950
  ggml_backend_t backend = new ggml_backend {
846
951
  /* .guid = */ ggml_backend_rpc_guid(),
847
952
  /* .iface = */ ggml_backend_rpc_interface,
848
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
953
+ /* .device = */ ggml_backend_reg_dev_get(reg, device),
849
954
  /* .context = */ ctx
850
955
  };
851
956
  return backend;
@@ -855,37 +960,40 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
855
960
  return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
856
961
  }
857
962
 
858
- static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
963
+ static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
964
+ rpc_msg_get_device_memory_req request;
965
+ request.device = device;
859
966
  rpc_msg_get_device_memory_rsp response;
860
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
967
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
861
968
  RPC_STATUS_ASSERT(status);
862
969
  *free = response.free_mem;
863
970
  *total = response.total_mem;
864
971
  }
865
972
 
866
- void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
973
+ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
867
974
  auto sock = get_socket(endpoint);
868
975
  if (sock == nullptr) {
869
976
  *free = 0;
870
977
  *total = 0;
871
978
  return;
872
979
  }
873
- get_device_memory(sock, free, total);
980
+ get_device_memory(sock, device, free, total);
874
981
  }
875
982
 
876
983
  // RPC server-side implementation
877
984
 
878
985
  class rpc_server {
879
986
  public:
880
- rpc_server(ggml_backend_t backend, const char * cache_dir)
881
- : backend(backend), cache_dir(cache_dir) {
987
+ rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
988
+ : backends(std::move(all_backends)), cache_dir(cache_dir) {
989
+ stored_graphs.resize(backends.size());
882
990
  }
883
991
  ~rpc_server();
884
992
 
885
993
  void hello(rpc_msg_hello_rsp & response);
886
- void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
887
- void get_alignment(rpc_msg_get_alignment_rsp & response);
888
- void get_max_size(rpc_msg_get_max_size_rsp & response);
994
+ bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
995
+ bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
996
+ bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
889
997
  bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
890
998
  bool free_buffer(const rpc_msg_free_buffer_req & request);
891
999
  bool buffer_clear(const rpc_msg_buffer_clear_req & request);
@@ -893,9 +1001,16 @@ public:
893
1001
  bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
894
1002
  bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
895
1003
  bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
896
- bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
1004
+ bool graph_compute(const std::vector<uint8_t> & input);
1005
+ bool graph_recompute(const rpc_msg_graph_recompute_req & request);
897
1006
  bool init_tensor(const rpc_msg_init_tensor_req & request);
898
1007
  bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
1008
+ bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
1009
+
1010
+ struct stored_graph {
1011
+ ggml_context_ptr ctx_ptr;
1012
+ ggml_cgraph * graph;
1013
+ };
899
1014
 
900
1015
  private:
901
1016
  bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
@@ -906,9 +1021,11 @@ private:
906
1021
  std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
907
1022
 
908
1023
 
909
- ggml_backend_t backend;
1024
+ std::vector<ggml_backend_t> backends;
910
1025
  const char * cache_dir;
911
1026
  std::unordered_set<ggml_backend_buffer_t> buffers;
1027
+ // store the last computed graph for each backend
1028
+ std::vector<stored_graph> stored_graphs;
912
1029
  };
913
1030
 
914
1031
  void rpc_server::hello(rpc_msg_hello_rsp & response) {
@@ -919,9 +1036,13 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
919
1036
  }
920
1037
 
921
1038
  bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
1039
+ uint32_t dev_id = request.device;
1040
+ if (dev_id >= backends.size()) {
1041
+ return false;
1042
+ }
922
1043
  ggml_backend_buffer_type_t buft;
923
1044
  struct ggml_init_params params {
924
- /*.mem_size =*/ ggml_tensor_overhead(),
1045
+ /*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
925
1046
  /*.mem_buffer =*/ NULL,
926
1047
  /*.no_alloc =*/ true,
927
1048
  };
@@ -929,16 +1050,22 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
929
1050
  ggml_context_ptr ctx_ptr { ggml_init(params) };
930
1051
  GGML_ASSERT(ctx_ptr != nullptr);
931
1052
  ggml_context * ctx = ctx_ptr.get();
932
- ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
933
1053
 
1054
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
934
1055
  if (tensor == nullptr) {
935
1056
  GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
936
1057
  return false;
937
1058
  }
938
- LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
1059
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1060
+ if (request.srcs[i].id != 0) {
1061
+ tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
1062
+ }
1063
+ }
1064
+
1065
+ LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
939
1066
  if (tensor->buffer == nullptr) {
940
1067
  //No buffer allocated.
941
- buft = ggml_backend_get_default_buffer_type(backend);
1068
+ buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
942
1069
  } else {
943
1070
  buft = tensor->buffer->buft;
944
1071
  }
@@ -948,33 +1075,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
948
1075
  return true;
949
1076
  }
950
1077
 
951
- void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
952
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
1078
+ bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
1079
+ uint32_t dev_id = request.device;
1080
+ if (dev_id >= backends.size()) {
1081
+ return false;
1082
+ }
1083
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
953
1084
  ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
954
1085
  response.remote_ptr = 0;
955
1086
  response.remote_size = 0;
956
1087
  if (buffer != nullptr) {
957
1088
  response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
958
1089
  response.remote_size = buffer->size;
959
- LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
1090
+ LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
1091
+ __func__, dev_id, request.size, response.remote_ptr, response.remote_size);
960
1092
  buffers.insert(buffer);
961
1093
  } else {
962
- LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
1094
+ LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
963
1095
  }
1096
+ return true;
964
1097
  }
965
1098
 
966
- void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
967
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
1099
+ bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
1100
+ uint32_t dev_id = request.device;
1101
+ if (dev_id >= backends.size()) {
1102
+ return false;
1103
+ }
1104
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
968
1105
  size_t alignment = ggml_backend_buft_get_alignment(buft);
969
- LOG_DBG("[%s] alignment: %lu\n", __func__, alignment);
1106
+ LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
970
1107
  response.alignment = alignment;
1108
+ return true;
971
1109
  }
972
1110
 
973
- void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
974
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
1111
+ bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
1112
+ uint32_t dev_id = request.device;
1113
+ if (dev_id >= backends.size()) {
1114
+ return false;
1115
+ }
1116
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
975
1117
  size_t max_size = ggml_backend_buft_get_max_size(buft);
976
- LOG_DBG("[%s] max_size: %lu\n", __func__, max_size);
1118
+ LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
977
1119
  response.max_size = max_size;
1120
+ return true;
978
1121
  }
979
1122
 
980
1123
  bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
@@ -1115,7 +1258,8 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1115
1258
  char hash_str[17];
1116
1259
  snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1117
1260
  fs::path cache_file = fs::path(cache_dir) / hash_str;
1118
- if (!fs::exists(cache_file)) {
1261
+ std::error_code ec;
1262
+ if (!fs::exists(cache_file, ec)) {
1119
1263
  return false;
1120
1264
  }
1121
1265
  std::ifstream ifs(cache_file, std::ios::binary);
@@ -1330,25 +1474,35 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
1330
1474
  return result;
1331
1475
  }
1332
1476
 
1333
- bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1477
+ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
1334
1478
  // serialization format:
1335
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1336
- if (input.size() < sizeof(uint32_t)) {
1479
+ // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1480
+ if (input.size() < 2*sizeof(uint32_t)) {
1481
+ return false;
1482
+ }
1483
+ const uint8_t * src = input.data();
1484
+ uint32_t device;
1485
+ memcpy(&device, src, sizeof(device));
1486
+ src += sizeof(device);
1487
+ if (device >= backends.size()) {
1337
1488
  return false;
1338
1489
  }
1339
1490
  uint32_t n_nodes;
1340
- memcpy(&n_nodes, input.data(), sizeof(n_nodes));
1341
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
1491
+ memcpy(&n_nodes, src, sizeof(n_nodes));
1492
+ src += sizeof(n_nodes);
1493
+ if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
1342
1494
  return false;
1343
1495
  }
1344
- const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
1496
+ const uint64_t * nodes = (const uint64_t *)src;
1497
+ src += n_nodes*sizeof(uint64_t);
1345
1498
  uint32_t n_tensors;
1346
- memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
1347
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1499
+ memcpy(&n_tensors, src, sizeof(n_tensors));
1500
+ src += sizeof(n_tensors);
1501
+ if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1348
1502
  return false;
1349
1503
  }
1350
- const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
1351
- LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
1504
+ const rpc_tensor * tensors = (const rpc_tensor *)src;
1505
+ LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
1352
1506
 
1353
1507
  size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1354
1508
 
@@ -1363,10 +1517,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1363
1517
  struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1364
1518
  graph->n_nodes = n_nodes;
1365
1519
  std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
1520
+ tensor_ptrs.reserve(n_tensors);
1366
1521
  for (uint32_t i = 0; i < n_tensors; i++) {
1367
- tensor_ptrs[tensors[i].id] = &tensors[i];
1522
+ tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
1368
1523
  }
1369
1524
  std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
1525
+ tensor_map.reserve(n_nodes);
1370
1526
  for (uint32_t i = 0; i < n_nodes; i++) {
1371
1527
  int64_t id;
1372
1528
  memcpy(&id, &nodes[i], sizeof(id));
@@ -1380,8 +1536,39 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1380
1536
  return false;
1381
1537
  }
1382
1538
  }
1383
- ggml_status status = ggml_backend_graph_compute(backend, graph);
1384
- response.result = status;
1539
+ ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1540
+ GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1541
+ stored_graphs[device].ctx_ptr.swap(ctx_ptr);
1542
+ stored_graphs[device].graph = graph;
1543
+ return true;
1544
+ }
1545
+
1546
+ bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
1547
+ uint32_t device = request.device;
1548
+ if (device >= backends.size()) {
1549
+ return false;
1550
+ }
1551
+ if (stored_graphs[device].graph == nullptr) {
1552
+ return false;
1553
+ }
1554
+ ggml_cgraph * graph = stored_graphs[device].graph;
1555
+ LOG_DBG("[%s] device: %u\n", __func__, device);
1556
+ ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1557
+ GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1558
+ return true;
1559
+ }
1560
+
1561
+ bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1562
+ uint32_t dev_id = request.device;
1563
+ if (dev_id >= backends.size()) {
1564
+ return false;
1565
+ }
1566
+ size_t free, total;
1567
+ ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
1568
+ ggml_backend_dev_memory(dev, &free, &total);
1569
+ response.free_mem = free;
1570
+ response.total_mem = total;
1571
+ LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
1385
1572
  return true;
1386
1573
  }
1387
1574
 
@@ -1391,9 +1578,9 @@ rpc_server::~rpc_server() {
1391
1578
  }
1392
1579
  }
1393
1580
 
1394
- static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1395
- sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1396
- rpc_server server(backend, cache_dir);
1581
+ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
1582
+ sockfd_t sockfd) {
1583
+ rpc_server server(backends, cache_dir);
1397
1584
  uint8_t cmd;
1398
1585
  if (!recv_data(sockfd, &cmd, 1)) {
1399
1586
  return;
@@ -1425,13 +1612,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1425
1612
  // HELLO command is handled above
1426
1613
  return;
1427
1614
  }
1615
+ case RPC_CMD_DEVICE_COUNT: {
1616
+ if (!recv_msg(sockfd, nullptr, 0)) {
1617
+ return;
1618
+ }
1619
+ rpc_msg_device_count_rsp response;
1620
+ response.device_count = backends.size();
1621
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1622
+ return;
1623
+ }
1624
+ break;
1625
+ }
1428
1626
  case RPC_CMD_ALLOC_BUFFER: {
1429
1627
  rpc_msg_alloc_buffer_req request;
1430
1628
  if (!recv_msg(sockfd, &request, sizeof(request))) {
1431
1629
  return;
1432
1630
  }
1433
1631
  rpc_msg_alloc_buffer_rsp response;
1434
- server.alloc_buffer(request, response);
1632
+ if (!server.alloc_buffer(request, response)) {
1633
+ return;
1634
+ }
1435
1635
  if (!send_msg(sockfd, &response, sizeof(response))) {
1436
1636
  return;
1437
1637
  }
@@ -1452,22 +1652,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1452
1652
  break;
1453
1653
  }
1454
1654
  case RPC_CMD_GET_ALIGNMENT: {
1455
- if (!recv_msg(sockfd, nullptr, 0)) {
1655
+ rpc_msg_get_alignment_req request;
1656
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1456
1657
  return;
1457
1658
  }
1458
1659
  rpc_msg_get_alignment_rsp response;
1459
- server.get_alignment(response);
1660
+ if (!server.get_alignment(request, response)) {
1661
+ return;
1662
+ }
1460
1663
  if (!send_msg(sockfd, &response, sizeof(response))) {
1461
1664
  return;
1462
1665
  }
1463
1666
  break;
1464
1667
  }
1465
1668
  case RPC_CMD_GET_MAX_SIZE: {
1466
- if (!recv_msg(sockfd, nullptr, 0)) {
1669
+ rpc_msg_get_max_size_req request;
1670
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1467
1671
  return;
1468
1672
  }
1469
1673
  rpc_msg_get_max_size_rsp response;
1470
- server.get_max_size(response);
1674
+ if (!server.get_max_size(request, response)) {
1675
+ return;
1676
+ }
1471
1677
  if (!send_msg(sockfd, &response, sizeof(response))) {
1472
1678
  return;
1473
1679
  }
@@ -1583,22 +1789,30 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1583
1789
  if (!recv_msg(sockfd, input)) {
1584
1790
  return;
1585
1791
  }
1586
- rpc_msg_graph_compute_rsp response;
1587
- if (!server.graph_compute(input, response)) {
1792
+ if (!server.graph_compute(input)) {
1588
1793
  return;
1589
1794
  }
1590
- if (!send_msg(sockfd, &response, sizeof(response))) {
1795
+ break;
1796
+ }
1797
+ case RPC_CMD_GRAPH_RECOMPUTE: {
1798
+ rpc_msg_graph_recompute_req request;
1799
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1800
+ return;
1801
+ }
1802
+ if (!server.graph_recompute(request)) {
1591
1803
  return;
1592
1804
  }
1593
1805
  break;
1594
1806
  }
1595
1807
  case RPC_CMD_GET_DEVICE_MEMORY: {
1596
- if (!recv_msg(sockfd, nullptr, 0)) {
1808
+ rpc_msg_get_device_memory_req request;
1809
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1597
1810
  return;
1598
1811
  }
1599
1812
  rpc_msg_get_device_memory_rsp response;
1600
- response.free_mem = free_mem;
1601
- response.total_mem = total_mem;
1813
+ if (!server.get_device_memory(request, response)) {
1814
+ return;
1815
+ }
1602
1816
  if (!send_msg(sockfd, &response, sizeof(response))) {
1603
1817
  return;
1604
1818
  }
@@ -1612,16 +1826,40 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1612
1826
  }
1613
1827
  }
1614
1828
 
1615
- void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
1616
- const char * cache_dir,
1617
- size_t free_mem, size_t total_mem) {
1829
+ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
1830
+ size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
1831
+ if (n_devices == 0 || devices == nullptr) {
1832
+ fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
1833
+ return;
1834
+ }
1835
+ std::vector<ggml_backend_t> backends;
1618
1836
  printf("Starting RPC server v%d.%d.%d\n",
1619
1837
  RPC_PROTO_MAJOR_VERSION,
1620
1838
  RPC_PROTO_MINOR_VERSION,
1621
1839
  RPC_PROTO_PATCH_VERSION);
1622
1840
  printf(" endpoint : %s\n", endpoint);
1623
1841
  printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
1624
- printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
1842
+ printf("Devices:\n");
1843
+ for (size_t i = 0; i < n_devices; i++) {
1844
+ auto dev = devices[i];
1845
+ size_t free, total;
1846
+ ggml_backend_dev_memory(dev, &free, &total);
1847
+ printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
1848
+ total / 1024 / 1024, free / 1024 / 1024);
1849
+ auto backend = ggml_backend_dev_init(dev, nullptr);
1850
+ if (!backend) {
1851
+ fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
1852
+ return;
1853
+ }
1854
+ backends.push_back(backend);
1855
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
1856
+ if (reg) {
1857
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
1858
+ if (ggml_backend_set_n_threads_fn) {
1859
+ ggml_backend_set_n_threads_fn(backend, n_threads);
1860
+ }
1861
+ }
1862
+ }
1625
1863
 
1626
1864
  std::string host;
1627
1865
  int port;
@@ -1649,22 +1887,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
1649
1887
  fprintf(stderr, "Failed to accept client connection\n");
1650
1888
  return;
1651
1889
  }
1652
- printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1890
+ printf("Accepted client connection\n");
1653
1891
  fflush(stdout);
1654
- rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
1892
+ rpc_serve_client(backends, cache_dir, client_socket->fd);
1655
1893
  printf("Client connection closed\n");
1656
1894
  fflush(stdout);
1657
1895
  }
1658
1896
  #ifdef _WIN32
1659
1897
  WSACleanup();
1660
1898
  #endif
1899
+ for (auto backend : backends) {
1900
+ ggml_backend_free(backend);
1901
+ }
1661
1902
  }
1662
1903
 
1663
1904
  // device interface
1664
1905
 
1665
1906
  struct ggml_backend_rpc_device_context {
1666
1907
  std::string endpoint;
1908
+ uint32_t device;
1667
1909
  std::string name;
1910
+ std::string description;
1668
1911
  };
1669
1912
 
1670
1913
  static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
@@ -1676,15 +1919,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1676
1919
  static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1677
1920
  ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1678
1921
 
1679
- return ctx->name.c_str();
1922
+ return ctx->description.c_str();
1680
1923
  }
1681
1924
 
1682
1925
  static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1683
1926
  ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1684
1927
 
1685
- ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1686
-
1687
- GGML_UNUSED(dev);
1928
+ ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
1688
1929
  }
1689
1930
 
1690
1931
  static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
@@ -1710,7 +1951,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
1710
1951
  static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1711
1952
  ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1712
1953
 
1713
- return ggml_backend_rpc_init(ctx->endpoint.c_str());
1954
+ return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
1714
1955
 
1715
1956
  GGML_UNUSED(params);
1716
1957
  }
@@ -1718,7 +1959,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
1718
1959
  static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1719
1960
  ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1720
1961
 
1721
- return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1962
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
1722
1963
 
1723
1964
  GGML_UNUSED(dev);
1724
1965
  }
@@ -1736,7 +1977,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
1736
1977
  }
1737
1978
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1738
1979
  ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1739
- return buft_ctx->endpoint == dev_ctx->endpoint;
1980
+ return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
1740
1981
  }
1741
1982
 
1742
1983
  static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
@@ -1759,28 +2000,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1759
2000
 
1760
2001
  // backend reg interface
1761
2002
 
1762
- static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1763
- return "RPC";
2003
+ struct ggml_backend_rpc_reg_context {
2004
+ std::string name;
2005
+ std::vector<ggml_backend_dev_t> devices;
2006
+ };
1764
2007
 
1765
- GGML_UNUSED(reg);
2008
+ static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
2009
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2010
+ return ctx ? ctx->name.c_str() : "RPC";
1766
2011
  }
1767
2012
 
1768
2013
  static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1769
- return 0;
1770
-
1771
- GGML_UNUSED(reg);
2014
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2015
+ return ctx ? ctx->devices.size() : 0;
1772
2016
  }
1773
2017
 
1774
2018
  static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1775
- GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1776
-
1777
- GGML_UNUSED(reg);
1778
- GGML_UNUSED(index);
2019
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2020
+ if (ctx == nullptr) {
2021
+ GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
2022
+ } else {
2023
+ GGML_ASSERT(index < ctx->devices.size());
2024
+ return ctx->devices[index];
2025
+ }
1779
2026
  }
1780
2027
 
1781
2028
  static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1782
- if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1783
- return (void *)ggml_backend_rpc_add_device;
2029
+ if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
2030
+ return (void *)ggml_backend_rpc_add_server;
1784
2031
  }
1785
2032
  if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
1786
2033
  return (void *)ggml_backend_rpc_start_server;
@@ -1807,30 +2054,65 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
1807
2054
  return &ggml_backend_rpc_reg;
1808
2055
  }
1809
2056
 
1810
- ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1811
- static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
2057
+ static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
2058
+ auto sock = get_socket(endpoint);
2059
+ if (sock == nullptr) {
2060
+ GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
2061
+ return 0;
2062
+ }
2063
+ rpc_msg_device_count_rsp response;
2064
+ bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
2065
+ RPC_STATUS_ASSERT(status);
2066
+ return response.device_count;
2067
+ }
2068
+
2069
+ static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
2070
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
2071
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
2072
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
2073
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
2074
+ };
1812
2075
 
2076
+ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
2077
+ static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
1813
2078
  static std::mutex mutex;
2079
+ static uint32_t dev_id = 0;
1814
2080
  std::lock_guard<std::mutex> lock(mutex);
1815
-
1816
- if (dev_map.find(endpoint) != dev_map.end()) {
1817
- return dev_map[endpoint];
2081
+ if (reg_map.find(endpoint) != reg_map.end()) {
2082
+ return reg_map[endpoint];
1818
2083
  }
1819
-
1820
- ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1821
- /* .endpoint = */ endpoint,
1822
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
1823
- };
1824
-
1825
- ggml_backend_dev_t dev = new ggml_backend_device {
1826
- /* .iface = */ ggml_backend_rpc_device_i,
1827
- /* .reg = */ ggml_backend_rpc_reg(),
1828
- /* .context = */ ctx,
2084
+ uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
2085
+ if (dev_count == 0) {
2086
+ return nullptr;
2087
+ }
2088
+ ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
2089
+ ctx->name = "RPC[" + std::string(endpoint) + "]";
2090
+ for (uint32_t ind = 0; ind < dev_count; ind++) {
2091
+ std::string dev_name = "RPC" + std::to_string(dev_id);
2092
+ std::string dev_desc = std::string(endpoint);
2093
+ ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
2094
+ /* .endpoint = */ endpoint,
2095
+ /* .device = */ ind,
2096
+ /* .name = */ dev_name,
2097
+ /* .description = */ dev_desc
2098
+ };
2099
+
2100
+ ggml_backend_dev_t dev = new ggml_backend_device {
2101
+ /* .iface = */ ggml_backend_rpc_device_i,
2102
+ /* .reg = */ ggml_backend_rpc_reg(),
2103
+ /* .context = */ dev_ctx,
2104
+ };
2105
+ ctx->devices.push_back(dev);
2106
+ dev_id++;
2107
+ }
2108
+ ggml_backend_reg_t reg = new ggml_backend_reg {
2109
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
2110
+ /* .iface = */ ggml_backend_rpc_reg_interface,
2111
+ /* .context = */ ctx
1829
2112
  };
1830
-
1831
- dev_map[endpoint] = dev;
1832
-
1833
- return dev;
2113
+ reg_map[endpoint] = reg;
2114
+ return reg;
1834
2115
  }
1835
2116
 
2117
+
1836
2118
  GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)