local-llm-rn 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (626) hide show
  1. package/cpp/CMakeLists.txt +285 -0
  2. package/cpp/common/CMakeLists.txt +149 -0
  3. package/cpp/common/arg.cpp +3799 -0
  4. package/cpp/common/arg.h +131 -0
  5. package/cpp/common/base64.hpp +392 -0
  6. package/cpp/common/build-info.cpp.in +4 -0
  7. package/cpp/common/chat-parser-xml-toolcall.cpp +879 -0
  8. package/cpp/common/chat-parser-xml-toolcall.h +45 -0
  9. package/cpp/common/chat-parser.cpp +1649 -0
  10. package/cpp/common/chat-parser.h +133 -0
  11. package/cpp/common/chat-peg-parser.cpp +124 -0
  12. package/cpp/common/chat-peg-parser.h +105 -0
  13. package/cpp/common/chat.cpp +3355 -0
  14. package/cpp/common/chat.h +252 -0
  15. package/cpp/common/common.cpp +1824 -0
  16. package/cpp/common/common.h +930 -0
  17. package/cpp/common/console.cpp +1137 -0
  18. package/cpp/common/console.h +41 -0
  19. package/cpp/common/debug.cpp +167 -0
  20. package/cpp/common/debug.h +43 -0
  21. package/cpp/common/download.cpp +792 -0
  22. package/cpp/common/download.h +84 -0
  23. package/cpp/common/http.h +84 -0
  24. package/cpp/common/jinja/README.md +88 -0
  25. package/cpp/common/jinja/caps.cpp +285 -0
  26. package/cpp/common/jinja/caps.h +30 -0
  27. package/cpp/common/jinja/lexer.cpp +341 -0
  28. package/cpp/common/jinja/lexer.h +157 -0
  29. package/cpp/common/jinja/parser.cpp +591 -0
  30. package/cpp/common/jinja/parser.h +21 -0
  31. package/cpp/common/jinja/runtime.cpp +867 -0
  32. package/cpp/common/jinja/runtime.h +638 -0
  33. package/cpp/common/jinja/string.cpp +213 -0
  34. package/cpp/common/jinja/string.h +61 -0
  35. package/cpp/common/jinja/utils.h +149 -0
  36. package/cpp/common/jinja/value.cpp +1393 -0
  37. package/cpp/common/jinja/value.h +756 -0
  38. package/cpp/common/json-partial.cpp +324 -0
  39. package/cpp/common/json-partial.h +39 -0
  40. package/cpp/common/json-schema-to-grammar.cpp +1153 -0
  41. package/cpp/common/json-schema-to-grammar.h +43 -0
  42. package/cpp/common/llguidance.cpp +258 -0
  43. package/cpp/common/log.cpp +446 -0
  44. package/cpp/common/log.h +119 -0
  45. package/cpp/common/ngram-cache.cpp +285 -0
  46. package/cpp/common/ngram-cache.h +101 -0
  47. package/cpp/common/ngram-map.cpp +530 -0
  48. package/cpp/common/ngram-map.h +115 -0
  49. package/cpp/common/ngram-mod.cpp +60 -0
  50. package/cpp/common/ngram-mod.h +38 -0
  51. package/cpp/common/peg-parser.cpp +1712 -0
  52. package/cpp/common/peg-parser.h +459 -0
  53. package/cpp/common/preset.cpp +483 -0
  54. package/cpp/common/preset.h +83 -0
  55. package/cpp/common/regex-partial.cpp +204 -0
  56. package/cpp/common/regex-partial.h +56 -0
  57. package/cpp/common/sampling.cpp +745 -0
  58. package/cpp/common/sampling.h +119 -0
  59. package/cpp/common/speculative.cpp +1074 -0
  60. package/cpp/common/speculative.h +41 -0
  61. package/cpp/common/unicode.cpp +64 -0
  62. package/cpp/common/unicode.h +22 -0
  63. package/cpp/ggml/CMakeLists.txt +494 -0
  64. package/cpp/ggml/cmake/GitVars.cmake +22 -0
  65. package/cpp/ggml/cmake/common.cmake +50 -0
  66. package/cpp/ggml/cmake/ggml-config.cmake.in +191 -0
  67. package/cpp/ggml/include/ggml-alloc.h +85 -0
  68. package/cpp/ggml/include/ggml-backend.h +373 -0
  69. package/cpp/ggml/include/ggml-blas.h +25 -0
  70. package/cpp/ggml/include/ggml-cann.h +123 -0
  71. package/cpp/ggml/include/ggml-cpp.h +39 -0
  72. package/cpp/ggml/include/ggml-cpu.h +151 -0
  73. package/cpp/ggml/include/ggml-cuda.h +47 -0
  74. package/cpp/ggml/include/ggml-hexagon.h +19 -0
  75. package/cpp/ggml/include/ggml-metal.h +61 -0
  76. package/cpp/ggml/include/ggml-opencl.h +26 -0
  77. package/cpp/ggml/include/ggml-opt.h +256 -0
  78. package/cpp/ggml/include/ggml-rpc.h +30 -0
  79. package/cpp/ggml/include/ggml-sycl.h +49 -0
  80. package/cpp/ggml/include/ggml-virtgpu.h +14 -0
  81. package/cpp/ggml/include/ggml-vulkan.h +29 -0
  82. package/cpp/ggml/include/ggml-webgpu.h +19 -0
  83. package/cpp/ggml/include/ggml-zdnn.h +17 -0
  84. package/cpp/ggml/include/ggml-zendnn.h +22 -0
  85. package/cpp/ggml/include/ggml.h +2753 -0
  86. package/cpp/ggml/include/gguf.h +204 -0
  87. package/cpp/ggml/src/CMakeLists.txt +492 -0
  88. package/cpp/ggml/src/ggml-alloc.c +1244 -0
  89. package/cpp/ggml/src/ggml-backend-dl.cpp +48 -0
  90. package/cpp/ggml/src/ggml-backend-dl.h +45 -0
  91. package/cpp/ggml/src/ggml-backend-impl.h +255 -0
  92. package/cpp/ggml/src/ggml-backend-reg.cpp +566 -0
  93. package/cpp/ggml/src/ggml-backend.cpp +2270 -0
  94. package/cpp/ggml/src/ggml-blas/CMakeLists.txt +101 -0
  95. package/cpp/ggml/src/ggml-blas/ggml-blas.cpp +518 -0
  96. package/cpp/ggml/src/ggml-common.h +1878 -0
  97. package/cpp/ggml/src/ggml-cpu/CMakeLists.txt +691 -0
  98. package/cpp/ggml/src/ggml-cpu/amx/amx.cpp +247 -0
  99. package/cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  100. package/cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  101. package/cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2512 -0
  102. package/cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  103. package/cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +98 -0
  104. package/cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4052 -0
  105. package/cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +4935 -0
  106. package/cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2159 -0
  107. package/cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  108. package/cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  109. package/cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  110. package/cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2726 -0
  111. package/cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  112. package/cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  113. package/cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  114. package/cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  115. package/cpp/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  116. package/cpp/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  117. package/cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  118. package/cpp/ggml/src/ggml-cpu/arch-fallback.h +313 -0
  119. package/cpp/ggml/src/ggml-cpu/binary-ops.cpp +154 -0
  120. package/cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  121. package/cpp/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  122. package/cpp/ggml/src/ggml-cpu/common.h +95 -0
  123. package/cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +529 -0
  124. package/cpp/ggml/src/ggml-cpu/ggml-cpu.c +3734 -0
  125. package/cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +701 -0
  126. package/cpp/ggml/src/ggml-cpu/hbm.cpp +55 -0
  127. package/cpp/ggml/src/ggml-cpu/hbm.h +8 -0
  128. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +938 -0
  129. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +90 -0
  130. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +798 -0
  131. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  132. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +4033 -0
  133. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +25 -0
  134. package/cpp/ggml/src/ggml-cpu/ops.cpp +10978 -0
  135. package/cpp/ggml/src/ggml-cpu/ops.h +116 -0
  136. package/cpp/ggml/src/ggml-cpu/quants.c +1193 -0
  137. package/cpp/ggml/src/ggml-cpu/quants.h +97 -0
  138. package/cpp/ggml/src/ggml-cpu/repack.cpp +3316 -0
  139. package/cpp/ggml/src/ggml-cpu/repack.h +173 -0
  140. package/cpp/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  141. package/cpp/ggml/src/ggml-cpu/simd-mappings.h +1279 -0
  142. package/cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  143. package/cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  144. package/cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  145. package/cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  146. package/cpp/ggml/src/ggml-cpu/traits.cpp +36 -0
  147. package/cpp/ggml/src/ggml-cpu/traits.h +38 -0
  148. package/cpp/ggml/src/ggml-cpu/unary-ops.cpp +337 -0
  149. package/cpp/ggml/src/ggml-cpu/unary-ops.h +35 -0
  150. package/cpp/ggml/src/ggml-cpu/vec.cpp +629 -0
  151. package/cpp/ggml/src/ggml-cpu/vec.h +1585 -0
  152. package/cpp/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  153. package/cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3232 -0
  154. package/cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -0
  155. package/cpp/ggml/src/ggml-hexagon/htp/act-ops.c +815 -0
  156. package/cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  157. package/cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +827 -0
  158. package/cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  159. package/cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c +251 -0
  160. package/cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +666 -0
  161. package/cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c +111 -0
  162. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  163. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  164. package/cpp/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  165. package/cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  166. package/cpp/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  167. package/cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  168. package/cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +154 -0
  169. package/cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +65 -0
  170. package/cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  171. package/cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h +470 -0
  172. package/cpp/ggml/src/ggml-hexagon/htp/hvx-base.h +173 -0
  173. package/cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  174. package/cpp/ggml/src/ggml-hexagon/htp/hvx-div.h +116 -0
  175. package/cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  176. package/cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  177. package/cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  178. package/cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h +176 -0
  179. package/cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h +266 -0
  180. package/cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  181. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  182. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  183. package/cpp/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  184. package/cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -0
  185. package/cpp/ggml/src/ggml-hexagon/htp/main.c +1150 -0
  186. package/cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2595 -0
  187. package/cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +498 -0
  188. package/cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c +167 -0
  189. package/cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +421 -0
  190. package/cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +130 -0
  191. package/cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +384 -0
  192. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  193. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  194. package/cpp/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  195. package/cpp/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  196. package/cpp/ggml/src/ggml-hexagon/libdl.h +79 -0
  197. package/cpp/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  198. package/cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
  199. package/cpp/ggml/src/ggml-impl.h +724 -0
  200. package/cpp/ggml/src/ggml-metal/CMakeLists.txt +124 -0
  201. package/cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +457 -0
  202. package/cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  203. package/cpp/ggml/src/ggml-metal/ggml-metal-context.h +41 -0
  204. package/cpp/ggml/src/ggml-metal/ggml-metal-context.m +702 -0
  205. package/cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1890 -0
  206. package/cpp/ggml/src/ggml-metal/ggml-metal-device.h +290 -0
  207. package/cpp/ggml/src/ggml-metal/ggml-metal-device.m +1749 -0
  208. package/cpp/ggml/src/ggml-metal/ggml-metal-impl.h +1054 -0
  209. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +4370 -0
  210. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  211. package/cpp/ggml/src/ggml-metal/ggml-metal.cpp +937 -0
  212. package/cpp/ggml/src/ggml-metal/ggml-metal.metal +9819 -0
  213. package/cpp/ggml/src/ggml-musa/CMakeLists.txt +125 -0
  214. package/cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  215. package/cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  216. package/cpp/ggml/src/ggml-opencl/CMakeLists.txt +150 -0
  217. package/cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +11553 -0
  218. package/cpp/ggml/src/ggml-opencl/kernels/add.cl +190 -0
  219. package/cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  220. package/cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  221. package/cpp/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  222. package/cpp/ggml/src/ggml-opencl/kernels/concat.cl +51 -0
  223. package/cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  224. package/cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  225. package/cpp/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  226. package/cpp/ggml/src/ggml-opencl/kernels/cvt.cl +417 -0
  227. package/cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  228. package/cpp/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  229. package/cpp/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  230. package/cpp/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  231. package/cpp/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  232. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  233. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  234. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  235. package/cpp/ggml/src/ggml-opencl/kernels/gelu.cl +89 -0
  236. package/cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  237. package/cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  238. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  239. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  240. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  241. package/cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +187 -0
  242. package/cpp/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  243. package/cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  244. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  245. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  246. package/cpp/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  247. package/cpp/ggml/src/ggml-opencl/kernels/mul.cl +152 -0
  248. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  249. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  250. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  251. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  252. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  253. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  254. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  255. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  256. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  257. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  258. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  259. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  260. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  261. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  262. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  263. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  264. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  265. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  266. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  267. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  268. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  269. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  270. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  271. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  272. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  273. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  274. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  275. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  276. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  277. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  278. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl +194 -0
  279. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  280. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  281. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  282. package/cpp/ggml/src/ggml-opencl/kernels/norm.cl +161 -0
  283. package/cpp/ggml/src/ggml-opencl/kernels/pad.cl +39 -0
  284. package/cpp/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  285. package/cpp/ggml/src/ggml-opencl/kernels/repeat.cl +38 -0
  286. package/cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +190 -0
  287. package/cpp/ggml/src/ggml-opencl/kernels/rope.cl +747 -0
  288. package/cpp/ggml/src/ggml-opencl/kernels/scale.cl +27 -0
  289. package/cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  290. package/cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  291. package/cpp/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  292. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +108 -0
  293. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +108 -0
  294. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +107 -0
  295. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +107 -0
  296. package/cpp/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  297. package/cpp/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  298. package/cpp/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  299. package/cpp/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  300. package/cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  301. package/cpp/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  302. package/cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +140 -0
  303. package/cpp/ggml/src/ggml-opencl/kernels/tanh.cl +109 -0
  304. package/cpp/ggml/src/ggml-opencl/kernels/transpose.cl +117 -0
  305. package/cpp/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  306. package/cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  307. package/cpp/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  308. package/cpp/ggml/src/ggml-opt.cpp +1093 -0
  309. package/cpp/ggml/src/ggml-quants.c +5325 -0
  310. package/cpp/ggml/src/ggml-quants.h +106 -0
  311. package/cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  312. package/cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2118 -0
  313. package/cpp/ggml/src/ggml-threading.cpp +12 -0
  314. package/cpp/ggml/src/ggml-threading.h +14 -0
  315. package/cpp/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  316. package/cpp/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  317. package/cpp/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  318. package/cpp/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  319. package/cpp/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  320. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  321. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  322. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  323. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  324. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  325. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  326. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  327. package/cpp/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  328. package/cpp/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  329. package/cpp/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  330. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  331. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  332. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  333. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  334. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  335. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  336. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  337. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  338. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  339. package/cpp/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  340. package/cpp/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  341. package/cpp/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  342. package/cpp/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  343. package/cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  344. package/cpp/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  345. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  346. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  347. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  348. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  349. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  350. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  351. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  352. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  353. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  354. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  355. package/cpp/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  356. package/cpp/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  357. package/cpp/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  358. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1231 -0
  359. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3150 -0
  360. package/cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  361. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  362. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  363. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  364. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +107 -0
  365. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +923 -0
  366. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  367. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  368. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +182 -0
  369. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  370. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +668 -0
  371. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  372. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  373. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +713 -0
  374. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +103 -0
  375. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +138 -0
  376. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +188 -0
  377. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +194 -0
  378. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  379. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  380. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  381. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  382. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +109 -0
  383. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  384. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  385. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  386. package/cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  387. package/cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
  388. package/cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +633 -0
  389. package/cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  390. package/cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  391. package/cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
  392. package/cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
  393. package/cpp/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  394. package/cpp/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  395. package/cpp/ggml/src/ggml.c +7669 -0
  396. package/cpp/ggml/src/ggml.cpp +26 -0
  397. package/cpp/ggml/src/gguf.cpp +1699 -0
  398. package/cpp/include/llama-cpp.h +32 -0
  399. package/cpp/include/llama.h +1568 -0
  400. package/cpp/mtmd/CMakeLists.txt +98 -0
  401. package/cpp/mtmd/README.md +63 -0
  402. package/cpp/mtmd/clip-graph.h +117 -0
  403. package/cpp/mtmd/clip-impl.h +586 -0
  404. package/cpp/mtmd/clip-model.h +390 -0
  405. package/cpp/mtmd/clip.cpp +4154 -0
  406. package/cpp/mtmd/clip.h +121 -0
  407. package/cpp/mtmd/deprecation-warning.cpp +22 -0
  408. package/cpp/mtmd/legacy-models/convert_image_encoder_to_gguf.py +412 -0
  409. package/cpp/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +280 -0
  410. package/cpp/mtmd/legacy-models/glmedge-surgery.py +33 -0
  411. package/cpp/mtmd/legacy-models/llava_surgery.py +38 -0
  412. package/cpp/mtmd/legacy-models/llava_surgery_v2.py +180 -0
  413. package/cpp/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +892 -0
  414. package/cpp/mtmd/legacy-models/minicpmv-surgery.py +47 -0
  415. package/cpp/mtmd/models/cogvlm.cpp +98 -0
  416. package/cpp/mtmd/models/conformer.cpp +216 -0
  417. package/cpp/mtmd/models/glm4v.cpp +122 -0
  418. package/cpp/mtmd/models/internvl.cpp +69 -0
  419. package/cpp/mtmd/models/kimik25.cpp +101 -0
  420. package/cpp/mtmd/models/kimivl.cpp +63 -0
  421. package/cpp/mtmd/models/llama4.cpp +96 -0
  422. package/cpp/mtmd/models/llava.cpp +374 -0
  423. package/cpp/mtmd/models/minicpmv.cpp +114 -0
  424. package/cpp/mtmd/models/mobilenetv5.cpp +451 -0
  425. package/cpp/mtmd/models/models.h +128 -0
  426. package/cpp/mtmd/models/nemotron-v2-vl.cpp +35 -0
  427. package/cpp/mtmd/models/paddleocr.cpp +52 -0
  428. package/cpp/mtmd/models/pixtral.cpp +86 -0
  429. package/cpp/mtmd/models/qwen2vl.cpp +183 -0
  430. package/cpp/mtmd/models/qwen3vl.cpp +193 -0
  431. package/cpp/mtmd/models/siglip.cpp +86 -0
  432. package/cpp/mtmd/models/whisper-enc.cpp +115 -0
  433. package/cpp/mtmd/models/youtuvl.cpp +179 -0
  434. package/cpp/mtmd/mtmd-audio.cpp +730 -0
  435. package/cpp/mtmd/mtmd-audio.h +113 -0
  436. package/cpp/mtmd/mtmd-cli.cpp +437 -0
  437. package/cpp/mtmd/mtmd-helper.cpp +521 -0
  438. package/cpp/mtmd/mtmd-helper.h +96 -0
  439. package/cpp/mtmd/mtmd.cpp +1156 -0
  440. package/cpp/mtmd/mtmd.h +319 -0
  441. package/cpp/mtmd/requirements.txt +5 -0
  442. package/cpp/mtmd/test-1.jpeg +0 -0
  443. package/cpp/mtmd/test-2.mp3 +0 -0
  444. package/cpp/mtmd/tests.sh +192 -0
  445. package/cpp/src/CMakeLists.txt +169 -0
  446. package/cpp/src/llama-adapter.cpp +488 -0
  447. package/cpp/src/llama-adapter.h +89 -0
  448. package/cpp/src/llama-arch.cpp +2855 -0
  449. package/cpp/src/llama-arch.h +619 -0
  450. package/cpp/src/llama-batch.cpp +917 -0
  451. package/cpp/src/llama-batch.h +173 -0
  452. package/cpp/src/llama-chat.cpp +896 -0
  453. package/cpp/src/llama-chat.h +71 -0
  454. package/cpp/src/llama-context.cpp +3512 -0
  455. package/cpp/src/llama-context.h +359 -0
  456. package/cpp/src/llama-cparams.cpp +5 -0
  457. package/cpp/src/llama-cparams.h +44 -0
  458. package/cpp/src/llama-grammar.cpp +1464 -0
  459. package/cpp/src/llama-grammar.h +194 -0
  460. package/cpp/src/llama-graph.cpp +2685 -0
  461. package/cpp/src/llama-graph.h +1026 -0
  462. package/cpp/src/llama-hparams.cpp +234 -0
  463. package/cpp/src/llama-hparams.h +339 -0
  464. package/cpp/src/llama-impl.cpp +171 -0
  465. package/cpp/src/llama-impl.h +73 -0
  466. package/cpp/src/llama-io.cpp +15 -0
  467. package/cpp/src/llama-io.h +35 -0
  468. package/cpp/src/llama-kv-cache-iswa.cpp +330 -0
  469. package/cpp/src/llama-kv-cache-iswa.h +137 -0
  470. package/cpp/src/llama-kv-cache.cpp +2271 -0
  471. package/cpp/src/llama-kv-cache.h +388 -0
  472. package/cpp/src/llama-kv-cells.h +533 -0
  473. package/cpp/src/llama-memory-hybrid-iswa.cpp +275 -0
  474. package/cpp/src/llama-memory-hybrid-iswa.h +140 -0
  475. package/cpp/src/llama-memory-hybrid.cpp +268 -0
  476. package/cpp/src/llama-memory-hybrid.h +139 -0
  477. package/cpp/src/llama-memory-recurrent.cpp +1165 -0
  478. package/cpp/src/llama-memory-recurrent.h +182 -0
  479. package/cpp/src/llama-memory.cpp +59 -0
  480. package/cpp/src/llama-memory.h +122 -0
  481. package/cpp/src/llama-mmap.cpp +785 -0
  482. package/cpp/src/llama-mmap.h +92 -0
  483. package/cpp/src/llama-model-loader.cpp +1414 -0
  484. package/cpp/src/llama-model-loader.h +203 -0
  485. package/cpp/src/llama-model-saver.cpp +286 -0
  486. package/cpp/src/llama-model-saver.h +37 -0
  487. package/cpp/src/llama-model.cpp +9253 -0
  488. package/cpp/src/llama-model.h +576 -0
  489. package/cpp/src/llama-quant.cpp +1119 -0
  490. package/cpp/src/llama-quant.h +1 -0
  491. package/cpp/src/llama-sampler.cpp +3885 -0
  492. package/cpp/src/llama-sampler.h +42 -0
  493. package/cpp/src/llama-vocab.cpp +3970 -0
  494. package/cpp/src/llama-vocab.h +187 -0
  495. package/cpp/src/llama.cpp +1313 -0
  496. package/cpp/src/models/afmoe.cpp +191 -0
  497. package/cpp/src/models/apertus.cpp +125 -0
  498. package/cpp/src/models/arcee.cpp +135 -0
  499. package/cpp/src/models/arctic.cpp +138 -0
  500. package/cpp/src/models/arwkv7.cpp +86 -0
  501. package/cpp/src/models/baichuan.cpp +122 -0
  502. package/cpp/src/models/bailingmoe.cpp +144 -0
  503. package/cpp/src/models/bailingmoe2.cpp +135 -0
  504. package/cpp/src/models/bert.cpp +178 -0
  505. package/cpp/src/models/bitnet.cpp +160 -0
  506. package/cpp/src/models/bloom.cpp +101 -0
  507. package/cpp/src/models/chameleon.cpp +178 -0
  508. package/cpp/src/models/chatglm.cpp +132 -0
  509. package/cpp/src/models/codeshell.cpp +111 -0
  510. package/cpp/src/models/cogvlm.cpp +102 -0
  511. package/cpp/src/models/cohere2-iswa.cpp +134 -0
  512. package/cpp/src/models/command-r.cpp +122 -0
  513. package/cpp/src/models/dbrx.cpp +123 -0
  514. package/cpp/src/models/deci.cpp +135 -0
  515. package/cpp/src/models/deepseek.cpp +144 -0
  516. package/cpp/src/models/deepseek2.cpp +262 -0
  517. package/cpp/src/models/delta-net-base.cpp +376 -0
  518. package/cpp/src/models/dots1.cpp +134 -0
  519. package/cpp/src/models/dream.cpp +105 -0
  520. package/cpp/src/models/ernie4-5-moe.cpp +150 -0
  521. package/cpp/src/models/ernie4-5.cpp +110 -0
  522. package/cpp/src/models/eurobert.cpp +97 -0
  523. package/cpp/src/models/exaone-moe.cpp +146 -0
  524. package/cpp/src/models/exaone.cpp +114 -0
  525. package/cpp/src/models/exaone4.cpp +123 -0
  526. package/cpp/src/models/falcon-h1.cpp +111 -0
  527. package/cpp/src/models/falcon.cpp +120 -0
  528. package/cpp/src/models/gemma-embedding.cpp +116 -0
  529. package/cpp/src/models/gemma.cpp +112 -0
  530. package/cpp/src/models/gemma2-iswa.cpp +128 -0
  531. package/cpp/src/models/gemma3.cpp +155 -0
  532. package/cpp/src/models/gemma3n-iswa.cpp +384 -0
  533. package/cpp/src/models/glm4-moe.cpp +170 -0
  534. package/cpp/src/models/glm4.cpp +157 -0
  535. package/cpp/src/models/gpt2.cpp +105 -0
  536. package/cpp/src/models/gptneox.cpp +144 -0
  537. package/cpp/src/models/granite-hybrid.cpp +196 -0
  538. package/cpp/src/models/granite.cpp +211 -0
  539. package/cpp/src/models/grok.cpp +159 -0
  540. package/cpp/src/models/grovemoe.cpp +141 -0
  541. package/cpp/src/models/hunyuan-dense.cpp +132 -0
  542. package/cpp/src/models/hunyuan-moe.cpp +154 -0
  543. package/cpp/src/models/internlm2.cpp +120 -0
  544. package/cpp/src/models/jais.cpp +86 -0
  545. package/cpp/src/models/jais2.cpp +123 -0
  546. package/cpp/src/models/jamba.cpp +106 -0
  547. package/cpp/src/models/kimi-linear.cpp +392 -0
  548. package/cpp/src/models/lfm2.cpp +190 -0
  549. package/cpp/src/models/llada-moe.cpp +122 -0
  550. package/cpp/src/models/llada.cpp +99 -0
  551. package/cpp/src/models/llama-iswa.cpp +178 -0
  552. package/cpp/src/models/llama.cpp +168 -0
  553. package/cpp/src/models/maincoder.cpp +117 -0
  554. package/cpp/src/models/mamba-base.cpp +285 -0
  555. package/cpp/src/models/mamba.cpp +54 -0
  556. package/cpp/src/models/mimo2-iswa.cpp +123 -0
  557. package/cpp/src/models/minicpm3.cpp +200 -0
  558. package/cpp/src/models/minimax-m2.cpp +124 -0
  559. package/cpp/src/models/mistral3.cpp +160 -0
  560. package/cpp/src/models/models.h +684 -0
  561. package/cpp/src/models/modern-bert.cpp +109 -0
  562. package/cpp/src/models/mpt.cpp +126 -0
  563. package/cpp/src/models/nemotron-h.cpp +148 -0
  564. package/cpp/src/models/nemotron.cpp +122 -0
  565. package/cpp/src/models/neo-bert.cpp +104 -0
  566. package/cpp/src/models/olmo.cpp +121 -0
  567. package/cpp/src/models/olmo2.cpp +150 -0
  568. package/cpp/src/models/olmoe.cpp +124 -0
  569. package/cpp/src/models/openai-moe-iswa.cpp +127 -0
  570. package/cpp/src/models/openelm.cpp +124 -0
  571. package/cpp/src/models/orion.cpp +123 -0
  572. package/cpp/src/models/paddleocr.cpp +122 -0
  573. package/cpp/src/models/pangu-embedded.cpp +121 -0
  574. package/cpp/src/models/phi2.cpp +121 -0
  575. package/cpp/src/models/phi3.cpp +152 -0
  576. package/cpp/src/models/plamo.cpp +110 -0
  577. package/cpp/src/models/plamo2.cpp +318 -0
  578. package/cpp/src/models/plamo3.cpp +128 -0
  579. package/cpp/src/models/plm.cpp +169 -0
  580. package/cpp/src/models/qwen.cpp +108 -0
  581. package/cpp/src/models/qwen2.cpp +126 -0
  582. package/cpp/src/models/qwen2moe.cpp +151 -0
  583. package/cpp/src/models/qwen2vl.cpp +117 -0
  584. package/cpp/src/models/qwen3.cpp +117 -0
  585. package/cpp/src/models/qwen35.cpp +386 -0
  586. package/cpp/src/models/qwen35moe.cpp +420 -0
  587. package/cpp/src/models/qwen3moe.cpp +124 -0
  588. package/cpp/src/models/qwen3next.cpp +525 -0
  589. package/cpp/src/models/qwen3vl-moe.cpp +140 -0
  590. package/cpp/src/models/qwen3vl.cpp +132 -0
  591. package/cpp/src/models/refact.cpp +94 -0
  592. package/cpp/src/models/rnd1.cpp +126 -0
  593. package/cpp/src/models/rwkv6-base.cpp +164 -0
  594. package/cpp/src/models/rwkv6.cpp +94 -0
  595. package/cpp/src/models/rwkv6qwen2.cpp +86 -0
  596. package/cpp/src/models/rwkv7-base.cpp +137 -0
  597. package/cpp/src/models/rwkv7.cpp +90 -0
  598. package/cpp/src/models/seed-oss.cpp +124 -0
  599. package/cpp/src/models/smallthinker.cpp +126 -0
  600. package/cpp/src/models/smollm3.cpp +128 -0
  601. package/cpp/src/models/stablelm.cpp +146 -0
  602. package/cpp/src/models/starcoder.cpp +100 -0
  603. package/cpp/src/models/starcoder2.cpp +121 -0
  604. package/cpp/src/models/step35-iswa.cpp +168 -0
  605. package/cpp/src/models/t5-dec.cpp +166 -0
  606. package/cpp/src/models/t5-enc.cpp +96 -0
  607. package/cpp/src/models/wavtokenizer-dec.cpp +149 -0
  608. package/cpp/src/models/xverse.cpp +108 -0
  609. package/cpp/src/unicode-data.cpp +7034 -0
  610. package/cpp/src/unicode-data.h +20 -0
  611. package/cpp/src/unicode.cpp +1103 -0
  612. package/cpp/src/unicode.h +111 -0
  613. package/cpp/vendor/nlohmann/json.hpp +25526 -0
  614. package/cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  615. package/cpp/vendor/stb/stb_image.h +7988 -0
  616. package/ios/LocalLLM-Bridging-Header.h +2 -0
  617. package/ios/LocalLLM.h +5 -0
  618. package/ios/LocalLLM.mm +1267 -0
  619. package/local-llm-rn.podspec +60 -0
  620. package/package.json +35 -0
  621. package/src/NativeLocalLLM.ts +73 -0
  622. package/src/device.ts +50 -0
  623. package/src/download-adapter.ts +17 -0
  624. package/src/index.ts +21 -0
  625. package/src/native-bridge.ts +142 -0
  626. package/src/rn-downloader.ts +37 -0
@@ -0,0 +1,2118 @@
1
+ #include "ggml-rpc.h"
2
+ #include "ggml-impl.h"
3
+ #include "ggml-backend-impl.h"
4
+ #include "ggml-cpp.h"
5
+
6
+ #include <cinttypes>
7
+ #include <string>
8
+ #include <vector>
9
+ #include <memory>
10
+ #include <mutex>
11
+ #include <unordered_map>
12
+ #include <unordered_set>
13
+ #ifdef _WIN32
14
+ # define WIN32_LEAN_AND_MEAN
15
+ # ifndef NOMINMAX
16
+ # define NOMINMAX
17
+ # endif
18
+ # include <windows.h>
19
+ # include <winsock2.h>
20
+ #else
21
+ # include <arpa/inet.h>
22
+ # include <sys/socket.h>
23
+ # include <sys/types.h>
24
+ # include <netinet/in.h>
25
+ # include <netinet/tcp.h>
26
+ # include <netdb.h>
27
+ # include <unistd.h>
28
+ #endif
29
+ #include <cstring>
30
+ #include <fstream>
31
+ #include <filesystem>
32
+ #include <algorithm>
33
+
34
+ static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
35
+
36
+ #define LOG_DBG(...) \
37
+ do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)
38
+
39
+
40
+ namespace fs = std::filesystem;
41
+
42
+ static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
43
+
44
+ #ifdef _WIN32
45
+ typedef SOCKET sockfd_t;
46
+ using ssize_t = __int64;
47
+ #else
48
+ typedef int sockfd_t;
49
+ #endif
50
+
51
+ // cross-platform socket
52
+ struct socket_t {
53
+ sockfd_t fd;
54
+ socket_t(sockfd_t fd) : fd(fd) {}
55
+ ~socket_t() {
56
+ LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
57
+ #ifdef _WIN32
58
+ closesocket(this->fd);
59
+ #else
60
+ close(this->fd);
61
+ #endif
62
+ }
63
+ };
64
+
65
+ // macro for nicer error messages on server crash
66
+ #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response")
67
+
68
+ // all RPC structures must be packed
69
+ #pragma pack(push, 1)
70
+ // ggml_tensor is serialized into rpc_tensor
71
+ struct rpc_tensor {
72
+ uint64_t id;
73
+ uint32_t type;
74
+ uint64_t buffer;
75
+ uint32_t ne[GGML_MAX_DIMS];
76
+ uint32_t nb[GGML_MAX_DIMS];
77
+ uint32_t op;
78
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
79
+ int32_t flags;
80
+ uint64_t src[GGML_MAX_SRC];
81
+ uint64_t view_src;
82
+ uint64_t view_offs;
83
+ uint64_t data;
84
+ char name[GGML_MAX_NAME];
85
+
86
+ char padding[4];
87
+ };
88
+
89
+ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
90
+
91
+ // RPC commands
92
+ enum rpc_cmd {
93
+ RPC_CMD_ALLOC_BUFFER = 0,
94
+ RPC_CMD_GET_ALIGNMENT,
95
+ RPC_CMD_GET_MAX_SIZE,
96
+ RPC_CMD_BUFFER_GET_BASE,
97
+ RPC_CMD_FREE_BUFFER,
98
+ RPC_CMD_BUFFER_CLEAR,
99
+ RPC_CMD_SET_TENSOR,
100
+ RPC_CMD_SET_TENSOR_HASH,
101
+ RPC_CMD_GET_TENSOR,
102
+ RPC_CMD_COPY_TENSOR,
103
+ RPC_CMD_GRAPH_COMPUTE,
104
+ RPC_CMD_GET_DEVICE_MEMORY,
105
+ RPC_CMD_INIT_TENSOR,
106
+ RPC_CMD_GET_ALLOC_SIZE,
107
+ RPC_CMD_HELLO,
108
+ RPC_CMD_DEVICE_COUNT,
109
+ RPC_CMD_GRAPH_RECOMPUTE,
110
+ RPC_CMD_COUNT,
111
+ };
112
+
113
+ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
114
+
115
+ // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
116
+ const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
117
+
118
+ struct rpc_msg_hello_rsp {
119
+ uint8_t major;
120
+ uint8_t minor;
121
+ uint8_t patch;
122
+ };
123
+
124
+ struct rpc_msg_device_count_rsp {
125
+ uint32_t device_count;
126
+ };
127
+
128
+ struct rpc_msg_get_alloc_size_req {
129
+ uint32_t device;
130
+ rpc_tensor tensor;
131
+ rpc_tensor srcs[GGML_MAX_SRC];
132
+ };
133
+
134
+ struct rpc_msg_get_alloc_size_rsp {
135
+ uint64_t alloc_size;
136
+ };
137
+
138
+ struct rpc_msg_init_tensor_req {
139
+ rpc_tensor tensor;
140
+ };
141
+
142
+ struct rpc_msg_alloc_buffer_req {
143
+ uint32_t device;
144
+ uint64_t size;
145
+ };
146
+
147
+ struct rpc_msg_alloc_buffer_rsp {
148
+ uint64_t remote_ptr;
149
+ uint64_t remote_size;
150
+ };
151
+
152
+ struct rpc_msg_get_alignment_req {
153
+ uint32_t device;
154
+ };
155
+
156
+ struct rpc_msg_get_alignment_rsp {
157
+ uint64_t alignment;
158
+ };
159
+
160
+ struct rpc_msg_get_max_size_req {
161
+ uint32_t device;
162
+ };
163
+
164
+ struct rpc_msg_get_max_size_rsp {
165
+ uint64_t max_size;
166
+ };
167
+
168
+ struct rpc_msg_buffer_get_base_req {
169
+ uint64_t remote_ptr;
170
+ };
171
+
172
+ struct rpc_msg_buffer_get_base_rsp {
173
+ uint64_t base_ptr;
174
+ };
175
+
176
+ struct rpc_msg_free_buffer_req {
177
+ uint64_t remote_ptr;
178
+ };
179
+
180
+ struct rpc_msg_buffer_clear_req {
181
+ uint64_t remote_ptr;
182
+ uint8_t value;
183
+ };
184
+
185
+ struct rpc_msg_set_tensor_hash_req {
186
+ rpc_tensor tensor;
187
+ uint64_t offset;
188
+ uint64_t hash;
189
+ };
190
+
191
+ struct rpc_msg_set_tensor_hash_rsp {
192
+ uint8_t result;
193
+ };
194
+
195
+ struct rpc_msg_get_tensor_req {
196
+ rpc_tensor tensor;
197
+ uint64_t offset;
198
+ uint64_t size;
199
+ };
200
+
201
+ struct rpc_msg_copy_tensor_req {
202
+ rpc_tensor src;
203
+ rpc_tensor dst;
204
+ };
205
+
206
+ struct rpc_msg_copy_tensor_rsp {
207
+ uint8_t result;
208
+ };
209
+
210
+ struct rpc_msg_get_device_memory_req {
211
+ uint32_t device;
212
+ };
213
+
214
+ struct rpc_msg_get_device_memory_rsp {
215
+ uint64_t free_mem;
216
+ uint64_t total_mem;
217
+ };
218
+
219
+ struct rpc_msg_graph_recompute_req {
220
+ uint32_t device;
221
+ };
222
+
223
+ #pragma pack(pop)
224
+
225
+ // RPC data structures
226
+
227
+ static ggml_guid_t ggml_backend_rpc_guid() {
228
+ static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
229
+ return &guid;
230
+ }
231
+
232
+ struct ggml_backend_rpc_buffer_type_context {
233
+ std::string endpoint;
234
+ uint32_t device;
235
+ std::string name;
236
+ size_t alignment;
237
+ size_t max_size;
238
+ };
239
+
240
+ struct graph_cache {
241
+
242
+ bool is_cached(const ggml_cgraph * cgraph) {
243
+ if ((int)last_graph.size() != cgraph->n_nodes) {
244
+ return false;
245
+ }
246
+ for (int i = 0; i < cgraph->n_nodes; i++) {
247
+ if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
248
+ return false;
249
+ }
250
+ }
251
+ return true;
252
+ }
253
+
254
+ void add(const ggml_cgraph * cgraph) {
255
+ last_graph.resize(cgraph->n_nodes);
256
+ for (int i = 0; i < cgraph->n_nodes; i++) {
257
+ memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
258
+ }
259
+ }
260
+
261
+ std::vector<ggml_tensor> last_graph;
262
+ };
263
+
264
+ struct ggml_backend_rpc_context {
265
+ std::string endpoint;
266
+ uint32_t device;
267
+ std::string name;
268
+ graph_cache gc;
269
+ };
270
+
271
+ struct ggml_backend_rpc_buffer_context {
272
+ std::shared_ptr<socket_t> sock;
273
+ void * base_ptr;
274
+ uint64_t remote_ptr;
275
+ };
276
+
277
+ // RPC helper functions
278
+
279
+ // Computes FNV-1a hash of the data
280
+ static uint64_t fnv_hash(const uint8_t * data, size_t len) {
281
+ const uint64_t fnv_prime = 0x100000001b3ULL;
282
+ uint64_t hash = 0xcbf29ce484222325ULL;
283
+
284
+ for (size_t i = 0; i < len; ++i) {
285
+ hash ^= data[i];
286
+ hash *= fnv_prime;
287
+ }
288
+ return hash;
289
+ }
290
+
291
+ static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
292
+ #ifdef _WIN32
293
+ if (fd == INVALID_SOCKET) {
294
+ return nullptr;
295
+ }
296
+ #else
297
+ if (fd < 0) {
298
+ return nullptr;
299
+ }
300
+ #endif
301
+ return std::make_shared<socket_t>(fd);
302
+ }
303
+
304
+ static bool set_no_delay(sockfd_t sockfd) {
305
+ int flag = 1;
306
+ // set TCP_NODELAY to disable Nagle's algorithm
307
+ int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
308
+ return ret == 0;
309
+ }
310
+
311
+ static bool set_reuse_addr(sockfd_t sockfd) {
312
+ int flag = 1;
313
+ int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
314
+ return ret == 0;
315
+ }
316
+
317
+ static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
318
+ struct sockaddr_in addr;
319
+ auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
320
+ auto sock_ptr = make_socket(sockfd);
321
+ if (sock_ptr == nullptr) {
322
+ return nullptr;
323
+ }
324
+ if (!set_no_delay(sockfd)) {
325
+ GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
326
+ return nullptr;
327
+ }
328
+ addr.sin_family = AF_INET;
329
+ addr.sin_port = htons(port);
330
+ struct hostent * server = gethostbyname(host);
331
+ if (server == NULL) {
332
+ GGML_LOG_ERROR("Cannot resolve host '%s'\n", host);
333
+ return nullptr;
334
+ }
335
+ memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
336
+ if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
337
+ return nullptr;
338
+ }
339
+ return sock_ptr;
340
+ }
341
+
342
+ static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
343
+ auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
344
+ auto client_socket = make_socket(client_socket_fd);
345
+ if (client_socket == nullptr) {
346
+ return nullptr;
347
+ }
348
+ if (!set_no_delay(client_socket_fd)) {
349
+ GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
350
+ return nullptr;
351
+ }
352
+ return client_socket;
353
+ }
354
+
355
+ static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
356
+ auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
357
+ auto sock = make_socket(sockfd);
358
+ if (sock == nullptr) {
359
+ return nullptr;
360
+ }
361
+ if (!set_reuse_addr(sockfd)) {
362
+ GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n");
363
+ return nullptr;
364
+ }
365
+ if (inet_addr(host) == INADDR_NONE) {
366
+ GGML_LOG_ERROR("Invalid host address: %s\n", host);
367
+ return nullptr;
368
+ }
369
+ struct sockaddr_in serv_addr;
370
+ serv_addr.sin_family = AF_INET;
371
+ serv_addr.sin_addr.s_addr = inet_addr(host);
372
+ serv_addr.sin_port = htons(port);
373
+
374
+ if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
375
+ return nullptr;
376
+ }
377
+ if (listen(sockfd, 1) < 0) {
378
+ return nullptr;
379
+ }
380
+ return sock;
381
+ }
382
+
383
+ static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
384
+ size_t bytes_sent = 0;
385
+ while (bytes_sent < size) {
386
+ size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
387
+ ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);
388
+ if (n < 0) {
389
+ GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
390
+ bytes_sent, size_to_send);
391
+ return false;
392
+ }
393
+ bytes_sent += (size_t)n;
394
+ }
395
+ return true;
396
+ }
397
+
398
+ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
399
+ size_t bytes_recv = 0;
400
+ while (bytes_recv < size) {
401
+ size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
402
+ ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);
403
+ if (n < 0) {
404
+ GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
405
+ bytes_recv, size_to_recv);
406
+ return false;
407
+ }
408
+ if (n == 0) {
409
+ LOG_DBG("recv returned 0 (peer closed?)\n");
410
+ return false;
411
+ }
412
+ bytes_recv += (size_t)n;
413
+ }
414
+ return true;
415
+ }
416
+
417
+ static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
418
+ if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
419
+ return false;
420
+ }
421
+ return send_data(sockfd, msg, msg_size);
422
+ }
423
+
424
+ static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
425
+ uint64_t size;
426
+ if (!recv_data(sockfd, &size, sizeof(size))) {
427
+ return false;
428
+ }
429
+ if (size != msg_size) {
430
+ return false;
431
+ }
432
+ return recv_data(sockfd, msg, msg_size);
433
+ }
434
+
435
+ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
436
+ uint64_t size;
437
+ if (!recv_data(sockfd, &size, sizeof(size))) {
438
+ return false;
439
+ }
440
+ try {
441
+ input.resize(size);
442
+ } catch (const std::bad_alloc & e) {
443
+ GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
444
+ return false;
445
+ }
446
+ return recv_data(sockfd, input.data(), size);
447
+ }
448
+
449
+ static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
450
+ size_t pos = endpoint.find(':');
451
+ if (pos == std::string::npos) {
452
+ return false;
453
+ }
454
+ host = endpoint.substr(0, pos);
455
+ port = std::stoi(endpoint.substr(pos + 1));
456
+ return true;
457
+ }
458
+
459
+ // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
460
+ // No response
461
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
462
+ uint8_t cmd_byte = cmd;
463
+ if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
464
+ return false;
465
+ }
466
+ if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
467
+ return false;
468
+ }
469
+ if (!send_data(sock->fd, input, input_size)) {
470
+ return false;
471
+ }
472
+ return true;
473
+ }
474
+
475
+ // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
476
+ // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
477
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
478
+ if (!send_rpc_cmd(sock, cmd, input, input_size)) {
479
+ return false;
480
+ }
481
+ // TODO: currently the output_size is always known, do we need support for commands with variable output size?
482
+ // even if we do, we can skip sending output_size from the server for commands with known output size
483
+ uint64_t out_size;
484
+ if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
485
+ return false;
486
+ }
487
+ if (out_size != output_size) {
488
+ return false;
489
+ }
490
+ if (!recv_data(sock->fd, output, output_size)) {
491
+ return false;
492
+ }
493
+ return true;
494
+ }
495
+
496
+ // RPC client-side implementation
497
+
498
+ static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
499
+ rpc_msg_hello_rsp response;
500
+ bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
501
+ RPC_STATUS_ASSERT(status);
502
+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
503
+ GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
504
+ return false;
505
+ }
506
+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
507
+ GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
508
+ }
509
+ return true;
510
+ }
511
+
512
+ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
513
+ static std::mutex mutex;
514
+ std::lock_guard<std::mutex> lock(mutex);
515
+ static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
516
+ static bool initialized = false;
517
+
518
+ auto it = sockets.find(endpoint);
519
+ if (it != sockets.end()) {
520
+ if (auto sock = it->second.lock()) {
521
+ return sock;
522
+ }
523
+ }
524
+ std::string host;
525
+ int port;
526
+ if (!parse_endpoint(endpoint, host, port)) {
527
+ GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
528
+ return nullptr;
529
+ }
530
+ #ifdef _WIN32
531
+ if (!initialized) {
532
+ WSADATA wsaData;
533
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
534
+ if (res != 0) {
535
+ return nullptr;
536
+ }
537
+ initialized = true;
538
+ }
539
+ #else
540
+ GGML_UNUSED(initialized);
541
+ #endif
542
+ auto sock = socket_connect(host.c_str(), port);
543
+ if (sock == nullptr) {
544
+ return nullptr;
545
+ }
546
+ if (!check_server_version(sock)) {
547
+ return nullptr;
548
+ }
549
+ LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
550
+ sockets[endpoint] = sock;
551
+ return sock;
552
+ }
553
+
554
+ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
555
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
556
+ rpc_msg_free_buffer_req request = {ctx->remote_ptr};
557
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
558
+ RPC_STATUS_ASSERT(status);
559
+ delete ctx;
560
+ }
561
+
562
+ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
563
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
564
+ if (ctx->base_ptr != nullptr) {
565
+ return ctx->base_ptr;
566
+ }
567
+ rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
568
+ rpc_msg_buffer_get_base_rsp response;
569
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
570
+ RPC_STATUS_ASSERT(status);
571
+ ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
572
+ return ctx->base_ptr;
573
+ }
574
+
575
+ static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
576
+ return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
577
+ }
578
+
579
+ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
580
+ rpc_tensor result;
581
+ if (!tensor) {
582
+ memset(&result, 0, sizeof(result));
583
+ return result;
584
+ }
585
+
586
+ result.id = reinterpret_cast<uint64_t>(tensor);
587
+ result.type = tensor->type;
588
+ if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
589
+ ggml_backend_buffer_t buffer = tensor->buffer;
590
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
591
+ result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
592
+ } else {
593
+ result.buffer = 0;
594
+ }
595
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
596
+ result.ne[i] = tensor->ne[i];
597
+ result.nb[i] = tensor->nb[i];
598
+ }
599
+ result.op = tensor->op;
600
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
601
+ result.op_params[i] = tensor->op_params[i];
602
+ }
603
+ result.flags = tensor->flags;
604
+ for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
605
+ result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
606
+ }
607
+ result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
608
+ result.view_offs = tensor->view_offs;
609
+ result.data = reinterpret_cast<uint64_t>(tensor->data);
610
+
611
+ // Avoid sending uninitialized data over the wire
612
+ memset(result.name, 0, sizeof(result.name));
613
+ memset(result.padding, 0, sizeof(result.padding));
614
+
615
+ snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
616
+ return result;
617
+ }
618
+
619
+ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
620
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
621
+
622
+ // CUDA backend on the server pads everything to 512 due to CUDA limitations.
623
+ // Due to bandwidth constraints, we only call the server init tensor functions if necessary.
624
+ // In particular, only quantized tensors need padding
625
+ if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
626
+ rpc_msg_init_tensor_req request;
627
+
628
+ request.tensor = serialize_tensor(tensor);
629
+
630
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
631
+ RPC_STATUS_ASSERT(status);
632
+ }
633
+ return GGML_STATUS_SUCCESS;
634
+ }
635
+
636
+ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
637
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
638
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
639
+ if (size > HASH_THRESHOLD) {
640
+ rpc_msg_set_tensor_hash_req request;
641
+ request.tensor = rpc_tensor;
642
+ request.offset = offset;
643
+ request.hash = fnv_hash((const uint8_t*)data, size);
644
+ rpc_msg_set_tensor_hash_rsp response;
645
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
646
+ RPC_STATUS_ASSERT(status);
647
+ if (response.result) {
648
+ // the server has the same data, no need to send it
649
+ return;
650
+ }
651
+ }
652
+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
653
+ size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
654
+ std::vector<uint8_t> input(input_size, 0);
655
+ memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
656
+ memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
657
+ memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
658
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
659
+ RPC_STATUS_ASSERT(status);
660
+ }
661
+
662
+ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
663
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
664
+ rpc_msg_get_tensor_req request;
665
+ request.tensor = serialize_tensor(tensor);
666
+ request.offset = offset;
667
+ request.size = size;
668
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
669
+ RPC_STATUS_ASSERT(status);
670
+ }
671
+
672
+ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
673
+ if (ggml_backend_buffer_is_rpc(src->buffer)) {
674
+ // check if src and dst are on the same server
675
+ ggml_backend_buffer_t src_buffer = src->buffer;
676
+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
677
+ ggml_backend_buffer_t dst_buffer = dst->buffer;
678
+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
679
+ if (src_ctx->sock != dst_ctx->sock) {
680
+ return false;
681
+ }
682
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
683
+ rpc_msg_copy_tensor_req request;
684
+ request.src = serialize_tensor(src);
685
+ request.dst = serialize_tensor(dst);
686
+ rpc_msg_copy_tensor_rsp response;
687
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
688
+ RPC_STATUS_ASSERT(status);
689
+ return response.result;
690
+ }
691
+ return false;
692
+ }
693
+
694
+ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
695
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
696
+ rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
697
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
698
+ RPC_STATUS_ASSERT(status);
699
+ }
700
+
701
+ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
702
+ /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
703
+ /* .get_base = */ ggml_backend_rpc_buffer_get_base,
704
+ /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
705
+ /* .memset_tensor = */ NULL,
706
+ /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
707
+ /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
708
+ /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
709
+ /* .clear = */ ggml_backend_rpc_buffer_clear,
710
+ /* .reset = */ NULL,
711
+ };
712
+
713
+ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
714
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
715
+ return buft_ctx->name.c_str();
716
+ }
717
+
718
+ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
719
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
720
+ rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
721
+ rpc_msg_alloc_buffer_rsp response;
722
+ auto sock = get_socket(buft_ctx->endpoint);
723
+ bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
724
+ RPC_STATUS_ASSERT(status);
725
+ if (response.remote_ptr != 0) {
726
+ ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
727
+ ggml_backend_rpc_buffer_interface,
728
+ new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
729
+ response.remote_size);
730
+ return buffer;
731
+ } else {
732
+ return nullptr;
733
+ }
734
+ }
735
+
736
+ static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
737
+ rpc_msg_get_alignment_req request = {device};
738
+ rpc_msg_get_alignment_rsp response;
739
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
740
+ RPC_STATUS_ASSERT(status);
741
+ return response.alignment;
742
+ }
743
+
744
+ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
745
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
746
+ return buft_ctx->alignment;
747
+ }
748
+
749
+ static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
750
+ rpc_msg_get_max_size_req request = {device};
751
+ rpc_msg_get_max_size_rsp response;
752
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
753
+ RPC_STATUS_ASSERT(status);
754
+ return response.max_size;
755
+ }
756
+
757
+ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
758
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
759
+ return buft_ctx->max_size;
760
+ }
761
+
762
+ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
763
+ // should we query the remote server for the actual size
764
+ bool rpc_get = false;
765
+
766
+ // See comments in init_tensor.
767
+ rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
768
+
769
+ // ops that require additional memory for fleeting data on certain backends
770
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15966
771
+ rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
772
+ rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
773
+
774
+ if (rpc_get) {
775
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
776
+ auto sock = get_socket(buft_ctx->endpoint);
777
+
778
+ rpc_msg_get_alloc_size_req request = {
779
+ /*.device =*/ buft_ctx->device,
780
+ /*.tensor =*/ serialize_tensor(tensor),
781
+ /*.srcs =*/ {},
782
+ };
783
+
784
+ // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
785
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
786
+ request.srcs[i] = serialize_tensor(tensor->src[i]);
787
+ }
788
+
789
+ // TODO: cache the alloc responses to avoid extra RPC calls?
790
+ rpc_msg_get_alloc_size_rsp response;
791
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
792
+ RPC_STATUS_ASSERT(status);
793
+
794
+ return response.alloc_size;
795
+ }
796
+
797
+ return ggml_nbytes(tensor);
798
+ }
799
+
800
+ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
801
+ /* .get_name = */ ggml_backend_rpc_buffer_type_name,
802
+ /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
803
+ /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
804
+ /* .get_max_size = */ ggml_backend_rpc_get_max_size,
805
+ /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
806
+ /* .is_host = */ NULL,
807
+ };
808
+
809
+ static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
810
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
811
+
812
+ return rpc_ctx->name.c_str();
813
+ }
814
+
815
+ static void ggml_backend_rpc_free(ggml_backend_t backend) {
816
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
817
+ delete rpc_ctx;
818
+ delete backend;
819
+ }
820
+
821
+ static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
822
+ GGML_UNUSED(backend);
823
+ // this is no-op because we don't have any async operations
824
+ }
825
+
826
+ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
827
+ if (tensor == nullptr) {
828
+ return;
829
+ }
830
+ if (visited.find(tensor) != visited.end()) {
831
+ return;
832
+ }
833
+ visited.insert(tensor);
834
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
835
+ add_tensor(tensor->src[i], tensors, visited);
836
+ }
837
+ add_tensor(tensor->view_src, tensors, visited);
838
+ tensors.push_back(serialize_tensor(tensor));
839
+ }
840
+
841
+ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
842
+ uint32_t n_nodes = cgraph->n_nodes;
843
+ std::vector<rpc_tensor> tensors;
844
+ std::unordered_set<ggml_tensor*> visited;
845
+ for (uint32_t i = 0; i < n_nodes; i++) {
846
+ add_tensor(cgraph->nodes[i], tensors, visited);
847
+ }
848
+ // serialization format:
849
+ // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
850
+ uint32_t n_tensors = tensors.size();
851
+ int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
852
+ output.resize(output_size, 0);
853
+ uint8_t * dest = output.data();
854
+ memcpy(dest, &device, sizeof(device));
855
+ dest += sizeof(device);
856
+ memcpy(dest, &n_nodes, sizeof(n_nodes));
857
+ dest += sizeof(n_nodes);
858
+ for (uint32_t i = 0; i < n_nodes; i++) {
859
+ memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
860
+ }
861
+ dest += n_nodes * sizeof(uint64_t);
862
+ memcpy(dest, &n_tensors, sizeof(n_tensors));
863
+ dest += sizeof(n_tensors);
864
+ rpc_tensor * out_tensors = (rpc_tensor *)dest;
865
+ memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
866
+ }
867
+
868
+ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
869
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
870
+
871
+ GGML_ASSERT(cgraph->n_nodes > 0);
872
+ bool reuse = rpc_ctx->gc.is_cached(cgraph);
873
+ if (reuse) {
874
+ rpc_msg_graph_recompute_req request;
875
+ request.device = rpc_ctx->device;
876
+ auto sock = get_socket(rpc_ctx->endpoint);
877
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
878
+ RPC_STATUS_ASSERT(status);
879
+ } else {
880
+ rpc_ctx->gc.add(cgraph);
881
+ std::vector<uint8_t> input;
882
+ serialize_graph(rpc_ctx->device, cgraph, input);
883
+ auto sock = get_socket(rpc_ctx->endpoint);
884
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
885
+ RPC_STATUS_ASSERT(status);
886
+ }
887
+ return GGML_STATUS_SUCCESS;
888
+ }
889
+
890
+ static ggml_backend_i ggml_backend_rpc_interface = {
891
+ /* .get_name = */ ggml_backend_rpc_name,
892
+ /* .free = */ ggml_backend_rpc_free,
893
+ /* .set_tensor_async = */ NULL,
894
+ /* .get_tensor_async = */ NULL,
895
+ /* .cpy_tensor_async = */ NULL,
896
+ /* .synchronize = */ ggml_backend_rpc_synchronize,
897
+ /* .graph_plan_create = */ NULL,
898
+ /* .graph_plan_free = */ NULL,
899
+ /* .graph_plan_update = */ NULL,
900
+ /* .graph_plan_compute = */ NULL,
901
+ /* .graph_compute = */ ggml_backend_rpc_graph_compute,
902
+ /* .event_record = */ NULL,
903
+ /* .event_wait = */ NULL,
904
+ /* .graph_optimize = */ NULL,
905
+ };
906
+
907
+ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
908
+ static std::mutex mutex;
909
+ std::lock_guard<std::mutex> lock(mutex);
910
+ std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
911
+ // NOTE: buffer types are allocated and never freed; this is by design
912
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
913
+ auto it = buft_map.find(buft_name);
914
+ if (it != buft_map.end()) {
915
+ return it->second;
916
+ }
917
+ auto sock = get_socket(endpoint);
918
+ if (sock == nullptr) {
919
+ GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
920
+ return nullptr;
921
+ }
922
+ size_t alignment = get_alignment(sock, device);
923
+ size_t max_size = get_max_size(sock, device);
924
+ ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
925
+ /* .endpoint = */ endpoint,
926
+ /* .device = */ device,
927
+ /* .name = */ buft_name,
928
+ /* .alignment = */ alignment,
929
+ /* .max_size = */ max_size
930
+ };
931
+ auto reg = ggml_backend_rpc_add_server(endpoint);
932
+ ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
933
+ /* .iface = */ ggml_backend_rpc_buffer_type_interface,
934
+ /* .device = */ ggml_backend_reg_dev_get(reg, device),
935
+ /* .context = */ buft_ctx
936
+ };
937
+ buft_map[buft_name] = buft;
938
+ return buft;
939
+ }
940
+
941
+ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
942
+ std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
943
+ ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
944
+ /* .endpoint = */ endpoint,
945
+ /* .device = */ device,
946
+ /* .name = */ dev_name,
947
+ /* .gc = */ {},
948
+ };
949
+ auto reg = ggml_backend_rpc_add_server(endpoint);
950
+ ggml_backend_t backend = new ggml_backend {
951
+ /* .guid = */ ggml_backend_rpc_guid(),
952
+ /* .iface = */ ggml_backend_rpc_interface,
953
+ /* .device = */ ggml_backend_reg_dev_get(reg, device),
954
+ /* .context = */ ctx
955
+ };
956
+ return backend;
957
+ }
958
+
959
+ bool ggml_backend_is_rpc(ggml_backend_t backend) {
960
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
961
+ }
962
+
963
+ static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
964
+ rpc_msg_get_device_memory_req request;
965
+ request.device = device;
966
+ rpc_msg_get_device_memory_rsp response;
967
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
968
+ RPC_STATUS_ASSERT(status);
969
+ *free = response.free_mem;
970
+ *total = response.total_mem;
971
+ }
972
+
973
+ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
974
+ auto sock = get_socket(endpoint);
975
+ if (sock == nullptr) {
976
+ *free = 0;
977
+ *total = 0;
978
+ return;
979
+ }
980
+ get_device_memory(sock, device, free, total);
981
+ }
982
+
983
+ // RPC server-side implementation
984
+
985
+ class rpc_server {
986
+ public:
987
+ rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
988
+ : backends(std::move(all_backends)), cache_dir(cache_dir) {
989
+ stored_graphs.resize(backends.size());
990
+ }
991
+ ~rpc_server();
992
+
993
+ void hello(rpc_msg_hello_rsp & response);
994
+ bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
995
+ bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
996
+ bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
997
+ bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
998
+ bool free_buffer(const rpc_msg_free_buffer_req & request);
999
+ bool buffer_clear(const rpc_msg_buffer_clear_req & request);
1000
+ bool set_tensor(const std::vector<uint8_t> & input);
1001
+ bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
1002
+ bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
1003
+ bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
1004
+ bool graph_compute(const std::vector<uint8_t> & input);
1005
+ bool graph_recompute(const rpc_msg_graph_recompute_req & request);
1006
+ bool init_tensor(const rpc_msg_init_tensor_req & request);
1007
+ bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
1008
+ bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
1009
+
1010
+ struct stored_graph {
1011
+ ggml_context_ptr ctx_ptr;
1012
+ ggml_cgraph * graph;
1013
+ };
1014
+
1015
+ private:
1016
+ bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
1017
+ ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
1018
+ ggml_tensor * create_node(uint64_t id,
1019
+ struct ggml_context * ctx,
1020
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
1021
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
1022
+
1023
+
1024
+ std::vector<ggml_backend_t> backends;
1025
+ const char * cache_dir;
1026
+ std::unordered_set<ggml_backend_buffer_t> buffers;
1027
+ // store the last computed graph for each backend
1028
+ std::vector<stored_graph> stored_graphs;
1029
+ };
1030
+
1031
+ void rpc_server::hello(rpc_msg_hello_rsp & response) {
1032
+ response.major = RPC_PROTO_MAJOR_VERSION;
1033
+ response.minor = RPC_PROTO_MINOR_VERSION;
1034
+ response.patch = RPC_PROTO_PATCH_VERSION;
1035
+ LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
1036
+ }
1037
+
1038
+ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
1039
+ uint32_t dev_id = request.device;
1040
+ if (dev_id >= backends.size()) {
1041
+ return false;
1042
+ }
1043
+ ggml_backend_buffer_type_t buft;
1044
+ struct ggml_init_params params {
1045
+ /*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
1046
+ /*.mem_buffer =*/ NULL,
1047
+ /*.no_alloc =*/ true,
1048
+ };
1049
+
1050
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1051
+ GGML_ASSERT(ctx_ptr != nullptr);
1052
+ ggml_context * ctx = ctx_ptr.get();
1053
+
1054
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1055
+ if (tensor == nullptr) {
1056
+ GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
1057
+ return false;
1058
+ }
1059
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1060
+ if (request.srcs[i].id != 0) {
1061
+ tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
1062
+ }
1063
+ }
1064
+
1065
+ LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
1066
+ if (tensor->buffer == nullptr) {
1067
+ //No buffer allocated.
1068
+ buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1069
+ } else {
1070
+ buft = tensor->buffer->buft;
1071
+ }
1072
+
1073
+ response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor);
1074
+
1075
+ return true;
1076
+ }
1077
+
1078
+ bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
1079
+ uint32_t dev_id = request.device;
1080
+ if (dev_id >= backends.size()) {
1081
+ return false;
1082
+ }
1083
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1084
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
1085
+ response.remote_ptr = 0;
1086
+ response.remote_size = 0;
1087
+ if (buffer != nullptr) {
1088
+ response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
1089
+ response.remote_size = buffer->size;
1090
+ LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
1091
+ __func__, dev_id, request.size, response.remote_ptr, response.remote_size);
1092
+ buffers.insert(buffer);
1093
+ } else {
1094
+ LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
1095
+ }
1096
+ return true;
1097
+ }
1098
+
1099
+ bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
1100
+ uint32_t dev_id = request.device;
1101
+ if (dev_id >= backends.size()) {
1102
+ return false;
1103
+ }
1104
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1105
+ size_t alignment = ggml_backend_buft_get_alignment(buft);
1106
+ LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
1107
+ response.alignment = alignment;
1108
+ return true;
1109
+ }
1110
+
1111
+ bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
1112
+ uint32_t dev_id = request.device;
1113
+ if (dev_id >= backends.size()) {
1114
+ return false;
1115
+ }
1116
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
1117
+ size_t max_size = ggml_backend_buft_get_max_size(buft);
1118
+ LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
1119
+ response.max_size = max_size;
1120
+ return true;
1121
+ }
1122
+
1123
+ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
1124
+ LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
1125
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1126
+ if (buffers.find(buffer) == buffers.end()) {
1127
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1128
+ return false;
1129
+ }
1130
+ void * base = ggml_backend_buffer_get_base(buffer);
1131
+ response.base_ptr = reinterpret_cast<uint64_t>(base);
1132
+ return true;
1133
+ }
1134
+
1135
+ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
1136
+ LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
1137
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1138
+ if (buffers.find(buffer) == buffers.end()) {
1139
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1140
+ return false;
1141
+ }
1142
+ ggml_backend_buffer_free(buffer);
1143
+ buffers.erase(buffer);
1144
+ return true;
1145
+ }
1146
+
1147
+ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
1148
+ LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
1149
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
1150
+ if (buffers.find(buffer) == buffers.end()) {
1151
+ GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
1152
+ return false;
1153
+ }
1154
+ ggml_backend_buffer_clear(buffer, request.value);
1155
+ return true;
1156
+ }
1157
+
1158
+ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
1159
+ // Validate tensor type before using it
1160
+ if (tensor->type >= GGML_TYPE_COUNT) {
1161
+ GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
1162
+ return nullptr;
1163
+ }
1164
+
1165
+ ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
1166
+ tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1167
+
1168
+ // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
1169
+ if (result == nullptr) {
1170
+ GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
1171
+ return nullptr;
1172
+ }
1173
+
1174
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
1175
+ result->nb[i] = tensor->nb[i];
1176
+ }
1177
+ result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
1178
+ if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
1179
+ result->buffer = nullptr;
1180
+ }
1181
+
1182
+ if (result->buffer) {
1183
+ // require that the tensor data does not go beyond the buffer end
1184
+ uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
1185
+ uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
1186
+ uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
1187
+ GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
1188
+ GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
1189
+ }
1190
+
1191
+ result->op = (ggml_op) tensor->op;
1192
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
1193
+ result->op_params[i] = tensor->op_params[i];
1194
+ }
1195
+ result->flags = tensor->flags;
1196
+ result->data = reinterpret_cast<void *>(tensor->data);
1197
+ ggml_set_name(result, tensor->name);
1198
+ return result;
1199
+ }
1200
+
1201
+
1202
+ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
1203
+ // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
1204
+ if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
1205
+ return false;
1206
+ }
1207
+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
1208
+ uint64_t offset;
1209
+ memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
1210
+ const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
1211
+
1212
+ struct ggml_init_params params {
1213
+ /*.mem_size =*/ ggml_tensor_overhead(),
1214
+ /*.mem_buffer =*/ NULL,
1215
+ /*.no_alloc =*/ true,
1216
+ };
1217
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1218
+ GGML_ASSERT(ctx_ptr != nullptr);
1219
+ ggml_context * ctx = ctx_ptr.get();
1220
+ ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1221
+ if (tensor == nullptr || tensor->buffer == nullptr) {
1222
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1223
+ return false;
1224
+ }
1225
+ LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
1226
+
1227
+ // sanitize tensor->data
1228
+ {
1229
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1230
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1231
+
1232
+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1233
+ GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
1234
+ __func__, in_tensor->data, offset, size, p0, p1);
1235
+ return false;
1236
+ }
1237
+ }
1238
+
1239
+ const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
1240
+ if (cache_dir && size > HASH_THRESHOLD) {
1241
+ uint64_t hash = fnv_hash((const uint8_t*)data, size);
1242
+ char hash_str[17];
1243
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1244
+ // save to cache_dir/hash_str
1245
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
1246
+ std::ofstream ofs(cache_file, std::ios::binary);
1247
+ ofs.write((const char *)data, size);
1248
+ GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
1249
+ }
1250
+ ggml_backend_tensor_set(tensor, data, offset, size);
1251
+ return true;
1252
+ }
1253
+
1254
+ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1255
+ if (!cache_dir) {
1256
+ return false;
1257
+ }
1258
+ char hash_str[17];
1259
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1260
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
1261
+ std::error_code ec;
1262
+ if (!fs::exists(cache_file, ec)) {
1263
+ return false;
1264
+ }
1265
+ std::ifstream ifs(cache_file, std::ios::binary);
1266
+ ifs.seekg(0, std::ios::end);
1267
+ size_t size = ifs.tellg();
1268
+ ifs.seekg(0, std::ios::beg);
1269
+ data.resize(size);
1270
+ ifs.read((char *)data.data(), size);
1271
+ return true;
1272
+ }
1273
+
1274
+ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
1275
+ {
1276
+ std::vector<uint8_t> cached_file;
1277
+ if (!get_cached_file(request.hash, cached_file)) {
1278
+ response.result = 0;
1279
+ return true;
1280
+ }
1281
+ size_t size = cached_file.size();
1282
+ struct ggml_init_params params {
1283
+ /*.mem_size =*/ ggml_tensor_overhead(),
1284
+ /*.mem_buffer =*/ NULL,
1285
+ /*.no_alloc =*/ true,
1286
+ };
1287
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1288
+ GGML_ASSERT(ctx_ptr != nullptr);
1289
+ ggml_context * ctx = ctx_ptr.get();
1290
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1291
+ if (tensor == nullptr || tensor->buffer == nullptr) {
1292
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1293
+ return false;
1294
+ }
1295
+ LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
1296
+ __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
1297
+
1298
+ // sanitize tensor->data
1299
+ {
1300
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1301
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1302
+
1303
+ if (request.tensor.data + request.offset < p0
1304
+ || request.tensor.data + request.offset >= p1
1305
+ || size > (p1 - request.tensor.data - request.offset)) {
1306
+ GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1307
+ __func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
1308
+ return false;
1309
+ }
1310
+ }
1311
+ ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
1312
+ response.result = 1;
1313
+ return true;
1314
+ }
1315
+
1316
+ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
1317
+ struct ggml_init_params params {
1318
+ /*.mem_size =*/ ggml_tensor_overhead(),
1319
+ /*.mem_buffer =*/ NULL,
1320
+ /*.no_alloc =*/ true,
1321
+ };
1322
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1323
+ GGML_ASSERT(ctx_ptr != nullptr);
1324
+ ggml_context * ctx = ctx_ptr.get();
1325
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1326
+ if (tensor == nullptr) {
1327
+ GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
1328
+ return false;
1329
+ }
1330
+ LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
1331
+ // Call the backend's buffer_init_tensor function
1332
+ ggml_backend_buffer_t buffer = tensor->buffer;
1333
+ if (buffer && buffer->iface.init_tensor) {
1334
+ buffer->iface.init_tensor(buffer, tensor);
1335
+ } else {
1336
+ GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
1337
+ }
1338
+
1339
+ if (tensor->extra != nullptr) {
1340
+ // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
1341
+ // Currently unimplemented.
1342
+ GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
1343
+ return false;
1344
+ }
1345
+
1346
+ return true;
1347
+ }
1348
+
1349
+ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
1350
+ struct ggml_init_params params {
1351
+ /*.mem_size =*/ ggml_tensor_overhead(),
1352
+ /*.mem_buffer =*/ NULL,
1353
+ /*.no_alloc =*/ true,
1354
+ };
1355
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1356
+ GGML_ASSERT(ctx_ptr != nullptr);
1357
+ ggml_context * ctx = ctx_ptr.get();
1358
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1359
+ if (tensor == nullptr || tensor->buffer == nullptr) {
1360
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1361
+ return false;
1362
+ }
1363
+ LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
1364
+
1365
+ // sanitize tensor->data
1366
+ {
1367
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1368
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1369
+
1370
+ if (request.tensor.data + request.offset < p0 ||
1371
+ request.tensor.data + request.offset >= p1 ||
1372
+ request.size > (p1 - request.tensor.data - request.offset)) {
1373
+ GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1374
+ __func__, request.tensor.data, request.offset, request.size, p0, p1);
1375
+ return false;
1376
+ }
1377
+ }
1378
+
1379
+ response.resize(request.size, 0);
1380
+ ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
1381
+ return true;
1382
+ }
1383
+
1384
+ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
1385
+ struct ggml_init_params params {
1386
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
1387
+ /*.mem_buffer =*/ NULL,
1388
+ /*.no_alloc =*/ true,
1389
+ };
1390
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1391
+ GGML_ASSERT(ctx_ptr != nullptr);
1392
+ ggml_context * ctx = ctx_ptr.get();
1393
+
1394
+ ggml_tensor * src = deserialize_tensor(ctx, &request.src);
1395
+ ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
1396
+ if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
1397
+ GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
1398
+ return false;
1399
+ }
1400
+
1401
+ uint64_t src_size = (uint64_t) ggml_nbytes(src);
1402
+ uint64_t dst_data = (uint64_t) dst->data;
1403
+ uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
1404
+ uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
1405
+
1406
+ if (dst_data + src_size > dst_base + dst_buf_sz) {
1407
+ GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
1408
+ " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
1409
+ " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
1410
+ __func__,
1411
+ dst_data,
1412
+ dst_data + src_size,
1413
+ dst_base,
1414
+ dst_base + dst_buf_sz);
1415
+ return false;
1416
+ }
1417
+
1418
+ LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n",
1419
+ __func__, (void*) src->buffer, (void*) dst->buffer);
1420
+
1421
+ response.result = ggml_backend_buffer_copy_tensor(src, dst);
1422
+ return true;
1423
+ }
1424
+
1425
+ ggml_tensor * rpc_server::create_node(uint64_t id,
1426
+ struct ggml_context * ctx,
1427
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
1428
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
1429
+ if (tensor_map.find(id) != tensor_map.end()) {
1430
+ return tensor_map[id];
1431
+ }
1432
+ // Safely find the tensor pointer
1433
+ auto it_ptr = tensor_ptrs.find(id);
1434
+ if (it_ptr == tensor_ptrs.end()) {
1435
+ return nullptr;
1436
+ }
1437
+ const rpc_tensor * tensor = it_ptr->second;
1438
+
1439
+ struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
1440
+ if (result == nullptr) {
1441
+ return nullptr;
1442
+ }
1443
+ tensor_map[id] = result;
1444
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
1445
+ // Check if the source ID is 0 before calling create_node recursively
1446
+ if (tensor->src[i] == 0) {
1447
+ result->src[i] = nullptr;
1448
+ } else {
1449
+ result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1450
+ // If the recursive call failed for a non-zero ID, propagate the error
1451
+ if (result->src[i] == nullptr) {
1452
+ GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1453
+ __func__, i, tensor->src[i], id);
1454
+ // Must return nullptr to signal failure up the call stack
1455
+ return nullptr;
1456
+ }
1457
+ }
1458
+ }
1459
+
1460
+ // Handle view_src similarly
1461
+ if (tensor->view_src == 0) {
1462
+ result->view_src = nullptr;
1463
+ } else {
1464
+ result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
1465
+ // If the recursive call failed for a non-zero ID, propagate the error
1466
+ if (result->view_src == nullptr) {
1467
+ GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1468
+ __func__, tensor->view_src, id);
1469
+ // Must return nullptr to signal failure up the call stack
1470
+ return nullptr;
1471
+ }
1472
+ }
1473
+ result->view_offs = tensor->view_offs;
1474
+ return result;
1475
+ }
1476
+
1477
+ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
1478
+ // serialization format:
1479
+ // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1480
+ if (input.size() < 2*sizeof(uint32_t)) {
1481
+ return false;
1482
+ }
1483
+ const uint8_t * src = input.data();
1484
+ uint32_t device;
1485
+ memcpy(&device, src, sizeof(device));
1486
+ src += sizeof(device);
1487
+ if (device >= backends.size()) {
1488
+ return false;
1489
+ }
1490
+ uint32_t n_nodes;
1491
+ memcpy(&n_nodes, src, sizeof(n_nodes));
1492
+ src += sizeof(n_nodes);
1493
+ if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
1494
+ return false;
1495
+ }
1496
+ const uint64_t * nodes = (const uint64_t *)src;
1497
+ src += n_nodes*sizeof(uint64_t);
1498
+ uint32_t n_tensors;
1499
+ memcpy(&n_tensors, src, sizeof(n_tensors));
1500
+ src += sizeof(n_tensors);
1501
+ if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1502
+ return false;
1503
+ }
1504
+ const rpc_tensor * tensors = (const rpc_tensor *)src;
1505
+ LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
1506
+
1507
+ size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1508
+
1509
+ struct ggml_init_params params = {
1510
+ /*.mem_size =*/ buf_size,
1511
+ /*.mem_buffer =*/ NULL,
1512
+ /*.no_alloc =*/ true,
1513
+ };
1514
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1515
+ GGML_ASSERT(ctx_ptr != nullptr);
1516
+ ggml_context * ctx = ctx_ptr.get();
1517
+ struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1518
+ graph->n_nodes = n_nodes;
1519
+ std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
1520
+ tensor_ptrs.reserve(n_tensors);
1521
+ for (uint32_t i = 0; i < n_tensors; i++) {
1522
+ tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
1523
+ }
1524
+ std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
1525
+ tensor_map.reserve(n_nodes);
1526
+ for (uint32_t i = 0; i < n_nodes; i++) {
1527
+ int64_t id;
1528
+ memcpy(&id, &nodes[i], sizeof(id));
1529
+ graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1530
+
1531
+ // Check if create_node failed for a *non-zero* ID.
1532
+ // If id was 0, create_node returning nullptr is expected.
1533
+ // If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
1534
+ if (graph->nodes[i] == nullptr && id != 0) {
1535
+ GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
1536
+ return false;
1537
+ }
1538
+ }
1539
+ ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1540
+ GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1541
+ stored_graphs[device].ctx_ptr.swap(ctx_ptr);
1542
+ stored_graphs[device].graph = graph;
1543
+ return true;
1544
+ }
1545
+
1546
+ bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
1547
+ uint32_t device = request.device;
1548
+ if (device >= backends.size()) {
1549
+ return false;
1550
+ }
1551
+ if (stored_graphs[device].graph == nullptr) {
1552
+ return false;
1553
+ }
1554
+ ggml_cgraph * graph = stored_graphs[device].graph;
1555
+ LOG_DBG("[%s] device: %u\n", __func__, device);
1556
+ ggml_status status = ggml_backend_graph_compute(backends[device], graph);
1557
+ GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1558
+ return true;
1559
+ }
1560
+
1561
+ bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1562
+ uint32_t dev_id = request.device;
1563
+ if (dev_id >= backends.size()) {
1564
+ return false;
1565
+ }
1566
+ size_t free, total;
1567
+ ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
1568
+ ggml_backend_dev_memory(dev, &free, &total);
1569
+ response.free_mem = free;
1570
+ response.total_mem = total;
1571
+ LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
1572
+ return true;
1573
+ }
1574
+
1575
+ rpc_server::~rpc_server() {
1576
+ for (auto buffer : buffers) {
1577
+ ggml_backend_buffer_free(buffer);
1578
+ }
1579
+ }
1580
+
1581
+ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
1582
+ sockfd_t sockfd) {
1583
+ rpc_server server(backends, cache_dir);
1584
+ uint8_t cmd;
1585
+ if (!recv_data(sockfd, &cmd, 1)) {
1586
+ return;
1587
+ }
1588
+ // the first command sent by the client must be HELLO
1589
+ if (cmd != RPC_CMD_HELLO) {
1590
+ GGML_LOG_ERROR("Expected HELLO command, update client\n");
1591
+ return;
1592
+ }
1593
+ if (!recv_msg(sockfd, nullptr, 0)) {
1594
+ return;
1595
+ }
1596
+ rpc_msg_hello_rsp response;
1597
+ server.hello(response);
1598
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1599
+ return;
1600
+ }
1601
+ while (true) {
1602
+ if (!recv_data(sockfd, &cmd, 1)) {
1603
+ break;
1604
+ }
1605
+ if (cmd >= RPC_CMD_COUNT) {
1606
+ // fail fast if the command is invalid
1607
+ GGML_LOG_ERROR("Unknown command: %d\n", cmd);
1608
+ break;
1609
+ }
1610
+ switch (cmd) {
1611
+ case RPC_CMD_HELLO: {
1612
+ // HELLO command is handled above
1613
+ return;
1614
+ }
1615
+ case RPC_CMD_DEVICE_COUNT: {
1616
+ if (!recv_msg(sockfd, nullptr, 0)) {
1617
+ return;
1618
+ }
1619
+ rpc_msg_device_count_rsp response;
1620
+ response.device_count = backends.size();
1621
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1622
+ return;
1623
+ }
1624
+ break;
1625
+ }
1626
+ case RPC_CMD_ALLOC_BUFFER: {
1627
+ rpc_msg_alloc_buffer_req request;
1628
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1629
+ return;
1630
+ }
1631
+ rpc_msg_alloc_buffer_rsp response;
1632
+ if (!server.alloc_buffer(request, response)) {
1633
+ return;
1634
+ }
1635
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1636
+ return;
1637
+ }
1638
+ break;
1639
+ }
1640
+ case RPC_CMD_GET_ALLOC_SIZE: {
1641
+ rpc_msg_get_alloc_size_req request;
1642
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1643
+ return;
1644
+ }
1645
+ rpc_msg_get_alloc_size_rsp response;
1646
+ if (!server.get_alloc_size(request, response)) {
1647
+ return;
1648
+ }
1649
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1650
+ return;
1651
+ }
1652
+ break;
1653
+ }
1654
+ case RPC_CMD_GET_ALIGNMENT: {
1655
+ rpc_msg_get_alignment_req request;
1656
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1657
+ return;
1658
+ }
1659
+ rpc_msg_get_alignment_rsp response;
1660
+ if (!server.get_alignment(request, response)) {
1661
+ return;
1662
+ }
1663
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1664
+ return;
1665
+ }
1666
+ break;
1667
+ }
1668
+ case RPC_CMD_GET_MAX_SIZE: {
1669
+ rpc_msg_get_max_size_req request;
1670
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1671
+ return;
1672
+ }
1673
+ rpc_msg_get_max_size_rsp response;
1674
+ if (!server.get_max_size(request, response)) {
1675
+ return;
1676
+ }
1677
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1678
+ return;
1679
+ }
1680
+ break;
1681
+ }
1682
+ case RPC_CMD_BUFFER_GET_BASE: {
1683
+ rpc_msg_buffer_get_base_req request;
1684
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1685
+ return;
1686
+ }
1687
+ rpc_msg_buffer_get_base_rsp response;
1688
+ if (!server.buffer_get_base(request, response)) {
1689
+ return;
1690
+ }
1691
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1692
+ return;
1693
+ }
1694
+ break;
1695
+ }
1696
+ case RPC_CMD_FREE_BUFFER: {
1697
+ rpc_msg_free_buffer_req request;
1698
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1699
+ return;
1700
+ }
1701
+ if (!server.free_buffer(request)) {
1702
+ return;
1703
+ }
1704
+ if (!send_msg(sockfd, nullptr, 0)) {
1705
+ return;
1706
+ }
1707
+ break;
1708
+ }
1709
+ case RPC_CMD_BUFFER_CLEAR: {
1710
+ rpc_msg_buffer_clear_req request;
1711
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1712
+ return;
1713
+ }
1714
+ if (!server.buffer_clear(request)) {
1715
+ return;
1716
+ }
1717
+ if (!send_msg(sockfd, nullptr, 0)) {
1718
+ return;
1719
+ }
1720
+ break;
1721
+ }
1722
+ case RPC_CMD_SET_TENSOR: {
1723
+ std::vector<uint8_t> input;
1724
+ if (!recv_msg(sockfd, input)) {
1725
+ return;
1726
+ }
1727
+ if (!server.set_tensor(input)) {
1728
+ return;
1729
+ }
1730
+ break;
1731
+ }
1732
+ case RPC_CMD_SET_TENSOR_HASH: {
1733
+ rpc_msg_set_tensor_hash_req request;
1734
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1735
+ return;
1736
+ }
1737
+ rpc_msg_set_tensor_hash_rsp response;
1738
+ if (!server.set_tensor_hash(request, response)) {
1739
+ return;
1740
+ }
1741
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1742
+ return;
1743
+ }
1744
+ break;
1745
+ }
1746
+ case RPC_CMD_INIT_TENSOR: {
1747
+ rpc_msg_init_tensor_req request;
1748
+ if (!recv_msg(sockfd, &request,sizeof(request))) {
1749
+ return;
1750
+ }
1751
+ if (!server.init_tensor(request)) {
1752
+ return;
1753
+ }
1754
+ if (!send_msg(sockfd, nullptr, 0)) {
1755
+ return;
1756
+ }
1757
+ break;
1758
+ }
1759
+ case RPC_CMD_GET_TENSOR: {
1760
+ rpc_msg_get_tensor_req request;
1761
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1762
+ return;
1763
+ }
1764
+ std::vector<uint8_t> response;
1765
+ if (!server.get_tensor(request, response)) {
1766
+ return;
1767
+ }
1768
+ if (!send_msg(sockfd, response.data(), response.size())) {
1769
+ return;
1770
+ }
1771
+ break;
1772
+ }
1773
+ case RPC_CMD_COPY_TENSOR: {
1774
+ rpc_msg_copy_tensor_req request;
1775
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1776
+ return;
1777
+ }
1778
+ rpc_msg_copy_tensor_rsp response;
1779
+ if (!server.copy_tensor(request, response)) {
1780
+ return;
1781
+ }
1782
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1783
+ return;
1784
+ }
1785
+ break;
1786
+ }
1787
+ case RPC_CMD_GRAPH_COMPUTE: {
1788
+ std::vector<uint8_t> input;
1789
+ if (!recv_msg(sockfd, input)) {
1790
+ return;
1791
+ }
1792
+ if (!server.graph_compute(input)) {
1793
+ return;
1794
+ }
1795
+ break;
1796
+ }
1797
+ case RPC_CMD_GRAPH_RECOMPUTE: {
1798
+ rpc_msg_graph_recompute_req request;
1799
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1800
+ return;
1801
+ }
1802
+ if (!server.graph_recompute(request)) {
1803
+ return;
1804
+ }
1805
+ break;
1806
+ }
1807
+ case RPC_CMD_GET_DEVICE_MEMORY: {
1808
+ rpc_msg_get_device_memory_req request;
1809
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1810
+ return;
1811
+ }
1812
+ rpc_msg_get_device_memory_rsp response;
1813
+ if (!server.get_device_memory(request, response)) {
1814
+ return;
1815
+ }
1816
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1817
+ return;
1818
+ }
1819
+ break;
1820
+ }
1821
+ default: {
1822
+ GGML_LOG_ERROR("Unknown command: %d\n", cmd);
1823
+ return;
1824
+ }
1825
+ }
1826
+ }
1827
+ }
1828
+
1829
+ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
1830
+ size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
1831
+ if (n_devices == 0 || devices == nullptr) {
1832
+ fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
1833
+ return;
1834
+ }
1835
+ std::vector<ggml_backend_t> backends;
1836
+ printf("Starting RPC server v%d.%d.%d\n",
1837
+ RPC_PROTO_MAJOR_VERSION,
1838
+ RPC_PROTO_MINOR_VERSION,
1839
+ RPC_PROTO_PATCH_VERSION);
1840
+ printf(" endpoint : %s\n", endpoint);
1841
+ printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
1842
+ printf("Devices:\n");
1843
+ for (size_t i = 0; i < n_devices; i++) {
1844
+ auto dev = devices[i];
1845
+ size_t free, total;
1846
+ ggml_backend_dev_memory(dev, &free, &total);
1847
+ printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
1848
+ total / 1024 / 1024, free / 1024 / 1024);
1849
+ auto backend = ggml_backend_dev_init(dev, nullptr);
1850
+ if (!backend) {
1851
+ fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
1852
+ return;
1853
+ }
1854
+ backends.push_back(backend);
1855
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
1856
+ if (reg) {
1857
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
1858
+ if (ggml_backend_set_n_threads_fn) {
1859
+ ggml_backend_set_n_threads_fn(backend, n_threads);
1860
+ }
1861
+ }
1862
+ }
1863
+
1864
+ std::string host;
1865
+ int port;
1866
+ if (!parse_endpoint(endpoint, host, port)) {
1867
+ return;
1868
+ }
1869
+ #ifdef _WIN32
1870
+ {
1871
+ WSADATA wsaData;
1872
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
1873
+ if (res != 0) {
1874
+ fprintf(stderr, "WSAStartup failed: %d\n", res);
1875
+ return;
1876
+ }
1877
+ }
1878
+ #endif
1879
+ auto server_socket = create_server_socket(host.c_str(), port);
1880
+ if (server_socket == nullptr) {
1881
+ fprintf(stderr, "Failed to create server socket\n");
1882
+ return;
1883
+ }
1884
+ while (true) {
1885
+ auto client_socket = socket_accept(server_socket->fd);
1886
+ if (client_socket == nullptr) {
1887
+ fprintf(stderr, "Failed to accept client connection\n");
1888
+ return;
1889
+ }
1890
+ printf("Accepted client connection\n");
1891
+ fflush(stdout);
1892
+ rpc_serve_client(backends, cache_dir, client_socket->fd);
1893
+ printf("Client connection closed\n");
1894
+ fflush(stdout);
1895
+ }
1896
+ #ifdef _WIN32
1897
+ WSACleanup();
1898
+ #endif
1899
+ for (auto backend : backends) {
1900
+ ggml_backend_free(backend);
1901
+ }
1902
+ }
1903
+
1904
+ // device interface
1905
+
1906
+ struct ggml_backend_rpc_device_context {
1907
+ std::string endpoint;
1908
+ uint32_t device;
1909
+ std::string name;
1910
+ std::string description;
1911
+ };
1912
+
1913
+ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1914
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1915
+
1916
+ return ctx->name.c_str();
1917
+ }
1918
+
1919
+ static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1920
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1921
+
1922
+ return ctx->description.c_str();
1923
+ }
1924
+
1925
+ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1926
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1927
+
1928
+ ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
1929
+ }
1930
+
1931
+ static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1932
+ // TODO: obtain value from the server
1933
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
1934
+
1935
+ GGML_UNUSED(dev);
1936
+ }
1937
+
1938
+ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1939
+ props->name = ggml_backend_rpc_device_get_name(dev);
1940
+ props->description = ggml_backend_rpc_device_get_description(dev);
1941
+ props->type = ggml_backend_rpc_device_get_type(dev);
1942
+ ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1943
+ props->caps = {
1944
+ /* .async = */ false,
1945
+ /* .host_buffer = */ false,
1946
+ /* .buffer_from_host_ptr = */ false,
1947
+ /* .events = */ false,
1948
+ };
1949
+ }
1950
+
1951
+ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1952
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1953
+
1954
+ return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
1955
+
1956
+ GGML_UNUSED(params);
1957
+ }
1958
+
1959
+ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1960
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1961
+
1962
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
1963
+
1964
+ GGML_UNUSED(dev);
1965
+ }
1966
+
1967
+ static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1968
+ GGML_UNUSED(dev);
1969
+ GGML_UNUSED(op);
1970
+ //TODO: call the remote backend and cache the results
1971
+ return true;
1972
+ }
1973
+
1974
+ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1975
+ if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1976
+ return false;
1977
+ }
1978
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1979
+ ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1980
+ return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
1981
+ }
1982
+
1983
+ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1984
+ /* .get_name = */ ggml_backend_rpc_device_get_name,
1985
+ /* .get_description = */ ggml_backend_rpc_device_get_description,
1986
+ /* .get_memory = */ ggml_backend_rpc_device_get_memory,
1987
+ /* .get_type = */ ggml_backend_rpc_device_get_type,
1988
+ /* .get_props = */ ggml_backend_rpc_device_get_props,
1989
+ /* .init_backend = */ ggml_backend_rpc_device_init,
1990
+ /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1991
+ /* .get_host_buffer_type = */ NULL,
1992
+ /* .buffer_from_host_ptr = */ NULL,
1993
+ /* .supports_op = */ ggml_backend_rpc_device_supports_op,
1994
+ /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1995
+ /* .offload_op = */ NULL,
1996
+ /* .event_new = */ NULL,
1997
+ /* .event_free = */ NULL,
1998
+ /* .event_synchronize = */ NULL,
1999
+ };
2000
+
2001
+ // backend reg interface
2002
+
2003
+ struct ggml_backend_rpc_reg_context {
2004
+ std::string name;
2005
+ std::vector<ggml_backend_dev_t> devices;
2006
+ };
2007
+
2008
+ static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
2009
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2010
+ return ctx ? ctx->name.c_str() : "RPC";
2011
+ }
2012
+
2013
+ static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
2014
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2015
+ return ctx ? ctx->devices.size() : 0;
2016
+ }
2017
+
2018
+ static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2019
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
2020
+ if (ctx == nullptr) {
2021
+ GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
2022
+ } else {
2023
+ GGML_ASSERT(index < ctx->devices.size());
2024
+ return ctx->devices[index];
2025
+ }
2026
+ }
2027
+
2028
+ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2029
+ if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
2030
+ return (void *)ggml_backend_rpc_add_server;
2031
+ }
2032
+ if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
2033
+ return (void *)ggml_backend_rpc_start_server;
2034
+ }
2035
+ return NULL;
2036
+
2037
+ GGML_UNUSED(reg);
2038
+ }
2039
+
2040
+ static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
2041
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
2042
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
2043
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
2044
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
2045
+ };
2046
+
2047
+ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
2048
+ static struct ggml_backend_reg ggml_backend_rpc_reg = {
2049
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
2050
+ /* .iface = */ ggml_backend_rpc_reg_i,
2051
+ /* .context = */ NULL,
2052
+ };
2053
+
2054
+ return &ggml_backend_rpc_reg;
2055
+ }
2056
+
2057
+ static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
2058
+ auto sock = get_socket(endpoint);
2059
+ if (sock == nullptr) {
2060
+ GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
2061
+ return 0;
2062
+ }
2063
+ rpc_msg_device_count_rsp response;
2064
+ bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
2065
+ RPC_STATUS_ASSERT(status);
2066
+ return response.device_count;
2067
+ }
2068
+
2069
+ static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
2070
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
2071
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
2072
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
2073
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
2074
+ };
2075
+
2076
+ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
2077
+ static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
2078
+ static std::mutex mutex;
2079
+ static uint32_t dev_id = 0;
2080
+ std::lock_guard<std::mutex> lock(mutex);
2081
+ if (reg_map.find(endpoint) != reg_map.end()) {
2082
+ return reg_map[endpoint];
2083
+ }
2084
+ uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
2085
+ if (dev_count == 0) {
2086
+ return nullptr;
2087
+ }
2088
+ ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
2089
+ ctx->name = "RPC[" + std::string(endpoint) + "]";
2090
+ for (uint32_t ind = 0; ind < dev_count; ind++) {
2091
+ std::string dev_name = "RPC" + std::to_string(dev_id);
2092
+ std::string dev_desc = std::string(endpoint);
2093
+ ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
2094
+ /* .endpoint = */ endpoint,
2095
+ /* .device = */ ind,
2096
+ /* .name = */ dev_name,
2097
+ /* .description = */ dev_desc
2098
+ };
2099
+
2100
+ ggml_backend_dev_t dev = new ggml_backend_device {
2101
+ /* .iface = */ ggml_backend_rpc_device_i,
2102
+ /* .reg = */ ggml_backend_rpc_reg(),
2103
+ /* .context = */ dev_ctx,
2104
+ };
2105
+ ctx->devices.push_back(dev);
2106
+ dev_id++;
2107
+ }
2108
+ ggml_backend_reg_t reg = new ggml_backend_reg {
2109
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
2110
+ /* .iface = */ ggml_backend_rpc_reg_interface,
2111
+ /* .context = */ ctx
2112
+ };
2113
+ reg_map[endpoint] = reg;
2114
+ return reg;
2115
+ }
2116
+
2117
+
2118
+ GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)