@fugood/llama.node 0.3.16 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +238 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +6 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  130. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
  131. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  133. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  135. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  136. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
  142. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  143. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  144. package/src/llama.cpp/include/llama.h +30 -11
  145. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  147. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  149. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  150. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  151. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  152. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  153. package/src/llama.cpp/src/llama-arch.cpp +160 -17
  154. package/src/llama.cpp/src/llama-arch.h +16 -0
  155. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  156. package/src/llama.cpp/src/llama-chat.h +6 -2
  157. package/src/llama.cpp/src/llama-context.cpp +108 -92
  158. package/src/llama.cpp/src/llama-context.h +1 -2
  159. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  160. package/src/llama.cpp/src/llama-graph.h +26 -6
  161. package/src/llama.cpp/src/llama-hparams.h +13 -0
  162. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  163. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  164. package/src/llama.cpp/src/llama-memory.h +1 -1
  165. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  166. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  167. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  168. package/src/llama.cpp/src/llama-model.cpp +1760 -534
  169. package/src/llama.cpp/src/llama-model.h +13 -1
  170. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  171. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  172. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  173. package/src/llama.cpp/src/llama.cpp +1 -1
  174. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  175. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  176. package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
  177. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  178. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  179. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  180. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  181. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  182. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  183. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  184. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  185. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  186. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  188. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  189. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  190. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  191. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  192. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  193. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -1,6 +1,7 @@
1
1
  #include "ggml-rpc.h"
2
2
  #include "ggml-impl.h"
3
3
  #include "ggml-backend-impl.h"
4
+ #include "ggml-cpp.h"
4
5
 
5
6
  #include <cinttypes>
6
7
  #include <string>
@@ -26,6 +27,10 @@
26
27
  # include <unistd.h>
27
28
  #endif
28
29
  #include <cstring>
30
+ #include <fstream>
31
+ #include <filesystem>
32
+
33
+ namespace fs = std::filesystem;
29
34
 
30
35
  #ifdef _WIN32
31
36
  typedef SOCKET sockfd_t;
@@ -80,15 +85,26 @@ enum rpc_cmd {
80
85
  RPC_CMD_FREE_BUFFER,
81
86
  RPC_CMD_BUFFER_CLEAR,
82
87
  RPC_CMD_SET_TENSOR,
88
+ RPC_CMD_SET_TENSOR_HASH,
83
89
  RPC_CMD_GET_TENSOR,
84
90
  RPC_CMD_COPY_TENSOR,
85
91
  RPC_CMD_GRAPH_COMPUTE,
86
92
  RPC_CMD_GET_DEVICE_MEMORY,
87
93
  RPC_CMD_INIT_TENSOR,
88
94
  RPC_CMD_GET_ALLOC_SIZE,
95
+ RPC_CMD_HELLO,
89
96
  RPC_CMD_COUNT,
90
97
  };
91
98
 
99
+ // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
100
+ const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
101
+
102
+ struct rpc_msg_hello_rsp {
103
+ uint8_t major;
104
+ uint8_t minor;
105
+ uint8_t patch;
106
+ };
107
+
92
108
  struct rpc_msg_get_alloc_size_req {
93
109
  rpc_tensor tensor;
94
110
  };
@@ -135,6 +151,10 @@ struct rpc_msg_buffer_clear_req {
135
151
  uint8_t value;
136
152
  };
137
153
 
154
+ struct rpc_msg_set_tensor_hash_rsp {
155
+ uint8_t result;
156
+ };
157
+
138
158
  struct rpc_msg_get_tensor_req {
139
159
  rpc_tensor tensor;
140
160
  uint64_t offset;
@@ -187,6 +207,18 @@ struct ggml_backend_rpc_buffer_context {
187
207
 
188
208
  // RPC helper functions
189
209
 
210
+ // Computes FNV-1a hash of the data
211
+ static uint64_t fnv_hash(const uint8_t * data, size_t len) {
212
+ const uint64_t fnv_prime = 0x100000001b3ULL;
213
+ uint64_t hash = 0xcbf29ce484222325ULL;
214
+
215
+ for (size_t i = 0; i < len; ++i) {
216
+ hash ^= data[i];
217
+ hash *= fnv_prime;
218
+ }
219
+ return hash;
220
+ }
221
+
190
222
  static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
191
223
  #ifdef _WIN32
192
224
  if (fd == INVALID_SOCKET) {
@@ -346,8 +378,8 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
346
378
  }
347
379
 
348
380
  // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
349
- // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
350
- static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
381
+ // No response
382
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
351
383
  uint8_t cmd_byte = cmd;
352
384
  if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
353
385
  return false;
@@ -358,6 +390,15 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
358
390
  if (!send_data(sock->fd, input, input_size)) {
359
391
  return false;
360
392
  }
393
+ return true;
394
+ }
395
+
396
+ // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
397
+ // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
398
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
399
+ if (!send_rpc_cmd(sock, cmd, input, input_size)) {
400
+ return false;
401
+ }
361
402
  // TODO: currently the output_size is always known, do we need support for commands with variable output size?
362
403
  // even if we do, we can skip sending output_size from the server for commands with known output size
363
404
  uint64_t out_size;
@@ -375,6 +416,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
375
416
 
376
417
  // RPC client-side implementation
377
418
 
419
+ static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
420
+ rpc_msg_hello_rsp response;
421
+ bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
422
+ GGML_ASSERT(status);
423
+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
424
+ fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
425
+ return false;
426
+ }
427
+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
428
+ fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
429
+ }
430
+ return true;
431
+ }
432
+
378
433
  static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
379
434
  static std::mutex mutex;
380
435
  std::lock_guard<std::mutex> lock(mutex);
@@ -408,6 +463,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
408
463
  if (sock == nullptr) {
409
464
  return nullptr;
410
465
  }
466
+ if (!check_server_version(sock)) {
467
+ return nullptr;
468
+ }
411
469
  GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
412
470
  sockets[endpoint] = sock;
413
471
  return sock;
@@ -483,14 +541,30 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
483
541
 
484
542
  static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
485
543
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
486
- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
544
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
545
+ if (size > HASH_THRESHOLD) {
546
+ // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
547
+ size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
548
+ std::vector<uint8_t> input(input_size, 0);
549
+ uint64_t hash = fnv_hash((const uint8_t*)data, size);
550
+ memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
551
+ memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
552
+ memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
553
+ rpc_msg_set_tensor_hash_rsp response;
554
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
555
+ GGML_ASSERT(status);
556
+ if (response.result) {
557
+ // the server has the same data, no need to send it
558
+ return;
559
+ }
560
+ }
561
+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
487
562
  size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
488
563
  std::vector<uint8_t> input(input_size, 0);
489
- rpc_tensor rpc_tensor = serialize_tensor(tensor);
490
564
  memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
491
565
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
492
566
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
493
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
567
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
494
568
  GGML_ASSERT(status);
495
569
  }
496
570
 
@@ -772,9 +846,12 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
772
846
 
773
847
  class rpc_server {
774
848
  public:
775
- rpc_server(ggml_backend_t backend) : backend(backend) {}
849
+ rpc_server(ggml_backend_t backend, const char * cache_dir)
850
+ : backend(backend), cache_dir(cache_dir) {
851
+ }
776
852
  ~rpc_server();
777
853
 
854
+ void hello(rpc_msg_hello_rsp & response);
778
855
  void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
779
856
  void get_alignment(rpc_msg_get_alignment_rsp & response);
780
857
  void get_max_size(rpc_msg_get_max_size_rsp & response);
@@ -782,6 +859,7 @@ public:
782
859
  bool free_buffer(const rpc_msg_free_buffer_req & request);
783
860
  bool buffer_clear(const rpc_msg_buffer_clear_req & request);
784
861
  bool set_tensor(const std::vector<uint8_t> & input);
862
+ bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
785
863
  bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
786
864
  bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
787
865
  bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
@@ -789,6 +867,7 @@ public:
789
867
  bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
790
868
 
791
869
  private:
870
+ bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
792
871
  ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
793
872
  ggml_tensor * create_node(uint64_t id,
794
873
  struct ggml_context * ctx,
@@ -797,9 +876,17 @@ private:
797
876
 
798
877
 
799
878
  ggml_backend_t backend;
879
+ const char * cache_dir;
800
880
  std::unordered_set<ggml_backend_buffer_t> buffers;
801
881
  };
802
882
 
883
+ void rpc_server::hello(rpc_msg_hello_rsp & response) {
884
+ response.major = RPC_PROTO_MAJOR_VERSION;
885
+ response.minor = RPC_PROTO_MINOR_VERSION;
886
+ response.patch = RPC_PROTO_PATCH_VERSION;
887
+ GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
888
+ }
889
+
803
890
  bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
804
891
  ggml_backend_buffer_type_t buft;
805
892
  struct ggml_init_params params {
@@ -808,12 +895,13 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
808
895
  /*.no_alloc =*/ true,
809
896
  };
810
897
 
811
- struct ggml_context * ctx = ggml_init(params);
898
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
899
+ GGML_ASSERT(ctx_ptr != nullptr);
900
+ ggml_context * ctx = ctx_ptr.get();
812
901
  ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
813
902
 
814
903
  if (tensor == nullptr) {
815
904
  GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
816
- ggml_free(ctx);
817
905
  return false;
818
906
  }
819
907
 
@@ -826,7 +914,6 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
826
914
 
827
915
  response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
828
916
 
829
- ggml_free(ctx);
830
917
  return true;
831
918
  }
832
919
 
@@ -895,8 +982,21 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
895
982
  }
896
983
 
897
984
  ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
985
+ // Validate tensor type before using it
986
+ if (tensor->type >= GGML_TYPE_COUNT) {
987
+ GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
988
+ return nullptr;
989
+ }
990
+
898
991
  ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
899
992
  tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
993
+
994
+ // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
995
+ if (result == nullptr) {
996
+ GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
997
+ return nullptr;
998
+ }
999
+
900
1000
  for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
901
1001
  result->nb[i] = tensor->nb[i];
902
1002
  }
@@ -940,11 +1040,12 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
940
1040
  /*.mem_buffer =*/ NULL,
941
1041
  /*.no_alloc =*/ true,
942
1042
  };
943
- struct ggml_context * ctx = ggml_init(params);
1043
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1044
+ GGML_ASSERT(ctx_ptr != nullptr);
1045
+ ggml_context * ctx = ctx_ptr.get();
944
1046
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
945
1047
  if (tensor == nullptr) {
946
1048
  GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
947
- ggml_free(ctx);
948
1049
  return false;
949
1050
  }
950
1051
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
@@ -955,13 +1056,90 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
955
1056
  const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
956
1057
 
957
1058
  if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
958
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1059
+ GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
1060
+ __func__, in_tensor->data, offset, size, p0, p1);
1061
+ return false;
959
1062
  }
960
1063
  }
961
1064
 
962
1065
  const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
1066
+ if (cache_dir && size > HASH_THRESHOLD) {
1067
+ uint64_t hash = fnv_hash((const uint8_t*)data, size);
1068
+ char hash_str[17];
1069
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1070
+ // save to cache_dir/hash_str
1071
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
1072
+ std::ofstream ofs(cache_file, std::ios::binary);
1073
+ ofs.write((const char *)data, size);
1074
+ printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
1075
+ }
963
1076
  ggml_backend_tensor_set(tensor, data, offset, size);
964
- ggml_free(ctx);
1077
+ return true;
1078
+ }
1079
+
1080
+ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
1081
+ if (!cache_dir) {
1082
+ return false;
1083
+ }
1084
+ char hash_str[17];
1085
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
1086
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
1087
+ if (!fs::exists(cache_file)) {
1088
+ return false;
1089
+ }
1090
+ std::ifstream ifs(cache_file, std::ios::binary);
1091
+ ifs.seekg(0, std::ios::end);
1092
+ size_t size = ifs.tellg();
1093
+ ifs.seekg(0, std::ios::beg);
1094
+ data.resize(size);
1095
+ ifs.read((char *)data.data(), size);
1096
+ return true;
1097
+ }
1098
+
1099
+ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
1100
+ {
1101
+ // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1102
+ if (input.size() != sizeof(rpc_tensor) + 16) {
1103
+ return false;
1104
+ }
1105
+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
1106
+ uint64_t offset;
1107
+ memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
1108
+ const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
1109
+ std::vector<uint8_t> cached_file;
1110
+ if (!get_cached_file(*hash, cached_file)) {
1111
+ response.result = 0;
1112
+ return true;
1113
+ }
1114
+ size_t size = cached_file.size();
1115
+ struct ggml_init_params params {
1116
+ /*.mem_size =*/ ggml_tensor_overhead(),
1117
+ /*.mem_buffer =*/ NULL,
1118
+ /*.no_alloc =*/ true,
1119
+ };
1120
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1121
+ GGML_ASSERT(ctx_ptr != nullptr);
1122
+ ggml_context * ctx = ctx_ptr.get();
1123
+ ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1124
+ if (tensor == nullptr) {
1125
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1126
+ return false;
1127
+ }
1128
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
1129
+
1130
+ // sanitize tensor->data
1131
+ {
1132
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
1133
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
1134
+
1135
+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1136
+ GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1137
+ __func__, in_tensor->data, offset, size, *hash, p0, p1);
1138
+ return false;
1139
+ }
1140
+ }
1141
+ ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
1142
+ response.result = 1;
965
1143
  return true;
966
1144
  }
967
1145
 
@@ -971,11 +1149,12 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
971
1149
  /*.mem_buffer =*/ NULL,
972
1150
  /*.no_alloc =*/ true,
973
1151
  };
974
- struct ggml_context * ctx = ggml_init(params);
1152
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1153
+ GGML_ASSERT(ctx_ptr != nullptr);
1154
+ ggml_context * ctx = ctx_ptr.get();
975
1155
  ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
976
1156
  if (tensor == nullptr) {
977
1157
  GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
978
- ggml_free(ctx);
979
1158
  return false;
980
1159
  }
981
1160
 
@@ -991,11 +1170,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
991
1170
  // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
992
1171
  // Currently unimplemented.
993
1172
  GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
994
- ggml_free(ctx);
995
1173
  return false;
996
1174
  }
997
1175
 
998
- ggml_free(ctx);
999
1176
  return true;
1000
1177
  }
1001
1178
 
@@ -1005,11 +1182,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
1005
1182
  /*.mem_buffer =*/ NULL,
1006
1183
  /*.no_alloc =*/ true,
1007
1184
  };
1008
- struct ggml_context * ctx = ggml_init(params);
1185
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1186
+ GGML_ASSERT(ctx_ptr != nullptr);
1187
+ ggml_context * ctx = ctx_ptr.get();
1009
1188
  ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1010
1189
  if (tensor == nullptr) {
1011
1190
  GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
1012
- ggml_free(ctx);
1013
1191
  return false;
1014
1192
  }
1015
1193
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
@@ -1022,13 +1200,14 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
1022
1200
  if (request.tensor.data + request.offset < p0 ||
1023
1201
  request.tensor.data + request.offset >= p1 ||
1024
1202
  request.size > (p1 - request.tensor.data - request.offset)) {
1025
- GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
1203
+ GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
1204
+ __func__, request.tensor.data, request.offset, request.size, p0, p1);
1205
+ return false;
1026
1206
  }
1027
1207
  }
1028
1208
 
1029
1209
  response.resize(request.size, 0);
1030
1210
  ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
1031
- ggml_free(ctx);
1032
1211
  return true;
1033
1212
  }
1034
1213
 
@@ -1038,12 +1217,14 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
1038
1217
  /*.mem_buffer =*/ NULL,
1039
1218
  /*.no_alloc =*/ true,
1040
1219
  };
1041
- struct ggml_context * ctx = ggml_init(params);
1220
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1221
+ GGML_ASSERT(ctx_ptr != nullptr);
1222
+ ggml_context * ctx = ctx_ptr.get();
1223
+
1042
1224
  ggml_tensor * src = deserialize_tensor(ctx, &request.src);
1043
1225
  ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
1044
1226
  if (src == nullptr || dst == nullptr) {
1045
1227
  GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
1046
- ggml_free(ctx);
1047
1228
  return false;
1048
1229
  }
1049
1230
 
@@ -1061,7 +1242,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
1061
1242
  dst_data + src_size,
1062
1243
  dst_base,
1063
1244
  dst_base + dst_buf_sz);
1064
- ggml_free(ctx);
1065
1245
  return false;
1066
1246
  }
1067
1247
 
@@ -1069,7 +1249,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
1069
1249
  __func__, (void*) src->buffer, (void*) dst->buffer);
1070
1250
 
1071
1251
  response.result = ggml_backend_buffer_copy_tensor(src, dst);
1072
- ggml_free(ctx);
1073
1252
  return true;
1074
1253
  }
1075
1254
 
@@ -1077,22 +1256,50 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
1077
1256
  struct ggml_context * ctx,
1078
1257
  const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
1079
1258
  std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
1080
- if (id == 0) {
1081
- return nullptr;
1082
- }
1083
1259
  if (tensor_map.find(id) != tensor_map.end()) {
1084
1260
  return tensor_map[id];
1085
1261
  }
1086
- const rpc_tensor * tensor = tensor_ptrs.at(id);
1262
+ // Safely find the tensor pointer
1263
+ auto it_ptr = tensor_ptrs.find(id);
1264
+ if (it_ptr == tensor_ptrs.end()) {
1265
+ return nullptr;
1266
+ }
1267
+ const rpc_tensor * tensor = it_ptr->second;
1268
+
1087
1269
  struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
1088
1270
  if (result == nullptr) {
1089
1271
  return nullptr;
1090
1272
  }
1091
1273
  tensor_map[id] = result;
1092
1274
  for (int i = 0; i < GGML_MAX_SRC; i++) {
1093
- result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1275
+ // Check if the source ID is 0 before calling create_node recursively
1276
+ if (tensor->src[i] == 0) {
1277
+ result->src[i] = nullptr;
1278
+ } else {
1279
+ result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
1280
+ // If the recursive call failed for a non-zero ID, propagate the error
1281
+ if (result->src[i] == nullptr) {
1282
+ GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1283
+ __func__, i, tensor->src[i], id);
1284
+ // Must return nullptr to signal failure up the call stack
1285
+ return nullptr;
1286
+ }
1287
+ }
1288
+ }
1289
+
1290
+ // Handle view_src similarly
1291
+ if (tensor->view_src == 0) {
1292
+ result->view_src = nullptr;
1293
+ } else {
1294
+ result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
1295
+ // If the recursive call failed for a non-zero ID, propagate the error
1296
+ if (result->view_src == nullptr) {
1297
+ GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
1298
+ __func__, tensor->view_src, id);
1299
+ // Must return nullptr to signal failure up the call stack
1300
+ return nullptr;
1301
+ }
1094
1302
  }
1095
- result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
1096
1303
  result->view_offs = tensor->view_offs;
1097
1304
  return result;
1098
1305
  }
@@ -1118,12 +1325,15 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1118
1325
  GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
1119
1326
 
1120
1327
  size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1328
+
1121
1329
  struct ggml_init_params params = {
1122
1330
  /*.mem_size =*/ buf_size,
1123
1331
  /*.mem_buffer =*/ NULL,
1124
1332
  /*.no_alloc =*/ true,
1125
1333
  };
1126
- struct ggml_context * ctx = ggml_init(params);
1334
+ ggml_context_ptr ctx_ptr { ggml_init(params) };
1335
+ GGML_ASSERT(ctx_ptr != nullptr);
1336
+ ggml_context * ctx = ctx_ptr.get();
1127
1337
  struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
1128
1338
  graph->n_nodes = n_nodes;
1129
1339
  std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
@@ -1135,10 +1345,17 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
1135
1345
  int64_t id;
1136
1346
  memcpy(&id, &nodes[i], sizeof(id));
1137
1347
  graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
1348
+
1349
+ // Check if create_node failed for a *non-zero* ID.
1350
+ // If id was 0, create_node returning nullptr is expected.
1351
+ // If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
1352
+ if (graph->nodes[i] == nullptr && id != 0) {
1353
+ GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
1354
+ return false;
1355
+ }
1138
1356
  }
1139
1357
  ggml_status status = ggml_backend_graph_compute(backend, graph);
1140
1358
  response.result = status;
1141
- ggml_free(ctx);
1142
1359
  return true;
1143
1360
  }
1144
1361
 
@@ -1148,10 +1365,27 @@ rpc_server::~rpc_server() {
1148
1365
  }
1149
1366
  }
1150
1367
 
1151
- static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1152
- rpc_server server(backend);
1368
+ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1369
+ sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1370
+ rpc_server server(backend, cache_dir);
1371
+ uint8_t cmd;
1372
+ if (!recv_data(sockfd, &cmd, 1)) {
1373
+ return;
1374
+ }
1375
+ // the first command sent by the client must be HELLO
1376
+ if (cmd != RPC_CMD_HELLO) {
1377
+ fprintf(stderr, "Expected HELLO command, update client\n");
1378
+ return;
1379
+ }
1380
+ if (!recv_msg(sockfd, nullptr, 0)) {
1381
+ return;
1382
+ }
1383
+ rpc_msg_hello_rsp response;
1384
+ server.hello(response);
1385
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1386
+ return;
1387
+ }
1153
1388
  while (true) {
1154
- uint8_t cmd;
1155
1389
  if (!recv_data(sockfd, &cmd, 1)) {
1156
1390
  break;
1157
1391
  }
@@ -1161,6 +1395,10 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1161
1395
  break;
1162
1396
  }
1163
1397
  switch (cmd) {
1398
+ case RPC_CMD_HELLO: {
1399
+ // HELLO command is handled above
1400
+ return;
1401
+ }
1164
1402
  case RPC_CMD_ALLOC_BUFFER: {
1165
1403
  rpc_msg_alloc_buffer_req request;
1166
1404
  if (!recv_msg(sockfd, &request, sizeof(request))) {
@@ -1179,7 +1417,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1179
1417
  return;
1180
1418
  }
1181
1419
  rpc_msg_get_alloc_size_rsp response;
1182
- server.get_alloc_size(request, response);
1420
+ if (!server.get_alloc_size(request, response)) {
1421
+ return;
1422
+ }
1183
1423
  if (!send_msg(sockfd, &response, sizeof(response))) {
1184
1424
  return;
1185
1425
  }
@@ -1255,7 +1495,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1255
1495
  if (!server.set_tensor(input)) {
1256
1496
  return;
1257
1497
  }
1258
- if (!send_msg(sockfd, nullptr, 0)) {
1498
+ break;
1499
+ }
1500
+ case RPC_CMD_SET_TENSOR_HASH: {
1501
+ std::vector<uint8_t> input;
1502
+ if (!recv_msg(sockfd, input)) {
1503
+ return;
1504
+ }
1505
+ rpc_msg_set_tensor_hash_rsp response;
1506
+ if (!server.set_tensor_hash(input, response)) {
1507
+ return;
1508
+ }
1509
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1259
1510
  return;
1260
1511
  }
1261
1512
  break;
@@ -1335,7 +1586,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1335
1586
  }
1336
1587
  }
1337
1588
 
1338
- void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1589
+ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
1590
+ const char * cache_dir,
1591
+ size_t free_mem, size_t total_mem) {
1339
1592
  std::string host;
1340
1593
  int port;
1341
1594
  if (!parse_endpoint(endpoint, host, port)) {
@@ -1364,7 +1617,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
1364
1617
  }
1365
1618
  printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1366
1619
  fflush(stdout);
1367
- rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1620
+ rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
1368
1621
  printf("Client connection closed\n");
1369
1622
  fflush(stdout);
1370
1623
  }