@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -25,7 +25,7 @@
25
25
  # include <netdb.h>
26
26
  # include <unistd.h>
27
27
  #endif
28
- #include <string.h>
28
+ #include <cstring>
29
29
 
30
30
  #define UNUSED GGML_UNUSED
31
31
 
@@ -57,8 +57,9 @@ struct socket_t {
57
57
  }
58
58
  };
59
59
 
60
- // ggml_tensor is serialized into rpc_tensor
60
+ // all RPC structures must be packed
61
61
  #pragma pack(push, 1)
62
+ // ggml_tensor is serialized into rpc_tensor
62
63
  struct rpc_tensor {
63
64
  uint64_t id;
64
65
  uint32_t type;
@@ -76,7 +77,6 @@ struct rpc_tensor {
76
77
 
77
78
  char padding[4];
78
79
  };
79
- #pragma pack(pop)
80
80
 
81
81
  static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
82
82
 
@@ -96,6 +96,65 @@ enum rpc_cmd {
96
96
  RPC_CMD_COUNT,
97
97
  };
98
98
 
99
+ struct rpc_msg_alloc_buffer_req {
100
+ uint64_t size;
101
+ };
102
+
103
+ struct rpc_msg_alloc_buffer_rsp {
104
+ uint64_t remote_ptr;
105
+ uint64_t remote_size;
106
+ };
107
+
108
+ struct rpc_msg_get_alignment_rsp {
109
+ uint64_t alignment;
110
+ };
111
+
112
+ struct rpc_msg_get_max_size_rsp {
113
+ uint64_t max_size;
114
+ };
115
+
116
+ struct rpc_msg_buffer_get_base_req {
117
+ uint64_t remote_ptr;
118
+ };
119
+
120
+ struct rpc_msg_buffer_get_base_rsp {
121
+ uint64_t base_ptr;
122
+ };
123
+
124
+ struct rpc_msg_free_buffer_req {
125
+ uint64_t remote_ptr;
126
+ };
127
+
128
+ struct rpc_msg_buffer_clear_req {
129
+ uint64_t remote_ptr;
130
+ uint8_t value;
131
+ };
132
+
133
+ struct rpc_msg_get_tensor_req {
134
+ rpc_tensor tensor;
135
+ uint64_t offset;
136
+ uint64_t size;
137
+ };
138
+
139
+ struct rpc_msg_copy_tensor_req {
140
+ rpc_tensor src;
141
+ rpc_tensor dst;
142
+ };
143
+
144
+ struct rpc_msg_copy_tensor_rsp {
145
+ uint8_t result;
146
+ };
147
+
148
+ struct rpc_msg_graph_compute_rsp {
149
+ uint8_t result;
150
+ };
151
+
152
+ struct rpc_msg_get_device_memory_rsp {
153
+ uint64_t free_mem;
154
+ uint64_t total_mem;
155
+ };
156
+ #pragma pack(pop)
157
+
99
158
  // RPC data structures
100
159
 
101
160
  static ggml_guid_t ggml_backend_rpc_guid() {
@@ -119,7 +178,6 @@ struct ggml_backend_rpc_buffer_context {
119
178
  std::shared_ptr<socket_t> sock;
120
179
  std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
121
180
  uint64_t remote_ptr;
122
- std::string name;
123
181
  };
124
182
 
125
183
  // RPC helper functions
@@ -240,6 +298,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
240
298
  return true;
241
299
  }
242
300
 
301
+ static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
302
+ if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
303
+ return false;
304
+ }
305
+ return send_data(sockfd, msg, msg_size);
306
+ }
307
+
308
+ static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
309
+ uint64_t size;
310
+ if (!recv_data(sockfd, &size, sizeof(size))) {
311
+ return false;
312
+ }
313
+ if (size != msg_size) {
314
+ return false;
315
+ }
316
+ return recv_data(sockfd, msg, msg_size);
317
+ }
318
+
319
+ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
320
+ uint64_t size;
321
+ if (!recv_data(sockfd, &size, sizeof(size))) {
322
+ return false;
323
+ }
324
+ try {
325
+ input.resize(size);
326
+ } catch (const std::bad_alloc & e) {
327
+ fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
328
+ return false;
329
+ }
330
+ return recv_data(sockfd, input.data(), size);
331
+ }
332
+
243
333
  static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
244
334
  size_t pos = endpoint.find(':');
245
335
  if (pos == std::string::npos) {
@@ -252,28 +342,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
252
342
 
253
343
  // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
254
344
  // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
255
- static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
345
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
256
346
  uint8_t cmd_byte = cmd;
257
347
  if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
258
348
  return false;
259
349
  }
260
- uint64_t input_size = input.size();
261
350
  if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
262
351
  return false;
263
352
  }
264
- if (!send_data(sock->fd, input.data(), input.size())) {
353
+ if (!send_data(sock->fd, input, input_size)) {
265
354
  return false;
266
355
  }
267
- uint64_t output_size;
268
- if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
356
+ // TODO: currently the output_size is always known, do we need support for commands with variable output size?
357
+ // even if we do, we can skip sending output_size from the server for commands with known output size
358
+ uint64_t out_size;
359
+ if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
269
360
  return false;
270
361
  }
271
- if (output_size == 0) {
272
- output.clear();
273
- return true;
362
+ if (out_size != output_size) {
363
+ return false;
274
364
  }
275
- output.resize(output_size);
276
- if (!recv_data(sock->fd, output.data(), output_size)) {
365
+ if (!recv_data(sock->fd, output, output_size)) {
277
366
  return false;
278
367
  }
279
368
  return true;
@@ -319,21 +408,11 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
319
408
  return sock;
320
409
  }
321
410
 
322
- static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
323
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
324
- return ctx->name.c_str();
325
- }
326
-
327
411
  static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
328
412
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
329
- // input serialization format: | remote_ptr (8 bytes) |
330
- std::vector<uint8_t> input(sizeof(uint64_t), 0);
331
- uint64_t remote_ptr = ctx->remote_ptr;
332
- memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
333
- std::vector<uint8_t> output;
334
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
413
+ rpc_msg_free_buffer_req request = {ctx->remote_ptr};
414
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
335
415
  GGML_ASSERT(status);
336
- GGML_ASSERT(output.empty());
337
416
  delete ctx;
338
417
  }
339
418
 
@@ -342,20 +421,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
342
421
  if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
343
422
  return ctx->base_cache[buffer];
344
423
  }
345
- // input serialization format: | remote_ptr (8 bytes) |
346
- std::vector<uint8_t> input(sizeof(uint64_t), 0);
347
- uint64_t remote_ptr = ctx->remote_ptr;
348
- memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
349
- std::vector<uint8_t> output;
350
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
424
+ rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
425
+ rpc_msg_buffer_get_base_rsp response;
426
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
351
427
  GGML_ASSERT(status);
352
- GGML_ASSERT(output.size() == sizeof(uint64_t));
353
- // output serialization format: | base_ptr (8 bytes) |
354
- uint64_t base_ptr;
355
- memcpy(&base_ptr, output.data(), sizeof(base_ptr));
356
- void * base = reinterpret_cast<void *>(base_ptr);
357
- ctx->base_cache[buffer] = base;
358
- return base;
428
+ void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
429
+ ctx->base_cache[buffer] = base_ptr;
430
+ return base_ptr;
359
431
  }
360
432
 
361
433
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@@ -405,26 +477,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
405
477
  memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
406
478
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
407
479
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
408
- std::vector<uint8_t> output;
409
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
480
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
410
481
  GGML_ASSERT(status);
411
482
  }
412
483
 
413
484
  static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
414
485
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
415
- // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
416
- int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
417
- std::vector<uint8_t> input(input_size, 0);
418
- rpc_tensor rpc_tensor = serialize_tensor(tensor);
419
- memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
420
- memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
421
- memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
422
- std::vector<uint8_t> output;
423
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
486
+ rpc_msg_get_tensor_req request;
487
+ request.tensor = serialize_tensor(tensor);
488
+ request.offset = offset;
489
+ request.size = size;
490
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
424
491
  GGML_ASSERT(status);
425
- GGML_ASSERT(output.size() == size);
426
- // output serialization format: | data (size bytes) |
427
- memcpy(data, output.data(), size);
428
492
  }
429
493
 
430
494
  static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -437,35 +501,23 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
437
501
  return false;
438
502
  }
439
503
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
440
- // input serialization format: | rpc_tensor src | rpc_tensor dst |
441
- int input_size = 2*sizeof(rpc_tensor);
442
- std::vector<uint8_t> input(input_size, 0);
443
- rpc_tensor rpc_src = serialize_tensor(src);
444
- rpc_tensor rpc_dst = serialize_tensor(dst);
445
- memcpy(input.data(), &rpc_src, sizeof(rpc_src));
446
- memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
447
- std::vector<uint8_t> output;
448
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
504
+ rpc_msg_copy_tensor_req request;
505
+ request.src = serialize_tensor(src);
506
+ request.dst = serialize_tensor(dst);
507
+ rpc_msg_copy_tensor_rsp response;
508
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
449
509
  GGML_ASSERT(status);
450
- // output serialization format: | result (1 byte) |
451
- GGML_ASSERT(output.size() == 1);
452
- return output[0];
510
+ return response.result;
453
511
  }
454
512
 
455
513
  static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
456
514
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
457
- // serialization format: | bufptr (8 bytes) | value (1 byte) |
458
- int input_size = sizeof(uint64_t) + sizeof(uint8_t);
459
- std::vector<uint8_t> input(input_size, 0);
460
- memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
461
- memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
462
- std::vector<uint8_t> output;
463
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
515
+ rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
516
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
464
517
  GGML_ASSERT(status);
465
518
  }
466
519
 
467
520
  static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
468
- /* .get_name = */ ggml_backend_rpc_buffer_get_name,
469
521
  /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
470
522
  /* .get_base = */ ggml_backend_rpc_buffer_get_base,
471
523
  /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
@@ -484,25 +536,16 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
484
536
 
485
537
  static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
486
538
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
487
- // input serialization format: | size (8 bytes) |
488
- int input_size = sizeof(uint64_t);
489
- std::vector<uint8_t> input(input_size, 0);
490
- memcpy(input.data(), &size, sizeof(size));
491
- std::vector<uint8_t> output;
539
+ rpc_msg_alloc_buffer_req request = {size};
540
+ rpc_msg_alloc_buffer_rsp response;
492
541
  auto sock = get_socket(buft_ctx->endpoint);
493
- bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
542
+ bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
494
543
  GGML_ASSERT(status);
495
- GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
496
- // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
497
- uint64_t remote_ptr;
498
- memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
499
- size_t remote_size;
500
- memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
501
- if (remote_ptr != 0) {
544
+ if (response.remote_ptr != 0) {
502
545
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
503
546
  ggml_backend_rpc_buffer_interface,
504
- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
505
- remote_size);
547
+ new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
548
+ response.remote_size);
506
549
  return buffer;
507
550
  } else {
508
551
  return nullptr;
@@ -510,16 +553,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
510
553
  }
511
554
 
512
555
  static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
513
- // input serialization format: | 0 bytes |
514
- std::vector<uint8_t> input;
515
- std::vector<uint8_t> output;
516
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
556
+ rpc_msg_get_alignment_rsp response;
557
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
517
558
  GGML_ASSERT(status);
518
- GGML_ASSERT(output.size() == sizeof(uint64_t));
519
- // output serialization format: | alignment (8 bytes) |
520
- uint64_t alignment;
521
- memcpy(&alignment, output.data(), sizeof(alignment));
522
- return alignment;
559
+ return response.alignment;
523
560
  }
524
561
 
525
562
  static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@@ -528,16 +565,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
528
565
  }
529
566
 
530
567
  static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
531
- // input serialization format: | 0 bytes |
532
- std::vector<uint8_t> input;
533
- std::vector<uint8_t> output;
534
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
568
+ rpc_msg_get_max_size_rsp response;
569
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
535
570
  GGML_ASSERT(status);
536
- GGML_ASSERT(output.size() == sizeof(uint64_t));
537
- // output serialization format: | max_size (8 bytes) |
538
- uint64_t max_size;
539
- memcpy(&max_size, output.data(), sizeof(max_size));
540
- return max_size;
571
+ return response.max_size;
541
572
  }
542
573
 
543
574
  static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
@@ -571,11 +602,6 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
571
602
  delete backend;
572
603
  }
573
604
 
574
- static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
575
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
576
- return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
577
- }
578
-
579
605
  static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
580
606
  UNUSED(backend);
581
607
  // this is no-op because we don't have any async operations
@@ -622,34 +648,16 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
622
648
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
623
649
  std::vector<uint8_t> input;
624
650
  serialize_graph(cgraph, input);
625
- std::vector<uint8_t> output;
651
+ rpc_msg_graph_compute_rsp response;
626
652
  auto sock = get_socket(rpc_ctx->endpoint);
627
- bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
653
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
628
654
  GGML_ASSERT(status);
629
- GGML_ASSERT(output.size() == 1);
630
- return (enum ggml_status)output[0];
631
- }
632
-
633
- static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
634
- UNUSED(backend);
635
- UNUSED(op);
636
- //TODO: call the remote backend and cache the results
637
- return true;
638
- }
639
-
640
- static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
641
- if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
642
- return false;
643
- }
644
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
645
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
646
- return buft_ctx->endpoint == rpc_ctx->endpoint;
655
+ return (enum ggml_status)response.result;
647
656
  }
648
657
 
649
658
  static ggml_backend_i ggml_backend_rpc_interface = {
650
659
  /* .get_name = */ ggml_backend_rpc_name,
651
660
  /* .free = */ ggml_backend_rpc_free,
652
- /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
653
661
  /* .set_tensor_async = */ NULL,
654
662
  /* .get_tensor_async = */ NULL,
655
663
  /* .cpy_tensor_async = */ NULL,
@@ -659,14 +667,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
659
667
  /* .graph_plan_update = */ NULL,
660
668
  /* .graph_plan_compute = */ NULL,
661
669
  /* .graph_compute = */ ggml_backend_rpc_graph_compute,
662
- /* .supports_op = */ ggml_backend_rpc_supports_op,
663
- /* .supports_buft = */ ggml_backend_rpc_supports_buft,
664
- /* .offload_op = */ NULL,
665
670
  /* .event_record = */ NULL,
666
671
  /* .event_wait = */ NULL,
667
672
  };
668
673
 
669
- GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
674
+ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
670
675
  static std::mutex mutex;
671
676
  std::lock_guard<std::mutex> lock(mutex);
672
677
  // NOTE: buffer types are allocated and never freed; this is by design
@@ -691,7 +696,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
691
696
 
692
697
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
693
698
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
694
- /* .device = */ nullptr,
699
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
695
700
  /* .context = */ buft_ctx
696
701
  };
697
702
  buft_map[endpoint] = buft;
@@ -707,33 +712,25 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
707
712
  ggml_backend_t backend = new ggml_backend {
708
713
  /* .guid = */ ggml_backend_rpc_guid(),
709
714
  /* .interface = */ ggml_backend_rpc_interface,
710
- /* .device = */ nullptr,
715
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
711
716
  /* .context = */ ctx
712
717
  };
713
718
  return backend;
714
719
  }
715
720
 
716
- GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
721
+ bool ggml_backend_is_rpc(ggml_backend_t backend) {
717
722
  return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
718
723
  }
719
724
 
720
725
  static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
721
- // input serialization format: | 0 bytes |
722
- std::vector<uint8_t> input;
723
- std::vector<uint8_t> output;
724
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
726
+ rpc_msg_get_device_memory_rsp response;
727
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
725
728
  GGML_ASSERT(status);
726
- GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
727
- // output serialization format: | free (8 bytes) | total (8 bytes) |
728
- uint64_t free_mem;
729
- memcpy(&free_mem, output.data(), sizeof(free_mem));
730
- uint64_t total_mem;
731
- memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
732
- *free = free_mem;
733
- *total = total_mem;
729
+ *free = response.free_mem;
730
+ *total = response.total_mem;
734
731
  }
735
732
 
736
- GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
733
+ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
737
734
  auto sock = get_socket(endpoint);
738
735
  if (sock == nullptr) {
739
736
  *free = 0;
@@ -750,16 +747,16 @@ public:
750
747
  rpc_server(ggml_backend_t backend) : backend(backend) {}
751
748
  ~rpc_server();
752
749
 
753
- bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
754
- void get_alignment(std::vector<uint8_t> & output);
755
- void get_max_size(std::vector<uint8_t> & output);
756
- bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
757
- bool free_buffer(const std::vector<uint8_t> & input);
758
- bool buffer_clear(const std::vector<uint8_t> & input);
750
+ void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
751
+ void get_alignment(rpc_msg_get_alignment_rsp & response);
752
+ void get_max_size(rpc_msg_get_max_size_rsp & response);
753
+ bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
754
+ bool free_buffer(const rpc_msg_free_buffer_req & request);
755
+ bool buffer_clear(const rpc_msg_buffer_clear_req & request);
759
756
  bool set_tensor(const std::vector<uint8_t> & input);
760
- bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
761
- bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
762
- bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
757
+ bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
758
+ bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
759
+ bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
763
760
 
764
761
  private:
765
762
  ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -773,80 +770,50 @@ private:
773
770
  std::unordered_set<ggml_backend_buffer_t> buffers;
774
771
  };
775
772
 
776
- bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
777
- // input serialization format: | size (8 bytes) |
778
- if (input.size() != sizeof(uint64_t)) {
779
- return false;
780
- }
781
- uint64_t size;
782
- memcpy(&size, input.data(), sizeof(size));
773
+ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
783
774
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
784
- ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
785
- uint64_t remote_ptr = 0;
786
- uint64_t remote_size = 0;
775
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
776
+ response.remote_ptr = 0;
777
+ response.remote_size = 0;
787
778
  if (buffer != nullptr) {
788
- remote_ptr = reinterpret_cast<uint64_t>(buffer);
789
- remote_size = buffer->size;
790
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
779
+ response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
780
+ response.remote_size = buffer->size;
781
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
791
782
  buffers.insert(buffer);
792
783
  } else {
793
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
784
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
794
785
  }
795
- // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
796
- output.resize(2*sizeof(uint64_t), 0);
797
- memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
798
- memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
799
- return true;
800
786
  }
801
787
 
802
- void rpc_server::get_alignment(std::vector<uint8_t> & output) {
788
+ void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
803
789
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
804
790
  size_t alignment = ggml_backend_buft_get_alignment(buft);
805
791
  GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
806
- // output serialization format: | alignment (8 bytes) |
807
- output.resize(sizeof(uint64_t), 0);
808
- memcpy(output.data(), &alignment, sizeof(alignment));
792
+ response.alignment = alignment;
809
793
  }
810
794
 
811
- void rpc_server::get_max_size(std::vector<uint8_t> & output) {
795
+ void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
812
796
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
813
797
  size_t max_size = ggml_backend_buft_get_max_size(buft);
814
798
  GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
815
- // output serialization format: | max_size (8 bytes) |
816
- output.resize(sizeof(uint64_t), 0);
817
- memcpy(output.data(), &max_size, sizeof(max_size));
799
+ response.max_size = max_size;
818
800
  }
819
801
 
820
- bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
821
- // input serialization format: | remote_ptr (8 bytes) |
822
- if (input.size() != sizeof(uint64_t)) {
823
- return false;
824
- }
825
- uint64_t remote_ptr;
826
- memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
827
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
828
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
802
+ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
803
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
804
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
829
805
  if (buffers.find(buffer) == buffers.end()) {
830
806
  GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
831
807
  return false;
832
808
  }
833
809
  void * base = ggml_backend_buffer_get_base(buffer);
834
- // output serialization format: | base_ptr (8 bytes) |
835
- uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
836
- output.resize(sizeof(uint64_t), 0);
837
- memcpy(output.data(), &base_ptr, sizeof(base_ptr));
810
+ response.base_ptr = reinterpret_cast<uint64_t>(base);
838
811
  return true;
839
812
  }
840
813
 
841
- bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
842
- // input serialization format: | remote_ptr (8 bytes) |
843
- if (input.size() != sizeof(uint64_t)) {
844
- return false;
845
- }
846
- uint64_t remote_ptr;
847
- memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
848
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
849
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
814
+ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
815
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
816
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
850
817
  if (buffers.find(buffer) == buffers.end()) {
851
818
  GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
852
819
  return false;
@@ -856,22 +823,14 @@ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
856
823
  return true;
857
824
  }
858
825
 
859
- bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
860
- // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
861
- if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
862
- return false;
863
- }
864
- uint64_t remote_ptr;
865
- memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
866
- uint8_t value;
867
- memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
868
- GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
869
- ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
826
+ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
827
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
828
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
870
829
  if (buffers.find(buffer) == buffers.end()) {
871
830
  GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
872
831
  return false;
873
832
  }
874
- ggml_backend_buffer_clear(buffer, value);
833
+ ggml_backend_buffer_clear(buffer, request.value);
875
834
  return true;
876
835
  }
877
836
 
@@ -946,74 +905,55 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
946
905
  return true;
947
906
  }
948
907
 
949
- bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
950
- // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
951
- if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
952
- return false;
953
- }
954
- const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
955
- uint64_t offset;
956
- memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
957
- uint64_t size;
958
- memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
959
-
908
+ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
960
909
  struct ggml_init_params params {
961
910
  /*.mem_size =*/ ggml_tensor_overhead(),
962
911
  /*.mem_buffer =*/ NULL,
963
912
  /*.no_alloc =*/ true,
964
913
  };
965
914
  struct ggml_context * ctx = ggml_init(params);
966
- ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
915
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
967
916
  if (tensor == nullptr) {
968
917
  GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
969
918
  ggml_free(ctx);
970
919
  return false;
971
920
  }
972
- GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
921
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
973
922
 
974
923
  // sanitize tensor->data
975
924
  {
976
925
  const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
977
926
  const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
978
927
 
979
- if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
980
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
928
+ if (request.tensor.data + request.offset < p0 ||
929
+ request.tensor.data + request.offset >= p1 ||
930
+ request.size > (p1 - request.tensor.data - request.offset)) {
931
+ GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
981
932
  }
982
933
  }
983
934
 
984
- // output serialization format: | data (size bytes) |
985
- output.resize(size, 0);
986
- ggml_backend_tensor_get(tensor, output.data(), offset, size);
935
+ response.resize(request.size, 0);
936
+ ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
987
937
  ggml_free(ctx);
988
938
  return true;
989
939
  }
990
940
 
991
- bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
992
- // serialization format: | rpc_tensor src | rpc_tensor dst |
993
- if (input.size() != 2*sizeof(rpc_tensor)) {
994
- return false;
995
- }
996
- const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
997
- const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
998
-
941
+ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
999
942
  struct ggml_init_params params {
1000
943
  /*.mem_size =*/ 2*ggml_tensor_overhead(),
1001
944
  /*.mem_buffer =*/ NULL,
1002
945
  /*.no_alloc =*/ true,
1003
946
  };
1004
947
  struct ggml_context * ctx = ggml_init(params);
1005
- ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
1006
- ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
948
+ ggml_tensor * src = deserialize_tensor(ctx, &request.src);
949
+ ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
1007
950
  if (src == nullptr || dst == nullptr) {
1008
951
  GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
1009
952
  ggml_free(ctx);
1010
953
  return false;
1011
954
  }
1012
955
  GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
1013
- bool result = ggml_backend_buffer_copy_tensor(src, dst);
1014
- // output serialization format: | result (1 byte) |
1015
- output.resize(1, 0);
1016
- output[0] = result;
956
+ response.result = ggml_backend_buffer_copy_tensor(src, dst);
1017
957
  ggml_free(ctx);
1018
958
  return true;
1019
959
  }
@@ -1042,7 +982,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
1042
982
  return result;
1043
983
  }
1044
984
 
1045
- bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
985
+ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
1046
986
  // serialization format:
1047
987
  // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1048
988
  if (input.size() < sizeof(uint32_t)) {
@@ -1082,9 +1022,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
1082
1022
  graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1083
1023
  }
1084
1024
  ggml_status status = ggml_backend_graph_compute(backend, graph);
1085
- // output serialization format: | status (1 byte) |
1086
- output.resize(1, 0);
1087
- output[0] = status;
1025
+ response.result = status;
1088
1026
  ggml_free(ctx);
1089
1027
  return true;
1090
1028
  }
@@ -1107,89 +1045,157 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1107
1045
  fprintf(stderr, "Unknown command: %d\n", cmd);
1108
1046
  break;
1109
1047
  }
1110
- std::vector<uint8_t> input;
1111
- std::vector<uint8_t> output;
1112
- uint64_t input_size;
1113
- if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
1114
- break;
1115
- }
1116
- try {
1117
- input.resize(input_size);
1118
- } catch (const std::bad_alloc & e) {
1119
- fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
1120
- break;
1121
- }
1122
- if (!recv_data(sockfd, input.data(), input_size)) {
1123
- break;
1124
- }
1125
- bool ok = true;
1126
1048
  switch (cmd) {
1127
1049
  case RPC_CMD_ALLOC_BUFFER: {
1128
- ok = server.alloc_buffer(input, output);
1050
+ rpc_msg_alloc_buffer_req request;
1051
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1052
+ return;
1053
+ }
1054
+ rpc_msg_alloc_buffer_rsp response;
1055
+ server.alloc_buffer(request, response);
1056
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1057
+ return;
1058
+ }
1129
1059
  break;
1130
1060
  }
1131
1061
  case RPC_CMD_GET_ALIGNMENT: {
1132
- server.get_alignment(output);
1062
+ if (!recv_msg(sockfd, nullptr, 0)) {
1063
+ return;
1064
+ }
1065
+ rpc_msg_get_alignment_rsp response;
1066
+ server.get_alignment(response);
1067
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1068
+ return;
1069
+ }
1133
1070
  break;
1134
1071
  }
1135
1072
  case RPC_CMD_GET_MAX_SIZE: {
1136
- server.get_max_size(output);
1073
+ if (!recv_msg(sockfd, nullptr, 0)) {
1074
+ return;
1075
+ }
1076
+ rpc_msg_get_max_size_rsp response;
1077
+ server.get_max_size(response);
1078
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1079
+ return;
1080
+ }
1137
1081
  break;
1138
1082
  }
1139
1083
  case RPC_CMD_BUFFER_GET_BASE: {
1140
- ok = server.buffer_get_base(input, output);
1084
+ rpc_msg_buffer_get_base_req request;
1085
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1086
+ return;
1087
+ }
1088
+ rpc_msg_buffer_get_base_rsp response;
1089
+ if (!server.buffer_get_base(request, response)) {
1090
+ return;
1091
+ }
1092
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1093
+ return;
1094
+ }
1141
1095
  break;
1142
1096
  }
1143
1097
  case RPC_CMD_FREE_BUFFER: {
1144
- ok = server.free_buffer(input);
1098
+ rpc_msg_free_buffer_req request;
1099
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1100
+ return;
1101
+ }
1102
+ if (!server.free_buffer(request)) {
1103
+ return;
1104
+ }
1105
+ if (!send_msg(sockfd, nullptr, 0)) {
1106
+ return;
1107
+ }
1145
1108
  break;
1146
1109
  }
1147
1110
  case RPC_CMD_BUFFER_CLEAR: {
1148
- ok = server.buffer_clear(input);
1111
+ rpc_msg_buffer_clear_req request;
1112
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1113
+ return;
1114
+ }
1115
+ if (!server.buffer_clear(request)) {
1116
+ return;
1117
+ }
1118
+ if (!send_msg(sockfd, nullptr, 0)) {
1119
+ return;
1120
+ }
1149
1121
  break;
1150
1122
  }
1151
1123
  case RPC_CMD_SET_TENSOR: {
1152
- ok = server.set_tensor(input);
1124
+ std::vector<uint8_t> input;
1125
+ if (!recv_msg(sockfd, input)) {
1126
+ return;
1127
+ }
1128
+ if (!server.set_tensor(input)) {
1129
+ return;
1130
+ }
1131
+ if (!send_msg(sockfd, nullptr, 0)) {
1132
+ return;
1133
+ }
1153
1134
  break;
1154
1135
  }
1155
1136
  case RPC_CMD_GET_TENSOR: {
1156
- ok = server.get_tensor(input, output);
1137
+ rpc_msg_get_tensor_req request;
1138
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1139
+ return;
1140
+ }
1141
+ std::vector<uint8_t> response;
1142
+ if (!server.get_tensor(request, response)) {
1143
+ return;
1144
+ }
1145
+ if (!send_msg(sockfd, response.data(), response.size())) {
1146
+ return;
1147
+ }
1157
1148
  break;
1158
1149
  }
1159
1150
  case RPC_CMD_COPY_TENSOR: {
1160
- ok = server.copy_tensor(input, output);
1151
+ rpc_msg_copy_tensor_req request;
1152
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
1153
+ return;
1154
+ }
1155
+ rpc_msg_copy_tensor_rsp response;
1156
+ if (!server.copy_tensor(request, response)) {
1157
+ return;
1158
+ }
1159
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1160
+ return;
1161
+ }
1161
1162
  break;
1162
1163
  }
1163
1164
  case RPC_CMD_GRAPH_COMPUTE: {
1164
- ok = server.graph_compute(input, output);
1165
+ std::vector<uint8_t> input;
1166
+ if (!recv_msg(sockfd, input)) {
1167
+ return;
1168
+ }
1169
+ rpc_msg_graph_compute_rsp response;
1170
+ if (!server.graph_compute(input, response)) {
1171
+ return;
1172
+ }
1173
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1174
+ return;
1175
+ }
1165
1176
  break;
1166
1177
  }
1167
1178
  case RPC_CMD_GET_DEVICE_MEMORY: {
1168
- // output serialization format: | free (8 bytes) | total (8 bytes) |
1169
- output.resize(2*sizeof(uint64_t), 0);
1170
- memcpy(output.data(), &free_mem, sizeof(free_mem));
1171
- memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
1179
+ if (!recv_msg(sockfd, nullptr, 0)) {
1180
+ return;
1181
+ }
1182
+ rpc_msg_get_device_memory_rsp response;
1183
+ response.free_mem = free_mem;
1184
+ response.total_mem = total_mem;
1185
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1186
+ return;
1187
+ }
1172
1188
  break;
1173
1189
  }
1174
1190
  default: {
1175
1191
  fprintf(stderr, "Unknown command: %d\n", cmd);
1176
- ok = false;
1192
+ return;
1177
1193
  }
1178
1194
  }
1179
- if (!ok) {
1180
- break;
1181
- }
1182
- uint64_t output_size = output.size();
1183
- if (!send_data(sockfd, &output_size, sizeof(output_size))) {
1184
- break;
1185
- }
1186
- if (!send_data(sockfd, output.data(), output_size)) {
1187
- break;
1188
- }
1189
1195
  }
1190
1196
  }
1191
1197
 
1192
- void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1198
+ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1193
1199
  std::string host;
1194
1200
  int port;
1195
1201
  if (!parse_endpoint(endpoint, host, port)) {
@@ -1226,3 +1232,172 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1226
1232
  WSACleanup();
1227
1233
  #endif
1228
1234
  }
1235
+
1236
+ // device interface
1237
+
1238
+ struct ggml_backend_rpc_device_context {
1239
+ std::string endpoint;
1240
+ std::string name;
1241
+ };
1242
+
1243
+ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1244
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1245
+
1246
+ return ctx->name.c_str();
1247
+ }
1248
+
1249
+ static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1250
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1251
+
1252
+ return ctx->name.c_str();
1253
+ }
1254
+
1255
+ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1256
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1257
+
1258
+ ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1259
+
1260
+ UNUSED(dev);
1261
+ }
1262
+
1263
+ static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1264
+ // TODO: obtain value from the server
1265
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
1266
+
1267
+ UNUSED(dev);
1268
+ }
1269
+
1270
+ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1271
+ props->name = ggml_backend_rpc_device_get_name(dev);
1272
+ props->description = ggml_backend_rpc_device_get_description(dev);
1273
+ props->type = ggml_backend_rpc_device_get_type(dev);
1274
+ ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1275
+ props->caps = {
1276
+ /* .async = */ false,
1277
+ /* .host_buffer = */ false,
1278
+ /* .buffer_from_host_ptr = */ false,
1279
+ /* .events = */ false,
1280
+ };
1281
+ }
1282
+
1283
+ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1284
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1285
+
1286
+ return ggml_backend_rpc_init(ctx->endpoint.c_str());
1287
+
1288
+ UNUSED(params);
1289
+ }
1290
+
1291
+ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1292
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1293
+
1294
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1295
+
1296
+ UNUSED(dev);
1297
+ }
1298
+
1299
+ static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1300
+ UNUSED(dev);
1301
+ UNUSED(op);
1302
+ //TODO: call the remote backend and cache the results
1303
+ return true;
1304
+ }
1305
+
1306
+ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1307
+ if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1308
+ return false;
1309
+ }
1310
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1311
+ ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1312
+ return buft_ctx->endpoint == dev_ctx->endpoint;
1313
+ }
1314
+
1315
+ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1316
+ /* .get_name = */ ggml_backend_rpc_device_get_name,
1317
+ /* .get_description = */ ggml_backend_rpc_device_get_description,
1318
+ /* .get_memory = */ ggml_backend_rpc_device_get_memory,
1319
+ /* .get_type = */ ggml_backend_rpc_device_get_type,
1320
+ /* .get_props = */ ggml_backend_rpc_device_get_props,
1321
+ /* .init_backend = */ ggml_backend_rpc_device_init,
1322
+ /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1323
+ /* .get_host_buffer_type = */ NULL,
1324
+ /* .buffer_from_host_ptr = */ NULL,
1325
+ /* .supports_op = */ ggml_backend_rpc_device_supports_op,
1326
+ /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1327
+ /* .offload_op = */ NULL,
1328
+ /* .event_new = */ NULL,
1329
+ /* .event_free = */ NULL,
1330
+ /* .event_synchronize = */ NULL,
1331
+ };
1332
+
1333
+ // backend reg interface
1334
+
1335
+ static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1336
+ return "RPC";
1337
+
1338
+ UNUSED(reg);
1339
+ }
1340
+
1341
+ static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1342
+ return 0;
1343
+
1344
+ UNUSED(reg);
1345
+ }
1346
+
1347
+ static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1348
+ GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1349
+
1350
+ UNUSED(reg);
1351
+ UNUSED(index);
1352
+ }
1353
+
1354
+ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1355
+ if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1356
+ return (void *)ggml_backend_rpc_add_device;
1357
+ }
1358
+ return NULL;
1359
+
1360
+ UNUSED(reg);
1361
+ }
1362
+
1363
+ static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
1364
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
1365
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
1366
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
1367
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
1368
+ };
1369
+
1370
+ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
1371
+ static struct ggml_backend_reg ggml_backend_rpc_reg = {
1372
+ /* .iface = */ ggml_backend_rpc_reg_i,
1373
+ /* .context = */ NULL,
1374
+ };
1375
+
1376
+ return &ggml_backend_rpc_reg;
1377
+ }
1378
+
1379
+ ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1380
+ static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
1381
+
1382
+ static std::mutex mutex;
1383
+ std::lock_guard<std::mutex> lock(mutex);
1384
+
1385
+ if (dev_map.find(endpoint) != dev_map.end()) {
1386
+ return dev_map[endpoint];
1387
+ }
1388
+
1389
+ ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1390
+ /* .endpoint = */ endpoint,
1391
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
1392
+ };
1393
+
1394
+ ggml_backend_dev_t dev = new ggml_backend_device {
1395
+ /* .iface = */ ggml_backend_rpc_device_i,
1396
+ /* .reg = */ ggml_backend_rpc_reg(),
1397
+ /* .context = */ ctx,
1398
+ };
1399
+
1400
+ dev_map[endpoint] = dev;
1401
+
1402
+ return dev;
1403
+ }