@fugood/llama.node 0.3.2 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +2 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +8 -9
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +43 -9
- package/src/llama.cpp/.github/workflows/docker.yml +3 -0
- package/src/llama.cpp/CMakeLists.txt +7 -4
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +0 -2
- package/src/llama.cpp/common/arg.cpp +642 -607
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +79 -281
- package/src/llama.cpp/common/common.h +130 -100
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +116 -108
- package/src/llama.cpp/common/sampling.h +20 -20
- package/src/llama.cpp/docs/build.md +37 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +14 -14
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
- package/src/llama.cpp/examples/infill/infill.cpp +40 -86
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/clip.cpp +1 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +37 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
- package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
- package/src/llama.cpp/examples/main/main.cpp +64 -109
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
- package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
- package/src/llama.cpp/examples/server/server.cpp +553 -691
- package/src/llama.cpp/examples/server/utils.hpp +312 -25
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +128 -96
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
- package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +53 -393
- package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
- package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
- package/src/llama.cpp/include/llama.h +67 -33
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +745 -105
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +49 -9
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +2636 -2406
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/tests/CMakeLists.txt +1 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
- package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +1 -0
- package/src/llama.cpp/tests/test-sampling.cpp +162 -137
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
- /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
|
@@ -25,7 +25,7 @@
|
|
|
25
25
|
# include <netdb.h>
|
|
26
26
|
# include <unistd.h>
|
|
27
27
|
#endif
|
|
28
|
-
#include <
|
|
28
|
+
#include <cstring>
|
|
29
29
|
|
|
30
30
|
#define UNUSED GGML_UNUSED
|
|
31
31
|
|
|
@@ -57,8 +57,9 @@ struct socket_t {
|
|
|
57
57
|
}
|
|
58
58
|
};
|
|
59
59
|
|
|
60
|
-
//
|
|
60
|
+
// all RPC structures must be packed
|
|
61
61
|
#pragma pack(push, 1)
|
|
62
|
+
// ggml_tensor is serialized into rpc_tensor
|
|
62
63
|
struct rpc_tensor {
|
|
63
64
|
uint64_t id;
|
|
64
65
|
uint32_t type;
|
|
@@ -76,7 +77,6 @@ struct rpc_tensor {
|
|
|
76
77
|
|
|
77
78
|
char padding[4];
|
|
78
79
|
};
|
|
79
|
-
#pragma pack(pop)
|
|
80
80
|
|
|
81
81
|
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
|
|
82
82
|
|
|
@@ -96,6 +96,65 @@ enum rpc_cmd {
|
|
|
96
96
|
RPC_CMD_COUNT,
|
|
97
97
|
};
|
|
98
98
|
|
|
99
|
+
struct rpc_msg_alloc_buffer_req {
|
|
100
|
+
uint64_t size;
|
|
101
|
+
};
|
|
102
|
+
|
|
103
|
+
struct rpc_msg_alloc_buffer_rsp {
|
|
104
|
+
uint64_t remote_ptr;
|
|
105
|
+
uint64_t remote_size;
|
|
106
|
+
};
|
|
107
|
+
|
|
108
|
+
struct rpc_msg_get_alignment_rsp {
|
|
109
|
+
uint64_t alignment;
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
struct rpc_msg_get_max_size_rsp {
|
|
113
|
+
uint64_t max_size;
|
|
114
|
+
};
|
|
115
|
+
|
|
116
|
+
struct rpc_msg_buffer_get_base_req {
|
|
117
|
+
uint64_t remote_ptr;
|
|
118
|
+
};
|
|
119
|
+
|
|
120
|
+
struct rpc_msg_buffer_get_base_rsp {
|
|
121
|
+
uint64_t base_ptr;
|
|
122
|
+
};
|
|
123
|
+
|
|
124
|
+
struct rpc_msg_free_buffer_req {
|
|
125
|
+
uint64_t remote_ptr;
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
struct rpc_msg_buffer_clear_req {
|
|
129
|
+
uint64_t remote_ptr;
|
|
130
|
+
uint8_t value;
|
|
131
|
+
};
|
|
132
|
+
|
|
133
|
+
struct rpc_msg_get_tensor_req {
|
|
134
|
+
rpc_tensor tensor;
|
|
135
|
+
uint64_t offset;
|
|
136
|
+
uint64_t size;
|
|
137
|
+
};
|
|
138
|
+
|
|
139
|
+
struct rpc_msg_copy_tensor_req {
|
|
140
|
+
rpc_tensor src;
|
|
141
|
+
rpc_tensor dst;
|
|
142
|
+
};
|
|
143
|
+
|
|
144
|
+
struct rpc_msg_copy_tensor_rsp {
|
|
145
|
+
uint8_t result;
|
|
146
|
+
};
|
|
147
|
+
|
|
148
|
+
struct rpc_msg_graph_compute_rsp {
|
|
149
|
+
uint8_t result;
|
|
150
|
+
};
|
|
151
|
+
|
|
152
|
+
struct rpc_msg_get_device_memory_rsp {
|
|
153
|
+
uint64_t free_mem;
|
|
154
|
+
uint64_t total_mem;
|
|
155
|
+
};
|
|
156
|
+
#pragma pack(pop)
|
|
157
|
+
|
|
99
158
|
// RPC data structures
|
|
100
159
|
|
|
101
160
|
static ggml_guid_t ggml_backend_rpc_guid() {
|
|
@@ -119,7 +178,6 @@ struct ggml_backend_rpc_buffer_context {
|
|
|
119
178
|
std::shared_ptr<socket_t> sock;
|
|
120
179
|
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
|
121
180
|
uint64_t remote_ptr;
|
|
122
|
-
std::string name;
|
|
123
181
|
};
|
|
124
182
|
|
|
125
183
|
// RPC helper functions
|
|
@@ -240,6 +298,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|
|
240
298
|
return true;
|
|
241
299
|
}
|
|
242
300
|
|
|
301
|
+
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
|
|
302
|
+
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
|
|
303
|
+
return false;
|
|
304
|
+
}
|
|
305
|
+
return send_data(sockfd, msg, msg_size);
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
|
|
309
|
+
uint64_t size;
|
|
310
|
+
if (!recv_data(sockfd, &size, sizeof(size))) {
|
|
311
|
+
return false;
|
|
312
|
+
}
|
|
313
|
+
if (size != msg_size) {
|
|
314
|
+
return false;
|
|
315
|
+
}
|
|
316
|
+
return recv_data(sockfd, msg, msg_size);
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
|
|
320
|
+
uint64_t size;
|
|
321
|
+
if (!recv_data(sockfd, &size, sizeof(size))) {
|
|
322
|
+
return false;
|
|
323
|
+
}
|
|
324
|
+
try {
|
|
325
|
+
input.resize(size);
|
|
326
|
+
} catch (const std::bad_alloc & e) {
|
|
327
|
+
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
|
|
328
|
+
return false;
|
|
329
|
+
}
|
|
330
|
+
return recv_data(sockfd, input.data(), size);
|
|
331
|
+
}
|
|
332
|
+
|
|
243
333
|
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
|
244
334
|
size_t pos = endpoint.find(':');
|
|
245
335
|
if (pos == std::string::npos) {
|
|
@@ -252,28 +342,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
|
|
|
252
342
|
|
|
253
343
|
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
|
254
344
|
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
|
255
|
-
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const
|
|
345
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
|
|
256
346
|
uint8_t cmd_byte = cmd;
|
|
257
347
|
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
|
258
348
|
return false;
|
|
259
349
|
}
|
|
260
|
-
uint64_t input_size = input.size();
|
|
261
350
|
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
|
|
262
351
|
return false;
|
|
263
352
|
}
|
|
264
|
-
if (!send_data(sock->fd, input
|
|
353
|
+
if (!send_data(sock->fd, input, input_size)) {
|
|
265
354
|
return false;
|
|
266
355
|
}
|
|
267
|
-
|
|
268
|
-
if
|
|
356
|
+
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
|
|
357
|
+
// even if we do, we can skip sending output_size from the server for commands with known output size
|
|
358
|
+
uint64_t out_size;
|
|
359
|
+
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
|
|
269
360
|
return false;
|
|
270
361
|
}
|
|
271
|
-
if (
|
|
272
|
-
|
|
273
|
-
return true;
|
|
362
|
+
if (out_size != output_size) {
|
|
363
|
+
return false;
|
|
274
364
|
}
|
|
275
|
-
output
|
|
276
|
-
if (!recv_data(sock->fd, output.data(), output_size)) {
|
|
365
|
+
if (!recv_data(sock->fd, output, output_size)) {
|
|
277
366
|
return false;
|
|
278
367
|
}
|
|
279
368
|
return true;
|
|
@@ -319,21 +408,11 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
319
408
|
return sock;
|
|
320
409
|
}
|
|
321
410
|
|
|
322
|
-
static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
323
|
-
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
324
|
-
return ctx->name.c_str();
|
|
325
|
-
}
|
|
326
|
-
|
|
327
411
|
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
328
412
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
uint64_t remote_ptr = ctx->remote_ptr;
|
|
332
|
-
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
|
333
|
-
std::vector<uint8_t> output;
|
|
334
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
|
|
413
|
+
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
|
|
414
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
|
|
335
415
|
GGML_ASSERT(status);
|
|
336
|
-
GGML_ASSERT(output.empty());
|
|
337
416
|
delete ctx;
|
|
338
417
|
}
|
|
339
418
|
|
|
@@ -342,20 +421,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
|
342
421
|
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
|
343
422
|
return ctx->base_cache[buffer];
|
|
344
423
|
}
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
|
349
|
-
std::vector<uint8_t> output;
|
|
350
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
|
|
424
|
+
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
|
425
|
+
rpc_msg_buffer_get_base_rsp response;
|
|
426
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
|
351
427
|
GGML_ASSERT(status);
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
memcpy(&base_ptr, output.data(), sizeof(base_ptr));
|
|
356
|
-
void * base = reinterpret_cast<void *>(base_ptr);
|
|
357
|
-
ctx->base_cache[buffer] = base;
|
|
358
|
-
return base;
|
|
428
|
+
void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
|
429
|
+
ctx->base_cache[buffer] = base_ptr;
|
|
430
|
+
return base_ptr;
|
|
359
431
|
}
|
|
360
432
|
|
|
361
433
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
@@ -405,26 +477,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|
|
405
477
|
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
|
406
478
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
407
479
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
|
408
|
-
|
|
409
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
|
|
480
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
|
|
410
481
|
GGML_ASSERT(status);
|
|
411
482
|
}
|
|
412
483
|
|
|
413
484
|
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
414
485
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
421
|
-
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
|
422
|
-
std::vector<uint8_t> output;
|
|
423
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
|
|
486
|
+
rpc_msg_get_tensor_req request;
|
|
487
|
+
request.tensor = serialize_tensor(tensor);
|
|
488
|
+
request.offset = offset;
|
|
489
|
+
request.size = size;
|
|
490
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
|
|
424
491
|
GGML_ASSERT(status);
|
|
425
|
-
GGML_ASSERT(output.size() == size);
|
|
426
|
-
// output serialization format: | data (size bytes) |
|
|
427
|
-
memcpy(data, output.data(), size);
|
|
428
492
|
}
|
|
429
493
|
|
|
430
494
|
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
|
@@ -437,35 +501,23 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
|
|
|
437
501
|
return false;
|
|
438
502
|
}
|
|
439
503
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
|
446
|
-
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
|
447
|
-
std::vector<uint8_t> output;
|
|
448
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
|
|
504
|
+
rpc_msg_copy_tensor_req request;
|
|
505
|
+
request.src = serialize_tensor(src);
|
|
506
|
+
request.dst = serialize_tensor(dst);
|
|
507
|
+
rpc_msg_copy_tensor_rsp response;
|
|
508
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
|
449
509
|
GGML_ASSERT(status);
|
|
450
|
-
|
|
451
|
-
GGML_ASSERT(output.size() == 1);
|
|
452
|
-
return output[0];
|
|
510
|
+
return response.result;
|
|
453
511
|
}
|
|
454
512
|
|
|
455
513
|
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
456
514
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
std::vector<uint8_t> input(input_size, 0);
|
|
460
|
-
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
|
461
|
-
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
|
462
|
-
std::vector<uint8_t> output;
|
|
463
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
|
|
515
|
+
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
|
|
516
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
|
|
464
517
|
GGML_ASSERT(status);
|
|
465
518
|
}
|
|
466
519
|
|
|
467
520
|
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
|
468
|
-
/* .get_name = */ ggml_backend_rpc_buffer_get_name,
|
|
469
521
|
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
|
|
470
522
|
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
|
|
471
523
|
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
|
|
@@ -484,25 +536,16 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
|
|
|
484
536
|
|
|
485
537
|
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
486
538
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
std::vector<uint8_t> input(input_size, 0);
|
|
490
|
-
memcpy(input.data(), &size, sizeof(size));
|
|
491
|
-
std::vector<uint8_t> output;
|
|
539
|
+
rpc_msg_alloc_buffer_req request = {size};
|
|
540
|
+
rpc_msg_alloc_buffer_rsp response;
|
|
492
541
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
493
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER,
|
|
542
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
|
494
543
|
GGML_ASSERT(status);
|
|
495
|
-
|
|
496
|
-
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
497
|
-
uint64_t remote_ptr;
|
|
498
|
-
memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
|
|
499
|
-
size_t remote_size;
|
|
500
|
-
memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
|
|
501
|
-
if (remote_ptr != 0) {
|
|
544
|
+
if (response.remote_ptr != 0) {
|
|
502
545
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
503
546
|
ggml_backend_rpc_buffer_interface,
|
|
504
|
-
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr
|
|
505
|
-
remote_size);
|
|
547
|
+
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
|
|
548
|
+
response.remote_size);
|
|
506
549
|
return buffer;
|
|
507
550
|
} else {
|
|
508
551
|
return nullptr;
|
|
@@ -510,16 +553,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|
|
510
553
|
}
|
|
511
554
|
|
|
512
555
|
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
std::vector<uint8_t> output;
|
|
516
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
|
|
556
|
+
rpc_msg_get_alignment_rsp response;
|
|
557
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
|
|
517
558
|
GGML_ASSERT(status);
|
|
518
|
-
|
|
519
|
-
// output serialization format: | alignment (8 bytes) |
|
|
520
|
-
uint64_t alignment;
|
|
521
|
-
memcpy(&alignment, output.data(), sizeof(alignment));
|
|
522
|
-
return alignment;
|
|
559
|
+
return response.alignment;
|
|
523
560
|
}
|
|
524
561
|
|
|
525
562
|
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
@@ -528,16 +565,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
|
|
|
528
565
|
}
|
|
529
566
|
|
|
530
567
|
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
std::vector<uint8_t> output;
|
|
534
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
|
|
568
|
+
rpc_msg_get_max_size_rsp response;
|
|
569
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
|
|
535
570
|
GGML_ASSERT(status);
|
|
536
|
-
|
|
537
|
-
// output serialization format: | max_size (8 bytes) |
|
|
538
|
-
uint64_t max_size;
|
|
539
|
-
memcpy(&max_size, output.data(), sizeof(max_size));
|
|
540
|
-
return max_size;
|
|
571
|
+
return response.max_size;
|
|
541
572
|
}
|
|
542
573
|
|
|
543
574
|
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
@@ -571,11 +602,6 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
|
|
571
602
|
delete backend;
|
|
572
603
|
}
|
|
573
604
|
|
|
574
|
-
static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
|
|
575
|
-
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
|
576
|
-
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
577
|
-
}
|
|
578
|
-
|
|
579
605
|
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
|
580
606
|
UNUSED(backend);
|
|
581
607
|
// this is no-op because we don't have any async operations
|
|
@@ -622,34 +648,16 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|
|
622
648
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
623
649
|
std::vector<uint8_t> input;
|
|
624
650
|
serialize_graph(cgraph, input);
|
|
625
|
-
|
|
651
|
+
rpc_msg_graph_compute_rsp response;
|
|
626
652
|
auto sock = get_socket(rpc_ctx->endpoint);
|
|
627
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input,
|
|
653
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
|
628
654
|
GGML_ASSERT(status);
|
|
629
|
-
|
|
630
|
-
return (enum ggml_status)output[0];
|
|
631
|
-
}
|
|
632
|
-
|
|
633
|
-
static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
|
634
|
-
UNUSED(backend);
|
|
635
|
-
UNUSED(op);
|
|
636
|
-
//TODO: call the remote backend and cache the results
|
|
637
|
-
return true;
|
|
638
|
-
}
|
|
639
|
-
|
|
640
|
-
static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
641
|
-
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
|
642
|
-
return false;
|
|
643
|
-
}
|
|
644
|
-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
645
|
-
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
646
|
-
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
|
655
|
+
return (enum ggml_status)response.result;
|
|
647
656
|
}
|
|
648
657
|
|
|
649
658
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
|
650
659
|
/* .get_name = */ ggml_backend_rpc_name,
|
|
651
660
|
/* .free = */ ggml_backend_rpc_free,
|
|
652
|
-
/* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
|
|
653
661
|
/* .set_tensor_async = */ NULL,
|
|
654
662
|
/* .get_tensor_async = */ NULL,
|
|
655
663
|
/* .cpy_tensor_async = */ NULL,
|
|
@@ -659,14 +667,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
659
667
|
/* .graph_plan_update = */ NULL,
|
|
660
668
|
/* .graph_plan_compute = */ NULL,
|
|
661
669
|
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
|
662
|
-
/* .supports_op = */ ggml_backend_rpc_supports_op,
|
|
663
|
-
/* .supports_buft = */ ggml_backend_rpc_supports_buft,
|
|
664
|
-
/* .offload_op = */ NULL,
|
|
665
670
|
/* .event_record = */ NULL,
|
|
666
671
|
/* .event_wait = */ NULL,
|
|
667
672
|
};
|
|
668
673
|
|
|
669
|
-
|
|
674
|
+
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
|
670
675
|
static std::mutex mutex;
|
|
671
676
|
std::lock_guard<std::mutex> lock(mutex);
|
|
672
677
|
// NOTE: buffer types are allocated and never freed; this is by design
|
|
@@ -691,7 +696,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
|
|
|
691
696
|
|
|
692
697
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
|
693
698
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
|
694
|
-
/* .device = */
|
|
699
|
+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
|
695
700
|
/* .context = */ buft_ctx
|
|
696
701
|
};
|
|
697
702
|
buft_map[endpoint] = buft;
|
|
@@ -707,33 +712,25 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
|
707
712
|
ggml_backend_t backend = new ggml_backend {
|
|
708
713
|
/* .guid = */ ggml_backend_rpc_guid(),
|
|
709
714
|
/* .interface = */ ggml_backend_rpc_interface,
|
|
710
|
-
/* .device = */
|
|
715
|
+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
|
711
716
|
/* .context = */ ctx
|
|
712
717
|
};
|
|
713
718
|
return backend;
|
|
714
719
|
}
|
|
715
720
|
|
|
716
|
-
|
|
721
|
+
bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|
717
722
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
|
718
723
|
}
|
|
719
724
|
|
|
720
725
|
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
std::vector<uint8_t> output;
|
|
724
|
-
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
|
|
726
|
+
rpc_msg_get_device_memory_rsp response;
|
|
727
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
|
|
725
728
|
GGML_ASSERT(status);
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
uint64_t free_mem;
|
|
729
|
-
memcpy(&free_mem, output.data(), sizeof(free_mem));
|
|
730
|
-
uint64_t total_mem;
|
|
731
|
-
memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
|
|
732
|
-
*free = free_mem;
|
|
733
|
-
*total = total_mem;
|
|
729
|
+
*free = response.free_mem;
|
|
730
|
+
*total = response.total_mem;
|
|
734
731
|
}
|
|
735
732
|
|
|
736
|
-
|
|
733
|
+
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
|
737
734
|
auto sock = get_socket(endpoint);
|
|
738
735
|
if (sock == nullptr) {
|
|
739
736
|
*free = 0;
|
|
@@ -750,16 +747,16 @@ public:
|
|
|
750
747
|
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
|
751
748
|
~rpc_server();
|
|
752
749
|
|
|
753
|
-
|
|
754
|
-
void get_alignment(
|
|
755
|
-
void get_max_size(
|
|
756
|
-
bool buffer_get_base(const
|
|
757
|
-
bool free_buffer(const
|
|
758
|
-
bool buffer_clear(const
|
|
750
|
+
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
|
751
|
+
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
|
752
|
+
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
|
753
|
+
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
|
754
|
+
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
|
755
|
+
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
|
759
756
|
bool set_tensor(const std::vector<uint8_t> & input);
|
|
760
|
-
bool get_tensor(const
|
|
761
|
-
bool copy_tensor(const
|
|
762
|
-
bool graph_compute(const std::vector<uint8_t> & input,
|
|
757
|
+
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
|
758
|
+
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
|
759
|
+
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
|
763
760
|
|
|
764
761
|
private:
|
|
765
762
|
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
|
@@ -773,80 +770,50 @@ private:
|
|
|
773
770
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
|
774
771
|
};
|
|
775
772
|
|
|
776
|
-
|
|
777
|
-
// input serialization format: | size (8 bytes) |
|
|
778
|
-
if (input.size() != sizeof(uint64_t)) {
|
|
779
|
-
return false;
|
|
780
|
-
}
|
|
781
|
-
uint64_t size;
|
|
782
|
-
memcpy(&size, input.data(), sizeof(size));
|
|
773
|
+
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
|
783
774
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
|
784
|
-
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
|
|
785
|
-
|
|
786
|
-
|
|
775
|
+
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
|
776
|
+
response.remote_ptr = 0;
|
|
777
|
+
response.remote_size = 0;
|
|
787
778
|
if (buffer != nullptr) {
|
|
788
|
-
remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
|
789
|
-
remote_size = buffer->size;
|
|
790
|
-
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
|
|
779
|
+
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
|
780
|
+
response.remote_size = buffer->size;
|
|
781
|
+
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
|
|
791
782
|
buffers.insert(buffer);
|
|
792
783
|
} else {
|
|
793
|
-
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
|
|
784
|
+
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
|
|
794
785
|
}
|
|
795
|
-
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
796
|
-
output.resize(2*sizeof(uint64_t), 0);
|
|
797
|
-
memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
|
|
798
|
-
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
|
|
799
|
-
return true;
|
|
800
786
|
}
|
|
801
787
|
|
|
802
|
-
void rpc_server::get_alignment(
|
|
788
|
+
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
|
|
803
789
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
|
804
790
|
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
|
805
791
|
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
|
806
|
-
|
|
807
|
-
output.resize(sizeof(uint64_t), 0);
|
|
808
|
-
memcpy(output.data(), &alignment, sizeof(alignment));
|
|
792
|
+
response.alignment = alignment;
|
|
809
793
|
}
|
|
810
794
|
|
|
811
|
-
void rpc_server::get_max_size(
|
|
795
|
+
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
|
|
812
796
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
|
813
797
|
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
|
814
798
|
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
|
815
|
-
|
|
816
|
-
output.resize(sizeof(uint64_t), 0);
|
|
817
|
-
memcpy(output.data(), &max_size, sizeof(max_size));
|
|
799
|
+
response.max_size = max_size;
|
|
818
800
|
}
|
|
819
801
|
|
|
820
|
-
bool rpc_server::buffer_get_base(const
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
return false;
|
|
824
|
-
}
|
|
825
|
-
uint64_t remote_ptr;
|
|
826
|
-
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
827
|
-
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
|
|
828
|
-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
802
|
+
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
|
803
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
|
804
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
829
805
|
if (buffers.find(buffer) == buffers.end()) {
|
|
830
806
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
|
831
807
|
return false;
|
|
832
808
|
}
|
|
833
809
|
void * base = ggml_backend_buffer_get_base(buffer);
|
|
834
|
-
|
|
835
|
-
uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
|
|
836
|
-
output.resize(sizeof(uint64_t), 0);
|
|
837
|
-
memcpy(output.data(), &base_ptr, sizeof(base_ptr));
|
|
810
|
+
response.base_ptr = reinterpret_cast<uint64_t>(base);
|
|
838
811
|
return true;
|
|
839
812
|
}
|
|
840
813
|
|
|
841
|
-
bool rpc_server::free_buffer(const
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
return false;
|
|
845
|
-
}
|
|
846
|
-
uint64_t remote_ptr;
|
|
847
|
-
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
848
|
-
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
|
|
849
|
-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
814
|
+
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
|
|
815
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
|
816
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
850
817
|
if (buffers.find(buffer) == buffers.end()) {
|
|
851
818
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
|
852
819
|
return false;
|
|
@@ -856,22 +823,14 @@ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
|
|
|
856
823
|
return true;
|
|
857
824
|
}
|
|
858
825
|
|
|
859
|
-
bool rpc_server::buffer_clear(const
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
return false;
|
|
863
|
-
}
|
|
864
|
-
uint64_t remote_ptr;
|
|
865
|
-
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
866
|
-
uint8_t value;
|
|
867
|
-
memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
|
|
868
|
-
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
|
|
869
|
-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
826
|
+
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
|
827
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
|
|
828
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
870
829
|
if (buffers.find(buffer) == buffers.end()) {
|
|
871
830
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
|
872
831
|
return false;
|
|
873
832
|
}
|
|
874
|
-
ggml_backend_buffer_clear(buffer, value);
|
|
833
|
+
ggml_backend_buffer_clear(buffer, request.value);
|
|
875
834
|
return true;
|
|
876
835
|
}
|
|
877
836
|
|
|
@@ -946,74 +905,55 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
946
905
|
return true;
|
|
947
906
|
}
|
|
948
907
|
|
|
949
|
-
bool rpc_server::get_tensor(const
|
|
950
|
-
// serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
|
|
951
|
-
if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
|
|
952
|
-
return false;
|
|
953
|
-
}
|
|
954
|
-
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
|
955
|
-
uint64_t offset;
|
|
956
|
-
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
|
957
|
-
uint64_t size;
|
|
958
|
-
memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
|
|
959
|
-
|
|
908
|
+
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
|
|
960
909
|
struct ggml_init_params params {
|
|
961
910
|
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
962
911
|
/*.mem_buffer =*/ NULL,
|
|
963
912
|
/*.no_alloc =*/ true,
|
|
964
913
|
};
|
|
965
914
|
struct ggml_context * ctx = ggml_init(params);
|
|
966
|
-
ggml_tensor * tensor = deserialize_tensor(ctx,
|
|
915
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
967
916
|
if (tensor == nullptr) {
|
|
968
917
|
GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
|
|
969
918
|
ggml_free(ctx);
|
|
970
919
|
return false;
|
|
971
920
|
}
|
|
972
|
-
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
|
921
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
|
973
922
|
|
|
974
923
|
// sanitize tensor->data
|
|
975
924
|
{
|
|
976
925
|
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
|
977
926
|
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
|
978
927
|
|
|
979
|
-
if (
|
|
980
|
-
|
|
928
|
+
if (request.tensor.data + request.offset < p0 ||
|
|
929
|
+
request.tensor.data + request.offset >= p1 ||
|
|
930
|
+
request.size > (p1 - request.tensor.data - request.offset)) {
|
|
931
|
+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
|
981
932
|
}
|
|
982
933
|
}
|
|
983
934
|
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
ggml_backend_tensor_get(tensor, output.data(), offset, size);
|
|
935
|
+
response.resize(request.size, 0);
|
|
936
|
+
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
|
|
987
937
|
ggml_free(ctx);
|
|
988
938
|
return true;
|
|
989
939
|
}
|
|
990
940
|
|
|
991
|
-
bool rpc_server::copy_tensor(const
|
|
992
|
-
// serialization format: | rpc_tensor src | rpc_tensor dst |
|
|
993
|
-
if (input.size() != 2*sizeof(rpc_tensor)) {
|
|
994
|
-
return false;
|
|
995
|
-
}
|
|
996
|
-
const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
|
|
997
|
-
const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
|
|
998
|
-
|
|
941
|
+
bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
|
|
999
942
|
struct ggml_init_params params {
|
|
1000
943
|
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
|
1001
944
|
/*.mem_buffer =*/ NULL,
|
|
1002
945
|
/*.no_alloc =*/ true,
|
|
1003
946
|
};
|
|
1004
947
|
struct ggml_context * ctx = ggml_init(params);
|
|
1005
|
-
ggml_tensor * src = deserialize_tensor(ctx,
|
|
1006
|
-
ggml_tensor * dst = deserialize_tensor(ctx,
|
|
948
|
+
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
|
949
|
+
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
|
1007
950
|
if (src == nullptr || dst == nullptr) {
|
|
1008
951
|
GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
|
|
1009
952
|
ggml_free(ctx);
|
|
1010
953
|
return false;
|
|
1011
954
|
}
|
|
1012
955
|
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
|
|
1013
|
-
|
|
1014
|
-
// output serialization format: | result (1 byte) |
|
|
1015
|
-
output.resize(1, 0);
|
|
1016
|
-
output[0] = result;
|
|
956
|
+
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
|
1017
957
|
ggml_free(ctx);
|
|
1018
958
|
return true;
|
|
1019
959
|
}
|
|
@@ -1042,7 +982,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|
|
1042
982
|
return result;
|
|
1043
983
|
}
|
|
1044
984
|
|
|
1045
|
-
bool rpc_server::graph_compute(const std::vector<uint8_t> & input,
|
|
985
|
+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
|
1046
986
|
// serialization format:
|
|
1047
987
|
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
1048
988
|
if (input.size() < sizeof(uint32_t)) {
|
|
@@ -1082,9 +1022,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
|
|
|
1082
1022
|
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
|
1083
1023
|
}
|
|
1084
1024
|
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
|
1085
|
-
|
|
1086
|
-
output.resize(1, 0);
|
|
1087
|
-
output[0] = status;
|
|
1025
|
+
response.result = status;
|
|
1088
1026
|
ggml_free(ctx);
|
|
1089
1027
|
return true;
|
|
1090
1028
|
}
|
|
@@ -1107,89 +1045,157 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1107
1045
|
fprintf(stderr, "Unknown command: %d\n", cmd);
|
|
1108
1046
|
break;
|
|
1109
1047
|
}
|
|
1110
|
-
std::vector<uint8_t> input;
|
|
1111
|
-
std::vector<uint8_t> output;
|
|
1112
|
-
uint64_t input_size;
|
|
1113
|
-
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
|
1114
|
-
break;
|
|
1115
|
-
}
|
|
1116
|
-
try {
|
|
1117
|
-
input.resize(input_size);
|
|
1118
|
-
} catch (const std::bad_alloc & e) {
|
|
1119
|
-
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
|
|
1120
|
-
break;
|
|
1121
|
-
}
|
|
1122
|
-
if (!recv_data(sockfd, input.data(), input_size)) {
|
|
1123
|
-
break;
|
|
1124
|
-
}
|
|
1125
|
-
bool ok = true;
|
|
1126
1048
|
switch (cmd) {
|
|
1127
1049
|
case RPC_CMD_ALLOC_BUFFER: {
|
|
1128
|
-
|
|
1050
|
+
rpc_msg_alloc_buffer_req request;
|
|
1051
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1052
|
+
return;
|
|
1053
|
+
}
|
|
1054
|
+
rpc_msg_alloc_buffer_rsp response;
|
|
1055
|
+
server.alloc_buffer(request, response);
|
|
1056
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1057
|
+
return;
|
|
1058
|
+
}
|
|
1129
1059
|
break;
|
|
1130
1060
|
}
|
|
1131
1061
|
case RPC_CMD_GET_ALIGNMENT: {
|
|
1132
|
-
|
|
1062
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1063
|
+
return;
|
|
1064
|
+
}
|
|
1065
|
+
rpc_msg_get_alignment_rsp response;
|
|
1066
|
+
server.get_alignment(response);
|
|
1067
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1068
|
+
return;
|
|
1069
|
+
}
|
|
1133
1070
|
break;
|
|
1134
1071
|
}
|
|
1135
1072
|
case RPC_CMD_GET_MAX_SIZE: {
|
|
1136
|
-
|
|
1073
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1074
|
+
return;
|
|
1075
|
+
}
|
|
1076
|
+
rpc_msg_get_max_size_rsp response;
|
|
1077
|
+
server.get_max_size(response);
|
|
1078
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1079
|
+
return;
|
|
1080
|
+
}
|
|
1137
1081
|
break;
|
|
1138
1082
|
}
|
|
1139
1083
|
case RPC_CMD_BUFFER_GET_BASE: {
|
|
1140
|
-
|
|
1084
|
+
rpc_msg_buffer_get_base_req request;
|
|
1085
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1086
|
+
return;
|
|
1087
|
+
}
|
|
1088
|
+
rpc_msg_buffer_get_base_rsp response;
|
|
1089
|
+
if (!server.buffer_get_base(request, response)) {
|
|
1090
|
+
return;
|
|
1091
|
+
}
|
|
1092
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1093
|
+
return;
|
|
1094
|
+
}
|
|
1141
1095
|
break;
|
|
1142
1096
|
}
|
|
1143
1097
|
case RPC_CMD_FREE_BUFFER: {
|
|
1144
|
-
|
|
1098
|
+
rpc_msg_free_buffer_req request;
|
|
1099
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1100
|
+
return;
|
|
1101
|
+
}
|
|
1102
|
+
if (!server.free_buffer(request)) {
|
|
1103
|
+
return;
|
|
1104
|
+
}
|
|
1105
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
|
1106
|
+
return;
|
|
1107
|
+
}
|
|
1145
1108
|
break;
|
|
1146
1109
|
}
|
|
1147
1110
|
case RPC_CMD_BUFFER_CLEAR: {
|
|
1148
|
-
|
|
1111
|
+
rpc_msg_buffer_clear_req request;
|
|
1112
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1113
|
+
return;
|
|
1114
|
+
}
|
|
1115
|
+
if (!server.buffer_clear(request)) {
|
|
1116
|
+
return;
|
|
1117
|
+
}
|
|
1118
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
|
1119
|
+
return;
|
|
1120
|
+
}
|
|
1149
1121
|
break;
|
|
1150
1122
|
}
|
|
1151
1123
|
case RPC_CMD_SET_TENSOR: {
|
|
1152
|
-
|
|
1124
|
+
std::vector<uint8_t> input;
|
|
1125
|
+
if (!recv_msg(sockfd, input)) {
|
|
1126
|
+
return;
|
|
1127
|
+
}
|
|
1128
|
+
if (!server.set_tensor(input)) {
|
|
1129
|
+
return;
|
|
1130
|
+
}
|
|
1131
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
|
1132
|
+
return;
|
|
1133
|
+
}
|
|
1153
1134
|
break;
|
|
1154
1135
|
}
|
|
1155
1136
|
case RPC_CMD_GET_TENSOR: {
|
|
1156
|
-
|
|
1137
|
+
rpc_msg_get_tensor_req request;
|
|
1138
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1139
|
+
return;
|
|
1140
|
+
}
|
|
1141
|
+
std::vector<uint8_t> response;
|
|
1142
|
+
if (!server.get_tensor(request, response)) {
|
|
1143
|
+
return;
|
|
1144
|
+
}
|
|
1145
|
+
if (!send_msg(sockfd, response.data(), response.size())) {
|
|
1146
|
+
return;
|
|
1147
|
+
}
|
|
1157
1148
|
break;
|
|
1158
1149
|
}
|
|
1159
1150
|
case RPC_CMD_COPY_TENSOR: {
|
|
1160
|
-
|
|
1151
|
+
rpc_msg_copy_tensor_req request;
|
|
1152
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1153
|
+
return;
|
|
1154
|
+
}
|
|
1155
|
+
rpc_msg_copy_tensor_rsp response;
|
|
1156
|
+
if (!server.copy_tensor(request, response)) {
|
|
1157
|
+
return;
|
|
1158
|
+
}
|
|
1159
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1160
|
+
return;
|
|
1161
|
+
}
|
|
1161
1162
|
break;
|
|
1162
1163
|
}
|
|
1163
1164
|
case RPC_CMD_GRAPH_COMPUTE: {
|
|
1164
|
-
|
|
1165
|
+
std::vector<uint8_t> input;
|
|
1166
|
+
if (!recv_msg(sockfd, input)) {
|
|
1167
|
+
return;
|
|
1168
|
+
}
|
|
1169
|
+
rpc_msg_graph_compute_rsp response;
|
|
1170
|
+
if (!server.graph_compute(input, response)) {
|
|
1171
|
+
return;
|
|
1172
|
+
}
|
|
1173
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1174
|
+
return;
|
|
1175
|
+
}
|
|
1165
1176
|
break;
|
|
1166
1177
|
}
|
|
1167
1178
|
case RPC_CMD_GET_DEVICE_MEMORY: {
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1179
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1180
|
+
return;
|
|
1181
|
+
}
|
|
1182
|
+
rpc_msg_get_device_memory_rsp response;
|
|
1183
|
+
response.free_mem = free_mem;
|
|
1184
|
+
response.total_mem = total_mem;
|
|
1185
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1186
|
+
return;
|
|
1187
|
+
}
|
|
1172
1188
|
break;
|
|
1173
1189
|
}
|
|
1174
1190
|
default: {
|
|
1175
1191
|
fprintf(stderr, "Unknown command: %d\n", cmd);
|
|
1176
|
-
|
|
1192
|
+
return;
|
|
1177
1193
|
}
|
|
1178
1194
|
}
|
|
1179
|
-
if (!ok) {
|
|
1180
|
-
break;
|
|
1181
|
-
}
|
|
1182
|
-
uint64_t output_size = output.size();
|
|
1183
|
-
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
|
|
1184
|
-
break;
|
|
1185
|
-
}
|
|
1186
|
-
if (!send_data(sockfd, output.data(), output_size)) {
|
|
1187
|
-
break;
|
|
1188
|
-
}
|
|
1189
1195
|
}
|
|
1190
1196
|
}
|
|
1191
1197
|
|
|
1192
|
-
void
|
|
1198
|
+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
|
1193
1199
|
std::string host;
|
|
1194
1200
|
int port;
|
|
1195
1201
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
@@ -1226,3 +1232,172 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
|
|
|
1226
1232
|
WSACleanup();
|
|
1227
1233
|
#endif
|
|
1228
1234
|
}
|
|
1235
|
+
|
|
1236
|
+
// device interface
|
|
1237
|
+
|
|
1238
|
+
struct ggml_backend_rpc_device_context {
|
|
1239
|
+
std::string endpoint;
|
|
1240
|
+
std::string name;
|
|
1241
|
+
};
|
|
1242
|
+
|
|
1243
|
+
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
|
1244
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1245
|
+
|
|
1246
|
+
return ctx->name.c_str();
|
|
1247
|
+
}
|
|
1248
|
+
|
|
1249
|
+
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
|
|
1250
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1251
|
+
|
|
1252
|
+
return ctx->name.c_str();
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
1256
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1257
|
+
|
|
1258
|
+
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
|
|
1259
|
+
|
|
1260
|
+
UNUSED(dev);
|
|
1261
|
+
}
|
|
1262
|
+
|
|
1263
|
+
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
|
1264
|
+
// TODO: obtain value from the server
|
|
1265
|
+
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
|
1266
|
+
|
|
1267
|
+
UNUSED(dev);
|
|
1268
|
+
}
|
|
1269
|
+
|
|
1270
|
+
static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
|
1271
|
+
props->name = ggml_backend_rpc_device_get_name(dev);
|
|
1272
|
+
props->description = ggml_backend_rpc_device_get_description(dev);
|
|
1273
|
+
props->type = ggml_backend_rpc_device_get_type(dev);
|
|
1274
|
+
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
1275
|
+
props->caps = {
|
|
1276
|
+
/* .async = */ false,
|
|
1277
|
+
/* .host_buffer = */ false,
|
|
1278
|
+
/* .buffer_from_host_ptr = */ false,
|
|
1279
|
+
/* .events = */ false,
|
|
1280
|
+
};
|
|
1281
|
+
}
|
|
1282
|
+
|
|
1283
|
+
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
1284
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1285
|
+
|
|
1286
|
+
return ggml_backend_rpc_init(ctx->endpoint.c_str());
|
|
1287
|
+
|
|
1288
|
+
UNUSED(params);
|
|
1289
|
+
}
|
|
1290
|
+
|
|
1291
|
+
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
1292
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1293
|
+
|
|
1294
|
+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
1295
|
+
|
|
1296
|
+
UNUSED(dev);
|
|
1297
|
+
}
|
|
1298
|
+
|
|
1299
|
+
static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
1300
|
+
UNUSED(dev);
|
|
1301
|
+
UNUSED(op);
|
|
1302
|
+
//TODO: call the remote backend and cache the results
|
|
1303
|
+
return true;
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1306
|
+
static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
1307
|
+
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
|
1308
|
+
return false;
|
|
1309
|
+
}
|
|
1310
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
1311
|
+
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1312
|
+
return buft_ctx->endpoint == dev_ctx->endpoint;
|
|
1313
|
+
}
|
|
1314
|
+
|
|
1315
|
+
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
1316
|
+
/* .get_name = */ ggml_backend_rpc_device_get_name,
|
|
1317
|
+
/* .get_description = */ ggml_backend_rpc_device_get_description,
|
|
1318
|
+
/* .get_memory = */ ggml_backend_rpc_device_get_memory,
|
|
1319
|
+
/* .get_type = */ ggml_backend_rpc_device_get_type,
|
|
1320
|
+
/* .get_props = */ ggml_backend_rpc_device_get_props,
|
|
1321
|
+
/* .init_backend = */ ggml_backend_rpc_device_init,
|
|
1322
|
+
/* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
|
|
1323
|
+
/* .get_host_buffer_type = */ NULL,
|
|
1324
|
+
/* .buffer_from_host_ptr = */ NULL,
|
|
1325
|
+
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
|
|
1326
|
+
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
|
|
1327
|
+
/* .offload_op = */ NULL,
|
|
1328
|
+
/* .event_new = */ NULL,
|
|
1329
|
+
/* .event_free = */ NULL,
|
|
1330
|
+
/* .event_synchronize = */ NULL,
|
|
1331
|
+
};
|
|
1332
|
+
|
|
1333
|
+
// backend reg interface
|
|
1334
|
+
|
|
1335
|
+
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
|
1336
|
+
return "RPC";
|
|
1337
|
+
|
|
1338
|
+
UNUSED(reg);
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
1342
|
+
return 0;
|
|
1343
|
+
|
|
1344
|
+
UNUSED(reg);
|
|
1345
|
+
}
|
|
1346
|
+
|
|
1347
|
+
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
|
1348
|
+
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
|
|
1349
|
+
|
|
1350
|
+
UNUSED(reg);
|
|
1351
|
+
UNUSED(index);
|
|
1352
|
+
}
|
|
1353
|
+
|
|
1354
|
+
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
|
1355
|
+
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
|
|
1356
|
+
return (void *)ggml_backend_rpc_add_device;
|
|
1357
|
+
}
|
|
1358
|
+
return NULL;
|
|
1359
|
+
|
|
1360
|
+
UNUSED(reg);
|
|
1361
|
+
}
|
|
1362
|
+
|
|
1363
|
+
static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
|
|
1364
|
+
/* .get_name = */ ggml_backend_rpc_reg_get_name,
|
|
1365
|
+
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
|
|
1366
|
+
/* .get_device = */ ggml_backend_rpc_reg_get_device,
|
|
1367
|
+
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
|
|
1368
|
+
};
|
|
1369
|
+
|
|
1370
|
+
ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
|
1371
|
+
static struct ggml_backend_reg ggml_backend_rpc_reg = {
|
|
1372
|
+
/* .iface = */ ggml_backend_rpc_reg_i,
|
|
1373
|
+
/* .context = */ NULL,
|
|
1374
|
+
};
|
|
1375
|
+
|
|
1376
|
+
return &ggml_backend_rpc_reg;
|
|
1377
|
+
}
|
|
1378
|
+
|
|
1379
|
+
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
|
|
1380
|
+
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
|
|
1381
|
+
|
|
1382
|
+
static std::mutex mutex;
|
|
1383
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
1384
|
+
|
|
1385
|
+
if (dev_map.find(endpoint) != dev_map.end()) {
|
|
1386
|
+
return dev_map[endpoint];
|
|
1387
|
+
}
|
|
1388
|
+
|
|
1389
|
+
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
|
|
1390
|
+
/* .endpoint = */ endpoint,
|
|
1391
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
|
1392
|
+
};
|
|
1393
|
+
|
|
1394
|
+
ggml_backend_dev_t dev = new ggml_backend_device {
|
|
1395
|
+
/* .iface = */ ggml_backend_rpc_device_i,
|
|
1396
|
+
/* .reg = */ ggml_backend_rpc_reg(),
|
|
1397
|
+
/* .context = */ ctx,
|
|
1398
|
+
};
|
|
1399
|
+
|
|
1400
|
+
dev_map[endpoint] = dev;
|
|
1401
|
+
|
|
1402
|
+
return dev;
|
|
1403
|
+
}
|