@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,5 +1,5 @@
1
1
  #include "ggml-rpc.h"
2
- #include "ggml.h"
2
+ #include "ggml-impl.h"
3
3
  #include "ggml-backend-impl.h"
4
4
 
5
5
  #include <cinttypes>
@@ -25,7 +25,7 @@
25
25
  # include <netdb.h>
26
26
  # include <unistd.h>
27
27
  #endif
28
- #include <string.h>
28
+ #include <cstring>
29
29
 
30
30
  #define UNUSED GGML_UNUSED
31
31
 
@@ -57,8 +57,9 @@ struct socket_t {
57
57
  }
58
58
  };
59
59
 
60
- // ggml_tensor is serialized into rpc_tensor
60
+ // all RPC structures must be packed
61
61
  #pragma pack(push, 1)
62
+ // ggml_tensor is serialized into rpc_tensor
62
63
  struct rpc_tensor {
63
64
  uint64_t id;
64
65
  uint32_t type;
@@ -76,25 +77,84 @@ struct rpc_tensor {
76
77
 
77
78
  char padding[4];
78
79
  };
79
- #pragma pack(pop)
80
80
 
81
81
  static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
82
82
 
83
83
  // RPC commands
84
84
  enum rpc_cmd {
85
- ALLOC_BUFFER = 0,
86
- GET_ALIGNMENT,
87
- GET_MAX_SIZE,
88
- BUFFER_GET_BASE,
89
- FREE_BUFFER,
90
- BUFFER_CLEAR,
91
- SET_TENSOR,
92
- GET_TENSOR,
93
- COPY_TENSOR,
94
- GRAPH_COMPUTE,
95
- GET_DEVICE_MEMORY,
85
+ RPC_CMD_ALLOC_BUFFER = 0,
86
+ RPC_CMD_GET_ALIGNMENT,
87
+ RPC_CMD_GET_MAX_SIZE,
88
+ RPC_CMD_BUFFER_GET_BASE,
89
+ RPC_CMD_FREE_BUFFER,
90
+ RPC_CMD_BUFFER_CLEAR,
91
+ RPC_CMD_SET_TENSOR,
92
+ RPC_CMD_GET_TENSOR,
93
+ RPC_CMD_COPY_TENSOR,
94
+ RPC_CMD_GRAPH_COMPUTE,
95
+ RPC_CMD_GET_DEVICE_MEMORY,
96
+ RPC_CMD_COUNT,
97
+ };
98
+
99
+ struct rpc_msg_alloc_buffer_req {
100
+ uint64_t size;
101
+ };
102
+
103
+ struct rpc_msg_alloc_buffer_rsp {
104
+ uint64_t remote_ptr;
105
+ uint64_t remote_size;
106
+ };
107
+
108
+ struct rpc_msg_get_alignment_rsp {
109
+ uint64_t alignment;
110
+ };
111
+
112
+ struct rpc_msg_get_max_size_rsp {
113
+ uint64_t max_size;
96
114
  };
97
115
 
116
+ struct rpc_msg_buffer_get_base_req {
117
+ uint64_t remote_ptr;
118
+ };
119
+
120
+ struct rpc_msg_buffer_get_base_rsp {
121
+ uint64_t base_ptr;
122
+ };
123
+
124
+ struct rpc_msg_free_buffer_req {
125
+ uint64_t remote_ptr;
126
+ };
127
+
128
+ struct rpc_msg_buffer_clear_req {
129
+ uint64_t remote_ptr;
130
+ uint8_t value;
131
+ };
132
+
133
+ struct rpc_msg_get_tensor_req {
134
+ rpc_tensor tensor;
135
+ uint64_t offset;
136
+ uint64_t size;
137
+ };
138
+
139
+ struct rpc_msg_copy_tensor_req {
140
+ rpc_tensor src;
141
+ rpc_tensor dst;
142
+ };
143
+
144
+ struct rpc_msg_copy_tensor_rsp {
145
+ uint8_t result;
146
+ };
147
+
148
+ struct rpc_msg_graph_compute_rsp {
149
+ uint8_t result;
150
+ };
151
+
152
+ struct rpc_msg_get_device_memory_rsp {
153
+ uint64_t free_mem;
154
+ uint64_t total_mem;
155
+ };
156
+ #pragma pack(pop)
157
+
98
158
  // RPC data structures
99
159
 
100
160
  static ggml_guid_t ggml_backend_rpc_guid() {
@@ -118,7 +178,6 @@ struct ggml_backend_rpc_buffer_context {
118
178
  std::shared_ptr<socket_t> sock;
119
179
  std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
120
180
  uint64_t remote_ptr;
121
- std::string name;
122
181
  };
123
182
 
124
183
  // RPC helper functions
@@ -197,6 +256,10 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
197
256
  fprintf(stderr, "Failed to set SO_REUSEADDR\n");
198
257
  return nullptr;
199
258
  }
259
+ if (inet_addr(host) == INADDR_NONE) {
260
+ fprintf(stderr, "Invalid host address: %s\n", host);
261
+ return nullptr;
262
+ }
200
263
  struct sockaddr_in serv_addr;
201
264
  serv_addr.sin_family = AF_INET;
202
265
  serv_addr.sin_addr.s_addr = inet_addr(host);
@@ -235,6 +298,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
235
298
  return true;
236
299
  }
237
300
 
301
+ static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
302
+ if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
303
+ return false;
304
+ }
305
+ return send_data(sockfd, msg, msg_size);
306
+ }
307
+
308
+ static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
309
+ uint64_t size;
310
+ if (!recv_data(sockfd, &size, sizeof(size))) {
311
+ return false;
312
+ }
313
+ if (size != msg_size) {
314
+ return false;
315
+ }
316
+ return recv_data(sockfd, msg, msg_size);
317
+ }
318
+
319
+ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
320
+ uint64_t size;
321
+ if (!recv_data(sockfd, &size, sizeof(size))) {
322
+ return false;
323
+ }
324
+ try {
325
+ input.resize(size);
326
+ } catch (const std::bad_alloc & e) {
327
+ fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
328
+ return false;
329
+ }
330
+ return recv_data(sockfd, input.data(), size);
331
+ }
332
+
238
333
  static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
239
334
  size_t pos = endpoint.find(':');
240
335
  if (pos == std::string::npos) {
@@ -247,28 +342,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
247
342
 
248
343
  // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
249
344
  // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
250
- static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
345
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
251
346
  uint8_t cmd_byte = cmd;
252
347
  if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
253
348
  return false;
254
349
  }
255
- uint64_t input_size = input.size();
256
350
  if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
257
351
  return false;
258
352
  }
259
- if (!send_data(sock->fd, input.data(), input.size())) {
353
+ if (!send_data(sock->fd, input, input_size)) {
260
354
  return false;
261
355
  }
262
- uint64_t output_size;
263
- if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
356
+ // TODO: currently the output_size is always known, do we need support for commands with variable output size?
357
+ // even if we do, we can skip sending output_size from the server for commands with known output size
358
+ uint64_t out_size;
359
+ if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
264
360
  return false;
265
361
  }
266
- if (output_size == 0) {
267
- output.clear();
268
- return true;
362
+ if (out_size != output_size) {
363
+ return false;
269
364
  }
270
- output.resize(output_size);
271
- if (!recv_data(sock->fd, output.data(), output_size)) {
365
+ if (!recv_data(sock->fd, output, output_size)) {
272
366
  return false;
273
367
  }
274
368
  return true;
@@ -314,43 +408,26 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
314
408
  return sock;
315
409
  }
316
410
 
317
- GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
411
+ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
318
412
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
319
- return ctx->name.c_str();
320
- }
321
-
322
- GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
323
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
324
- // input serialization format: | remote_ptr (8 bytes) |
325
- std::vector<uint8_t> input(sizeof(uint64_t), 0);
326
- uint64_t remote_ptr = ctx->remote_ptr;
327
- memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
328
- std::vector<uint8_t> output;
329
- bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
413
+ rpc_msg_free_buffer_req request = {ctx->remote_ptr};
414
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
330
415
  GGML_ASSERT(status);
331
- GGML_ASSERT(output.empty());
332
416
  delete ctx;
333
417
  }
334
418
 
335
- GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
419
+ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
336
420
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
337
421
  if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
338
422
  return ctx->base_cache[buffer];
339
423
  }
340
- // input serialization format: | remote_ptr (8 bytes) |
341
- std::vector<uint8_t> input(sizeof(uint64_t), 0);
342
- uint64_t remote_ptr = ctx->remote_ptr;
343
- memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
344
- std::vector<uint8_t> output;
345
- bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
424
+ rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
425
+ rpc_msg_buffer_get_base_rsp response;
426
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
346
427
  GGML_ASSERT(status);
347
- GGML_ASSERT(output.size() == sizeof(uint64_t));
348
- // output serialization format: | base_ptr (8 bytes) |
349
- uint64_t base_ptr;
350
- memcpy(&base_ptr, output.data(), sizeof(base_ptr));
351
- void * base = reinterpret_cast<void *>(base_ptr);
352
- ctx->base_cache[buffer] = base;
353
- return base;
428
+ void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
429
+ ctx->base_cache[buffer] = base_ptr;
430
+ return base_ptr;
354
431
  }
355
432
 
356
433
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@@ -383,7 +460,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
383
460
  return result;
384
461
  }
385
462
 
386
- GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
463
+ static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
387
464
  UNUSED(buffer);
388
465
  if (ggml_is_quantized(tensor->type)) {
389
466
  // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
@@ -391,7 +468,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t
391
468
  }
392
469
  }
393
470
 
394
- GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
471
+ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
395
472
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
396
473
  // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
397
474
  size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
@@ -400,29 +477,21 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
400
477
  memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
401
478
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
402
479
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
403
- std::vector<uint8_t> output;
404
- bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
480
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
405
481
  GGML_ASSERT(status);
406
482
  }
407
483
 
408
- GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
484
+ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
409
485
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
410
- // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
411
- int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
412
- std::vector<uint8_t> input(input_size, 0);
413
- rpc_tensor rpc_tensor = serialize_tensor(tensor);
414
- memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
415
- memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
416
- memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
417
- std::vector<uint8_t> output;
418
- bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
486
+ rpc_msg_get_tensor_req request;
487
+ request.tensor = serialize_tensor(tensor);
488
+ request.offset = offset;
489
+ request.size = size;
490
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
419
491
  GGML_ASSERT(status);
420
- GGML_ASSERT(output.size() == size);
421
- // output serialization format: | data (size bytes) |
422
- memcpy(data, output.data(), size);
423
492
  }
424
493
 
425
- GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
494
+ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
426
495
  // check if src and dst are on the same server
427
496
  ggml_backend_buffer_t src_buffer = src->buffer;
428
497
  ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
@@ -432,38 +501,27 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
432
501
  return false;
433
502
  }
434
503
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
435
- // input serialization format: | rpc_tensor src | rpc_tensor dst |
436
- int input_size = 2*sizeof(rpc_tensor);
437
- std::vector<uint8_t> input(input_size, 0);
438
- rpc_tensor rpc_src = serialize_tensor(src);
439
- rpc_tensor rpc_dst = serialize_tensor(dst);
440
- memcpy(input.data(), &rpc_src, sizeof(rpc_src));
441
- memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
442
- std::vector<uint8_t> output;
443
- bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
504
+ rpc_msg_copy_tensor_req request;
505
+ request.src = serialize_tensor(src);
506
+ request.dst = serialize_tensor(dst);
507
+ rpc_msg_copy_tensor_rsp response;
508
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
444
509
  GGML_ASSERT(status);
445
- // output serialization format: | result (1 byte) |
446
- GGML_ASSERT(output.size() == 1);
447
- return output[0];
510
+ return response.result;
448
511
  }
449
512
 
450
- GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
513
+ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
451
514
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
452
- // serialization format: | bufptr (8 bytes) | value (1 byte) |
453
- int input_size = sizeof(uint64_t) + sizeof(uint8_t);
454
- std::vector<uint8_t> input(input_size, 0);
455
- memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
456
- memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
457
- std::vector<uint8_t> output;
458
- bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
515
+ rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
516
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
459
517
  GGML_ASSERT(status);
460
518
  }
461
519
 
462
520
  static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
463
- /* .get_name = */ ggml_backend_rpc_buffer_get_name,
464
521
  /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
465
522
  /* .get_base = */ ggml_backend_rpc_buffer_get_base,
466
523
  /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
524
+ /* .memset_tensor = */ NULL,
467
525
  /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
468
526
  /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
469
527
  /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
@@ -471,32 +529,23 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
471
529
  /* .reset = */ NULL,
472
530
  };
473
531
 
474
- GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
532
+ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
475
533
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
476
534
  return buft_ctx->name.c_str();
477
535
  }
478
536
 
479
- GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
537
+ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
480
538
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
481
- // input serialization format: | size (8 bytes) |
482
- int input_size = sizeof(uint64_t);
483
- std::vector<uint8_t> input(input_size, 0);
484
- memcpy(input.data(), &size, sizeof(size));
485
- std::vector<uint8_t> output;
539
+ rpc_msg_alloc_buffer_req request = {size};
540
+ rpc_msg_alloc_buffer_rsp response;
486
541
  auto sock = get_socket(buft_ctx->endpoint);
487
- bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
542
+ bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
488
543
  GGML_ASSERT(status);
489
- GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
490
- // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
491
- uint64_t remote_ptr;
492
- memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
493
- size_t remote_size;
494
- memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
495
- if (remote_ptr != 0) {
544
+ if (response.remote_ptr != 0) {
496
545
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
497
546
  ggml_backend_rpc_buffer_interface,
498
- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
499
- remote_size);
547
+ new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
548
+ response.remote_size);
500
549
  return buffer;
501
550
  } else {
502
551
  return nullptr;
@@ -504,42 +553,30 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
504
553
  }
505
554
 
506
555
  static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
507
- // input serialization format: | 0 bytes |
508
- std::vector<uint8_t> input;
509
- std::vector<uint8_t> output;
510
- bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
556
+ rpc_msg_get_alignment_rsp response;
557
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
511
558
  GGML_ASSERT(status);
512
- GGML_ASSERT(output.size() == sizeof(uint64_t));
513
- // output serialization format: | alignment (8 bytes) |
514
- uint64_t alignment;
515
- memcpy(&alignment, output.data(), sizeof(alignment));
516
- return alignment;
559
+ return response.alignment;
517
560
  }
518
561
 
519
- GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
562
+ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
520
563
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
521
564
  return buft_ctx->alignment;
522
565
  }
523
566
 
524
567
  static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
525
- // input serialization format: | 0 bytes |
526
- std::vector<uint8_t> input;
527
- std::vector<uint8_t> output;
528
- bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
568
+ rpc_msg_get_max_size_rsp response;
569
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
529
570
  GGML_ASSERT(status);
530
- GGML_ASSERT(output.size() == sizeof(uint64_t));
531
- // output serialization format: | max_size (8 bytes) |
532
- uint64_t max_size;
533
- memcpy(&max_size, output.data(), sizeof(max_size));
534
- return max_size;
571
+ return response.max_size;
535
572
  }
536
573
 
537
- GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
574
+ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
538
575
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
539
576
  return buft_ctx->max_size;
540
577
  }
541
578
 
542
- GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
579
+ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
543
580
  UNUSED(buft);
544
581
  return ggml_nbytes(tensor);
545
582
  }
@@ -553,24 +590,19 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
553
590
  /* .is_host = */ NULL,
554
591
  };
555
592
 
556
- GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
593
+ static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
557
594
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
558
595
 
559
596
  return rpc_ctx->name.c_str();
560
597
  }
561
598
 
562
- GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
599
+ static void ggml_backend_rpc_free(ggml_backend_t backend) {
563
600
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
564
601
  delete rpc_ctx;
565
602
  delete backend;
566
603
  }
567
604
 
568
- GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
569
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
570
- return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
571
- }
572
-
573
- GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
605
+ static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
574
606
  UNUSED(backend);
575
607
  // this is no-op because we don't have any async operations
576
608
  }
@@ -612,38 +644,20 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
612
644
  memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
613
645
  }
614
646
 
615
- GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
647
+ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
616
648
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
617
649
  std::vector<uint8_t> input;
618
650
  serialize_graph(cgraph, input);
619
- std::vector<uint8_t> output;
651
+ rpc_msg_graph_compute_rsp response;
620
652
  auto sock = get_socket(rpc_ctx->endpoint);
621
- bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
653
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
622
654
  GGML_ASSERT(status);
623
- GGML_ASSERT(output.size() == 1);
624
- return (enum ggml_status)output[0];
625
- }
626
-
627
- GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
628
- UNUSED(backend);
629
- UNUSED(op);
630
- //TODO: call the remote backend and cache the results
631
- return true;
632
- }
633
-
634
- GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
635
- if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
636
- return false;
637
- }
638
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
639
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
640
- return buft_ctx->endpoint == rpc_ctx->endpoint;
655
+ return (enum ggml_status)response.result;
641
656
  }
642
657
 
643
658
  static ggml_backend_i ggml_backend_rpc_interface = {
644
659
  /* .get_name = */ ggml_backend_rpc_name,
645
660
  /* .free = */ ggml_backend_rpc_free,
646
- /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
647
661
  /* .set_tensor_async = */ NULL,
648
662
  /* .get_tensor_async = */ NULL,
649
663
  /* .cpy_tensor_async = */ NULL,
@@ -653,17 +667,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
653
667
  /* .graph_plan_update = */ NULL,
654
668
  /* .graph_plan_compute = */ NULL,
655
669
  /* .graph_compute = */ ggml_backend_rpc_graph_compute,
656
- /* .supports_op = */ ggml_backend_rpc_supports_op,
657
- /* .supports_buft = */ ggml_backend_rpc_supports_buft,
658
- /* .offload_op = */ NULL,
659
- /* .event_new = */ NULL,
660
- /* .event_free = */ NULL,
661
670
  /* .event_record = */ NULL,
662
671
  /* .event_wait = */ NULL,
663
- /* .event_synchronize = */ NULL,
664
672
  };
665
673
 
666
- GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
674
+ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
667
675
  static std::mutex mutex;
668
676
  std::lock_guard<std::mutex> lock(mutex);
669
677
  // NOTE: buffer types are allocated and never freed; this is by design
@@ -674,6 +682,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
674
682
  }
675
683
  auto sock = get_socket(endpoint);
676
684
  if (sock == nullptr) {
685
+ fprintf(stderr, "Failed to connect to %s\n", endpoint);
677
686
  return nullptr;
678
687
  }
679
688
  size_t alignment = get_alignment(sock);
@@ -687,13 +696,14 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
687
696
 
688
697
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
689
698
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
699
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
690
700
  /* .context = */ buft_ctx
691
701
  };
692
702
  buft_map[endpoint] = buft;
693
703
  return buft;
694
704
  }
695
705
 
696
- GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
706
+ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
697
707
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
698
708
  /* .endpoint = */ endpoint,
699
709
  /* .name = */ "RPC[" + std::string(endpoint) + "]",
@@ -702,32 +712,25 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
702
712
  ggml_backend_t backend = new ggml_backend {
703
713
  /* .guid = */ ggml_backend_rpc_guid(),
704
714
  /* .interface = */ ggml_backend_rpc_interface,
715
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
705
716
  /* .context = */ ctx
706
717
  };
707
718
  return backend;
708
719
  }
709
720
 
710
- GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
721
+ bool ggml_backend_is_rpc(ggml_backend_t backend) {
711
722
  return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
712
723
  }
713
724
 
714
725
  static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
715
- // input serialization format: | 0 bytes |
716
- std::vector<uint8_t> input;
717
- std::vector<uint8_t> output;
718
- bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
726
+ rpc_msg_get_device_memory_rsp response;
727
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
719
728
  GGML_ASSERT(status);
720
- GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
721
- // output serialization format: | free (8 bytes) | total (8 bytes) |
722
- uint64_t free_mem;
723
- memcpy(&free_mem, output.data(), sizeof(free_mem));
724
- uint64_t total_mem;
725
- memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
726
- *free = free_mem;
727
- *total = total_mem;
729
+ *free = response.free_mem;
730
+ *total = response.total_mem;
728
731
  }
729
732
 
730
- GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
733
+ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
731
734
  auto sock = get_socket(endpoint);
732
735
  if (sock == nullptr) {
733
736
  *free = 0;
@@ -744,16 +747,16 @@ public:
744
747
  rpc_server(ggml_backend_t backend) : backend(backend) {}
745
748
  ~rpc_server();
746
749
 
747
- bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
748
- void get_alignment(std::vector<uint8_t> & output);
749
- void get_max_size(std::vector<uint8_t> & output);
750
- bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
751
- bool free_buffer(const std::vector<uint8_t> & input);
752
- bool buffer_clear(const std::vector<uint8_t> & input);
750
+ void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
751
+ void get_alignment(rpc_msg_get_alignment_rsp & response);
752
+ void get_max_size(rpc_msg_get_max_size_rsp & response);
753
+ bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
754
+ bool free_buffer(const rpc_msg_free_buffer_req & request);
755
+ bool buffer_clear(const rpc_msg_buffer_clear_req & request);
753
756
  bool set_tensor(const std::vector<uint8_t> & input);
754
- bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
755
- bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
756
- bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
757
+ bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
758
+ bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
759
+ bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
757
760
 
758
761
  private:
759
762
  ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -767,80 +770,50 @@ private:
767
770
  std::unordered_set<ggml_backend_buffer_t> buffers;
768
771
  };
769
772
 
770
- bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
771
- // input serialization format: | size (8 bytes) |
772
- if (input.size() != sizeof(uint64_t)) {
773
- return false;
774
- }
775
- uint64_t size;
776
- memcpy(&size, input.data(), sizeof(size));
773
+ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
777
774
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
778
- ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
779
- uint64_t remote_ptr = 0;
780
- uint64_t remote_size = 0;
775
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
776
+ response.remote_ptr = 0;
777
+ response.remote_size = 0;
781
778
  if (buffer != nullptr) {
782
- remote_ptr = reinterpret_cast<uint64_t>(buffer);
783
- remote_size = buffer->size;
784
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
779
+ response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
780
+ response.remote_size = buffer->size;
781
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
785
782
  buffers.insert(buffer);
786
783
  } else {
787
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
784
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
788
785
  }
789
- // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
790
- output.resize(2*sizeof(uint64_t), 0);
791
- memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
792
- memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
793
- return true;
794
786
  }
795
787
 
796
- void rpc_server::get_alignment(std::vector<uint8_t> & output) {
788
+ void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
797
789
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
798
790
  size_t alignment = ggml_backend_buft_get_alignment(buft);
799
791
  GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
800
- // output serialization format: | alignment (8 bytes) |
801
- output.resize(sizeof(uint64_t), 0);
802
- memcpy(output.data(), &alignment, sizeof(alignment));
792
+ response.alignment = alignment;
803
793
  }
804
794
 
805
- void rpc_server::get_max_size(std::vector<uint8_t> & output) {
795
+ void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
806
796
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
807
797
  size_t max_size = ggml_backend_buft_get_max_size(buft);
808
798
  GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
809
- // output serialization format: | max_size (8 bytes) |
810
- output.resize(sizeof(uint64_t), 0);
811
- memcpy(output.data(), &max_size, sizeof(max_size));
799
+ response.max_size = max_size;
812
800
  }
813
801
 
814
- bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
815
- // input serialization format: | remote_ptr (8 bytes) |
816
- if (input.size() != sizeof(uint64_t)) {
817
- return false;
818
- }
819
- uint64_t remote_ptr;
820
- memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
821
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
822
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
802
+ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
803
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
804
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
823
805
  if (buffers.find(buffer) == buffers.end()) {
824
806
  GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
825
807
  return false;
826
808
  }
827
809
  void * base = ggml_backend_buffer_get_base(buffer);
828
- // output serialization format: | base_ptr (8 bytes) |
829
- uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
830
- output.resize(sizeof(uint64_t), 0);
831
- memcpy(output.data(), &base_ptr, sizeof(base_ptr));
810
+ response.base_ptr = reinterpret_cast<uint64_t>(base);
832
811
  return true;
833
812
  }
834
813
 
835
- bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
836
- // input serialization format: | remote_ptr (8 bytes) |
837
- if (input.size() != sizeof(uint64_t)) {
838
- return false;
839
- }
840
- uint64_t remote_ptr;
841
- memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
842
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
843
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
814
+ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
815
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
816
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
844
817
  if (buffers.find(buffer) == buffers.end()) {
845
818
  GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
846
819
  return false;
@@ -850,22 +823,14 @@ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
850
823
  return true;
851
824
  }
852
825
 
853
- bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
854
- // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
855
- if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
856
- return false;
857
- }
858
- uint64_t remote_ptr;
859
- memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
860
- uint8_t value;
861
- memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
862
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
863
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
826
+ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
827
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
828
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
864
829
  if (buffers.find(buffer) == buffers.end()) {
865
830
  GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
866
831
  return false;
867
832
  }
868
- ggml_backend_buffer_clear(buffer, value);
833
+ ggml_backend_buffer_clear(buffer, request.value);
869
834
  return true;
870
835
  }
871
836
 
@@ -877,8 +842,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
877
842
  }
878
843
  result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
879
844
  if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
880
- return nullptr;
845
+ result->buffer = nullptr;
846
+ }
847
+
848
+ if (result->buffer) {
849
+ // require that the tensor data does not go beyond the buffer end
850
+ uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
851
+ uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
852
+ uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
853
+ GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
854
+ GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
881
855
  }
856
+
882
857
  result->op = (ggml_op) tensor->op;
883
858
  for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
884
859
  result->op_params[i] = tensor->op_params[i];
@@ -898,7 +873,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
898
873
  const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
899
874
  uint64_t offset;
900
875
  memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
901
- size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
876
+ const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
902
877
 
903
878
  struct ggml_init_params params {
904
879
  /*.mem_size =*/ ggml_tensor_overhead(),
@@ -913,69 +888,72 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
913
888
  return false;
914
889
  }
915
890
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
891
+
892
+ // sanitize tensor->data
893
+ {
894
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
895
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
896
+
897
+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
898
+ GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
899
+ }
900
+ }
901
+
916
902
  const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
917
903
  ggml_backend_tensor_set(tensor, data, offset, size);
918
904
  ggml_free(ctx);
919
905
  return true;
920
906
  }
921
907
 
922
- bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
923
- // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
924
- if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
925
- return false;
926
- }
927
- const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
928
- uint64_t offset;
929
- memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
930
- uint64_t size;
931
- memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
932
-
908
+ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
933
909
  struct ggml_init_params params {
934
910
  /*.mem_size =*/ ggml_tensor_overhead(),
935
911
  /*.mem_buffer =*/ NULL,
936
912
  /*.no_alloc =*/ true,
937
913
  };
938
914
  struct ggml_context * ctx = ggml_init(params);
939
- ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
915
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
940
916
  if (tensor == nullptr) {
941
917
  GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
942
918
  ggml_free(ctx);
943
919
  return false;
944
920
  }
945
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
946
- // output serialization format: | data (size bytes) |
947
- output.resize(size, 0);
948
- ggml_backend_tensor_get(tensor, output.data(), offset, size);
921
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
922
+
923
+ // sanitize tensor->data
924
+ {
925
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
926
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
927
+
928
+ if (request.tensor.data + request.offset < p0 ||
929
+ request.tensor.data + request.offset >= p1 ||
930
+ request.size > (p1 - request.tensor.data - request.offset)) {
931
+ GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
932
+ }
933
+ }
934
+
935
+ response.resize(request.size, 0);
936
+ ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
949
937
  ggml_free(ctx);
950
938
  return true;
951
939
  }
952
940
 
953
- bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
954
- // serialization format: | rpc_tensor src | rpc_tensor dst |
955
- if (input.size() != 2*sizeof(rpc_tensor)) {
956
- return false;
957
- }
958
- const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
959
- const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
960
-
941
+ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
961
942
  struct ggml_init_params params {
962
943
  /*.mem_size =*/ 2*ggml_tensor_overhead(),
963
944
  /*.mem_buffer =*/ NULL,
964
945
  /*.no_alloc =*/ true,
965
946
  };
966
947
  struct ggml_context * ctx = ggml_init(params);
967
- ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
968
- ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
948
+ ggml_tensor * src = deserialize_tensor(ctx, &request.src);
949
+ ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
969
950
  if (src == nullptr || dst == nullptr) {
970
951
  GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
971
952
  ggml_free(ctx);
972
953
  return false;
973
954
  }
974
955
  GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
975
- bool result = ggml_backend_buffer_copy_tensor(src, dst);
976
- // output serialization format: | result (1 byte) |
977
- output.resize(1, 0);
978
- output[0] = result;
956
+ response.result = ggml_backend_buffer_copy_tensor(src, dst);
979
957
  ggml_free(ctx);
980
958
  return true;
981
959
  }
@@ -1004,7 +982,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
1004
982
  return result;
1005
983
  }
1006
984
 
1007
- bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
985
+ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1008
986
  // serialization format:
1009
987
  // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1010
988
  if (input.size() < sizeof(uint32_t)) {
@@ -1024,7 +1002,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
1024
1002
  const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
1025
1003
  GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
1026
1004
 
1027
- static size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1005
+ size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1028
1006
  struct ggml_init_params params = {
1029
1007
  /*.mem_size =*/ buf_size,
1030
1008
  /*.mem_buffer =*/ NULL,
@@ -1044,9 +1022,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
1044
1022
  graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1045
1023
  }
1046
1024
  ggml_status status = ggml_backend_graph_compute(backend, graph);
1047
- // output serialization format: | status (1 byte) |
1048
- output.resize(1, 0);
1049
- output[0] = status;
1025
+ response.result = status;
1050
1026
  ggml_free(ctx);
1051
1027
  return true;
1052
1028
  }
@@ -1064,84 +1040,162 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1064
1040
  if (!recv_data(sockfd, &cmd, 1)) {
1065
1041
  break;
1066
1042
  }
1067
- std::vector<uint8_t> input;
1068
- std::vector<uint8_t> output;
1069
- uint64_t input_size;
1070
- if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
1043
+ if (cmd >= RPC_CMD_COUNT) {
1044
+ // fail fast if the command is invalid
1045
+ fprintf(stderr, "Unknown command: %d\n", cmd);
1071
1046
  break;
1072
1047
  }
1073
- input.resize(input_size);
1074
- if (!recv_data(sockfd, input.data(), input_size)) {
1075
- break;
1076
- }
1077
- bool ok = true;
1078
1048
  switch (cmd) {
1079
- case ALLOC_BUFFER: {
1080
- ok = server.alloc_buffer(input, output);
1049
+ case RPC_CMD_ALLOC_BUFFER: {
1050
+ rpc_msg_alloc_buffer_req request;
1051
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1052
+ return;
1053
+ }
1054
+ rpc_msg_alloc_buffer_rsp response;
1055
+ server.alloc_buffer(request, response);
1056
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1057
+ return;
1058
+ }
1081
1059
  break;
1082
1060
  }
1083
- case GET_ALIGNMENT: {
1084
- server.get_alignment(output);
1061
+ case RPC_CMD_GET_ALIGNMENT: {
1062
+ if (!recv_msg(sockfd, nullptr, 0)) {
1063
+ return;
1064
+ }
1065
+ rpc_msg_get_alignment_rsp response;
1066
+ server.get_alignment(response);
1067
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1068
+ return;
1069
+ }
1085
1070
  break;
1086
1071
  }
1087
- case GET_MAX_SIZE: {
1088
- server.get_max_size(output);
1072
+ case RPC_CMD_GET_MAX_SIZE: {
1073
+ if (!recv_msg(sockfd, nullptr, 0)) {
1074
+ return;
1075
+ }
1076
+ rpc_msg_get_max_size_rsp response;
1077
+ server.get_max_size(response);
1078
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1079
+ return;
1080
+ }
1089
1081
  break;
1090
1082
  }
1091
- case BUFFER_GET_BASE: {
1092
- ok = server.buffer_get_base(input, output);
1083
+ case RPC_CMD_BUFFER_GET_BASE: {
1084
+ rpc_msg_buffer_get_base_req request;
1085
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1086
+ return;
1087
+ }
1088
+ rpc_msg_buffer_get_base_rsp response;
1089
+ if (!server.buffer_get_base(request, response)) {
1090
+ return;
1091
+ }
1092
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1093
+ return;
1094
+ }
1093
1095
  break;
1094
1096
  }
1095
- case FREE_BUFFER: {
1096
- ok = server.free_buffer(input);
1097
+ case RPC_CMD_FREE_BUFFER: {
1098
+ rpc_msg_free_buffer_req request;
1099
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1100
+ return;
1101
+ }
1102
+ if (!server.free_buffer(request)) {
1103
+ return;
1104
+ }
1105
+ if (!send_msg(sockfd, nullptr, 0)) {
1106
+ return;
1107
+ }
1097
1108
  break;
1098
1109
  }
1099
- case BUFFER_CLEAR: {
1100
- ok = server.buffer_clear(input);
1110
+ case RPC_CMD_BUFFER_CLEAR: {
1111
+ rpc_msg_buffer_clear_req request;
1112
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1113
+ return;
1114
+ }
1115
+ if (!server.buffer_clear(request)) {
1116
+ return;
1117
+ }
1118
+ if (!send_msg(sockfd, nullptr, 0)) {
1119
+ return;
1120
+ }
1101
1121
  break;
1102
1122
  }
1103
- case SET_TENSOR: {
1104
- ok = server.set_tensor(input);
1123
+ case RPC_CMD_SET_TENSOR: {
1124
+ std::vector<uint8_t> input;
1125
+ if (!recv_msg(sockfd, input)) {
1126
+ return;
1127
+ }
1128
+ if (!server.set_tensor(input)) {
1129
+ return;
1130
+ }
1131
+ if (!send_msg(sockfd, nullptr, 0)) {
1132
+ return;
1133
+ }
1105
1134
  break;
1106
1135
  }
1107
- case GET_TENSOR: {
1108
- ok = server.get_tensor(input, output);
1136
+ case RPC_CMD_GET_TENSOR: {
1137
+ rpc_msg_get_tensor_req request;
1138
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1139
+ return;
1140
+ }
1141
+ std::vector<uint8_t> response;
1142
+ if (!server.get_tensor(request, response)) {
1143
+ return;
1144
+ }
1145
+ if (!send_msg(sockfd, response.data(), response.size())) {
1146
+ return;
1147
+ }
1109
1148
  break;
1110
1149
  }
1111
- case COPY_TENSOR: {
1112
- ok = server.copy_tensor(input, output);
1150
+ case RPC_CMD_COPY_TENSOR: {
1151
+ rpc_msg_copy_tensor_req request;
1152
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1153
+ return;
1154
+ }
1155
+ rpc_msg_copy_tensor_rsp response;
1156
+ if (!server.copy_tensor(request, response)) {
1157
+ return;
1158
+ }
1159
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1160
+ return;
1161
+ }
1113
1162
  break;
1114
1163
  }
1115
- case GRAPH_COMPUTE: {
1116
- ok = server.graph_compute(input, output);
1164
+ case RPC_CMD_GRAPH_COMPUTE: {
1165
+ std::vector<uint8_t> input;
1166
+ if (!recv_msg(sockfd, input)) {
1167
+ return;
1168
+ }
1169
+ rpc_msg_graph_compute_rsp response;
1170
+ if (!server.graph_compute(input, response)) {
1171
+ return;
1172
+ }
1173
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1174
+ return;
1175
+ }
1117
1176
  break;
1118
1177
  }
1119
- case GET_DEVICE_MEMORY: {
1120
- // output serialization format: | free (8 bytes) | total (8 bytes) |
1121
- output.resize(2*sizeof(uint64_t), 0);
1122
- memcpy(output.data(), &free_mem, sizeof(free_mem));
1123
- memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
1178
+ case RPC_CMD_GET_DEVICE_MEMORY: {
1179
+ if (!recv_msg(sockfd, nullptr, 0)) {
1180
+ return;
1181
+ }
1182
+ rpc_msg_get_device_memory_rsp response;
1183
+ response.free_mem = free_mem;
1184
+ response.total_mem = total_mem;
1185
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1186
+ return;
1187
+ }
1124
1188
  break;
1125
1189
  }
1126
1190
  default: {
1127
1191
  fprintf(stderr, "Unknown command: %d\n", cmd);
1128
- ok = false;
1192
+ return;
1129
1193
  }
1130
1194
  }
1131
- if (!ok) {
1132
- break;
1133
- }
1134
- uint64_t output_size = output.size();
1135
- if (!send_data(sockfd, &output_size, sizeof(output_size))) {
1136
- break;
1137
- }
1138
- if (!send_data(sockfd, output.data(), output_size)) {
1139
- break;
1140
- }
1141
1195
  }
1142
1196
  }
1143
1197
 
1144
- void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1198
+ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1145
1199
  std::string host;
1146
1200
  int port;
1147
1201
  if (!parse_endpoint(endpoint, host, port)) {
@@ -1169,10 +1223,181 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1169
1223
  return;
1170
1224
  }
1171
1225
  printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1226
+ fflush(stdout);
1172
1227
  rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1173
1228
  printf("Client connection closed\n");
1229
+ fflush(stdout);
1174
1230
  }
1175
1231
  #ifdef _WIN32
1176
1232
  WSACleanup();
1177
1233
  #endif
1178
1234
  }
1235
+
1236
+ // device interface
1237
+
1238
+ struct ggml_backend_rpc_device_context {
1239
+ std::string endpoint;
1240
+ std::string name;
1241
+ };
1242
+
1243
+ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1244
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1245
+
1246
+ return ctx->name.c_str();
1247
+ }
1248
+
1249
+ static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1250
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1251
+
1252
+ return ctx->name.c_str();
1253
+ }
1254
+
1255
+ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1256
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1257
+
1258
+ ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1259
+
1260
+ UNUSED(dev);
1261
+ }
1262
+
1263
+ static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1264
+ // TODO: obtain value from the server
1265
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
1266
+
1267
+ UNUSED(dev);
1268
+ }
1269
+
1270
+ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1271
+ props->name = ggml_backend_rpc_device_get_name(dev);
1272
+ props->description = ggml_backend_rpc_device_get_description(dev);
1273
+ props->type = ggml_backend_rpc_device_get_type(dev);
1274
+ ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1275
+ props->caps = {
1276
+ /* .async = */ false,
1277
+ /* .host_buffer = */ false,
1278
+ /* .buffer_from_host_ptr = */ false,
1279
+ /* .events = */ false,
1280
+ };
1281
+ }
1282
+
1283
+ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1284
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1285
+
1286
+ return ggml_backend_rpc_init(ctx->endpoint.c_str());
1287
+
1288
+ UNUSED(params);
1289
+ }
1290
+
1291
+ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1292
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1293
+
1294
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1295
+
1296
+ UNUSED(dev);
1297
+ }
1298
+
1299
+ static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1300
+ UNUSED(dev);
1301
+ UNUSED(op);
1302
+ //TODO: call the remote backend and cache the results
1303
+ return true;
1304
+ }
1305
+
1306
+ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1307
+ if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1308
+ return false;
1309
+ }
1310
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1311
+ ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1312
+ return buft_ctx->endpoint == dev_ctx->endpoint;
1313
+ }
1314
+
1315
+ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1316
+ /* .get_name = */ ggml_backend_rpc_device_get_name,
1317
+ /* .get_description = */ ggml_backend_rpc_device_get_description,
1318
+ /* .get_memory = */ ggml_backend_rpc_device_get_memory,
1319
+ /* .get_type = */ ggml_backend_rpc_device_get_type,
1320
+ /* .get_props = */ ggml_backend_rpc_device_get_props,
1321
+ /* .init_backend = */ ggml_backend_rpc_device_init,
1322
+ /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1323
+ /* .get_host_buffer_type = */ NULL,
1324
+ /* .buffer_from_host_ptr = */ NULL,
1325
+ /* .supports_op = */ ggml_backend_rpc_device_supports_op,
1326
+ /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1327
+ /* .offload_op = */ NULL,
1328
+ /* .event_new = */ NULL,
1329
+ /* .event_free = */ NULL,
1330
+ /* .event_synchronize = */ NULL,
1331
+ };
1332
+
1333
+ // backend reg interface
1334
+
1335
+ static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1336
+ return "RPC";
1337
+
1338
+ UNUSED(reg);
1339
+ }
1340
+
1341
+ static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1342
+ return 0;
1343
+
1344
+ UNUSED(reg);
1345
+ }
1346
+
1347
+ static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1348
+ GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1349
+
1350
+ UNUSED(reg);
1351
+ UNUSED(index);
1352
+ }
1353
+
1354
+ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1355
+ if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1356
+ return (void *)ggml_backend_rpc_add_device;
1357
+ }
1358
+ return NULL;
1359
+
1360
+ UNUSED(reg);
1361
+ }
1362
+
1363
+ static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
1364
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
1365
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
1366
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
1367
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
1368
+ };
1369
+
1370
+ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
1371
+ static struct ggml_backend_reg ggml_backend_rpc_reg = {
1372
+ /* .iface = */ ggml_backend_rpc_reg_i,
1373
+ /* .context = */ NULL,
1374
+ };
1375
+
1376
+ return &ggml_backend_rpc_reg;
1377
+ }
1378
+
1379
+ ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1380
+ static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
1381
+
1382
+ static std::mutex mutex;
1383
+ std::lock_guard<std::mutex> lock(mutex);
1384
+
1385
+ if (dev_map.find(endpoint) != dev_map.end()) {
1386
+ return dev_map[endpoint];
1387
+ }
1388
+
1389
+ ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1390
+ /* .endpoint = */ endpoint,
1391
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
1392
+ };
1393
+
1394
+ ggml_backend_dev_t dev = new ggml_backend_device {
1395
+ /* .iface = */ ggml_backend_rpc_device_i,
1396
+ /* .reg = */ ggml_backend_rpc_reg(),
1397
+ /* .context = */ ctx,
1398
+ };
1399
+
1400
+ dev_map[endpoint] = dev;
1401
+
1402
+ return dev;
1403
+ }