@fugood/llama.node 0.3.16 → 0.3.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +5 -0
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +8 -0
- package/src/LlamaCompletionWorker.h +1 -0
- package/src/LlamaContext.cpp +3 -2
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
- package/src/llama.cpp/.github/workflows/build.yml +70 -27
- package/src/llama.cpp/.github/workflows/docker.yml +6 -6
- package/src/llama.cpp/.github/workflows/server.yml +7 -11
- package/src/llama.cpp/CMakeLists.txt +23 -1
- package/src/llama.cpp/common/CMakeLists.txt +6 -3
- package/src/llama.cpp/common/arg.cpp +809 -105
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +1 -1
- package/src/llama.cpp/common/common.cpp +31 -521
- package/src/llama.cpp/common/common.h +17 -36
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +30 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
- package/src/llama.cpp/common/minja/minja.hpp +119 -93
- package/src/llama.cpp/common/sampling.cpp +3 -0
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +0 -9
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
- package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
- package/src/llama.cpp/examples/llava/clip.h +39 -22
- package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/examples/llava/llava.cpp +64 -52
- package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
- package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
- package/src/llama.cpp/examples/llava/mtmd.h +168 -0
- package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
- package/src/llama.cpp/examples/main/main.cpp +16 -5
- package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
- package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
- package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
- package/src/llama.cpp/examples/run/run.cpp +14 -28
- package/src/llama.cpp/examples/server/httplib.h +313 -247
- package/src/llama.cpp/examples/server/server.cpp +238 -139
- package/src/llama.cpp/examples/server/utils.hpp +51 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +6 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +66 -99
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
- package/src/llama.cpp/ggml/src/ggml.c +141 -245
- package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
- package/src/llama.cpp/include/llama.h +30 -11
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +2 -0
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/src/CMakeLists.txt +3 -2
- package/src/llama.cpp/src/llama-adapter.cpp +37 -1
- package/src/llama.cpp/src/llama-arch.cpp +160 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-chat.cpp +82 -17
- package/src/llama.cpp/src/llama-chat.h +6 -2
- package/src/llama.cpp/src/llama-context.cpp +108 -92
- package/src/llama.cpp/src/llama-context.h +1 -2
- package/src/llama.cpp/src/llama-graph.cpp +189 -119
- package/src/llama.cpp/src/llama-graph.h +26 -6
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
- package/src/llama.cpp/src/llama-kv-cache.h +41 -115
- package/src/llama.cpp/src/llama-memory.h +1 -1
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model.cpp +1760 -534
- package/src/llama.cpp/src/llama-model.h +13 -1
- package/src/llama.cpp/src/llama-quant.cpp +29 -8
- package/src/llama.cpp/src/llama-sampling.cpp +7 -1
- package/src/llama.cpp/src/llama-vocab.cpp +44 -6
- package/src/llama.cpp/src/llama.cpp +1 -1
- package/src/llama.cpp/tests/CMakeLists.txt +43 -30
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
- package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
- package/src/llama.cpp/tests/test-chat.cpp +12 -2
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
#include "ggml-rpc.h"
|
|
2
2
|
#include "ggml-impl.h"
|
|
3
3
|
#include "ggml-backend-impl.h"
|
|
4
|
+
#include "ggml-cpp.h"
|
|
4
5
|
|
|
5
6
|
#include <cinttypes>
|
|
6
7
|
#include <string>
|
|
@@ -26,6 +27,10 @@
|
|
|
26
27
|
# include <unistd.h>
|
|
27
28
|
#endif
|
|
28
29
|
#include <cstring>
|
|
30
|
+
#include <fstream>
|
|
31
|
+
#include <filesystem>
|
|
32
|
+
|
|
33
|
+
namespace fs = std::filesystem;
|
|
29
34
|
|
|
30
35
|
#ifdef _WIN32
|
|
31
36
|
typedef SOCKET sockfd_t;
|
|
@@ -80,15 +85,26 @@ enum rpc_cmd {
|
|
|
80
85
|
RPC_CMD_FREE_BUFFER,
|
|
81
86
|
RPC_CMD_BUFFER_CLEAR,
|
|
82
87
|
RPC_CMD_SET_TENSOR,
|
|
88
|
+
RPC_CMD_SET_TENSOR_HASH,
|
|
83
89
|
RPC_CMD_GET_TENSOR,
|
|
84
90
|
RPC_CMD_COPY_TENSOR,
|
|
85
91
|
RPC_CMD_GRAPH_COMPUTE,
|
|
86
92
|
RPC_CMD_GET_DEVICE_MEMORY,
|
|
87
93
|
RPC_CMD_INIT_TENSOR,
|
|
88
94
|
RPC_CMD_GET_ALLOC_SIZE,
|
|
95
|
+
RPC_CMD_HELLO,
|
|
89
96
|
RPC_CMD_COUNT,
|
|
90
97
|
};
|
|
91
98
|
|
|
99
|
+
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
|
100
|
+
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
|
101
|
+
|
|
102
|
+
struct rpc_msg_hello_rsp {
|
|
103
|
+
uint8_t major;
|
|
104
|
+
uint8_t minor;
|
|
105
|
+
uint8_t patch;
|
|
106
|
+
};
|
|
107
|
+
|
|
92
108
|
struct rpc_msg_get_alloc_size_req {
|
|
93
109
|
rpc_tensor tensor;
|
|
94
110
|
};
|
|
@@ -135,6 +151,10 @@ struct rpc_msg_buffer_clear_req {
|
|
|
135
151
|
uint8_t value;
|
|
136
152
|
};
|
|
137
153
|
|
|
154
|
+
struct rpc_msg_set_tensor_hash_rsp {
|
|
155
|
+
uint8_t result;
|
|
156
|
+
};
|
|
157
|
+
|
|
138
158
|
struct rpc_msg_get_tensor_req {
|
|
139
159
|
rpc_tensor tensor;
|
|
140
160
|
uint64_t offset;
|
|
@@ -187,6 +207,18 @@ struct ggml_backend_rpc_buffer_context {
|
|
|
187
207
|
|
|
188
208
|
// RPC helper functions
|
|
189
209
|
|
|
210
|
+
// Computes FNV-1a hash of the data
|
|
211
|
+
static uint64_t fnv_hash(const uint8_t * data, size_t len) {
|
|
212
|
+
const uint64_t fnv_prime = 0x100000001b3ULL;
|
|
213
|
+
uint64_t hash = 0xcbf29ce484222325ULL;
|
|
214
|
+
|
|
215
|
+
for (size_t i = 0; i < len; ++i) {
|
|
216
|
+
hash ^= data[i];
|
|
217
|
+
hash *= fnv_prime;
|
|
218
|
+
}
|
|
219
|
+
return hash;
|
|
220
|
+
}
|
|
221
|
+
|
|
190
222
|
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
|
|
191
223
|
#ifdef _WIN32
|
|
192
224
|
if (fd == INVALID_SOCKET) {
|
|
@@ -346,8 +378,8 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
|
|
|
346
378
|
}
|
|
347
379
|
|
|
348
380
|
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
|
349
|
-
//
|
|
350
|
-
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size
|
|
381
|
+
// No response
|
|
382
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
|
|
351
383
|
uint8_t cmd_byte = cmd;
|
|
352
384
|
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
|
353
385
|
return false;
|
|
@@ -358,6 +390,15 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
|
358
390
|
if (!send_data(sock->fd, input, input_size)) {
|
|
359
391
|
return false;
|
|
360
392
|
}
|
|
393
|
+
return true;
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
|
397
|
+
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
|
398
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
|
|
399
|
+
if (!send_rpc_cmd(sock, cmd, input, input_size)) {
|
|
400
|
+
return false;
|
|
401
|
+
}
|
|
361
402
|
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
|
|
362
403
|
// even if we do, we can skip sending output_size from the server for commands with known output size
|
|
363
404
|
uint64_t out_size;
|
|
@@ -375,6 +416,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
|
375
416
|
|
|
376
417
|
// RPC client-side implementation
|
|
377
418
|
|
|
419
|
+
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
|
420
|
+
rpc_msg_hello_rsp response;
|
|
421
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
|
422
|
+
GGML_ASSERT(status);
|
|
423
|
+
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
|
424
|
+
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
425
|
+
return false;
|
|
426
|
+
}
|
|
427
|
+
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
|
428
|
+
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
429
|
+
}
|
|
430
|
+
return true;
|
|
431
|
+
}
|
|
432
|
+
|
|
378
433
|
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
379
434
|
static std::mutex mutex;
|
|
380
435
|
std::lock_guard<std::mutex> lock(mutex);
|
|
@@ -408,6 +463,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
408
463
|
if (sock == nullptr) {
|
|
409
464
|
return nullptr;
|
|
410
465
|
}
|
|
466
|
+
if (!check_server_version(sock)) {
|
|
467
|
+
return nullptr;
|
|
468
|
+
}
|
|
411
469
|
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
|
412
470
|
sockets[endpoint] = sock;
|
|
413
471
|
return sock;
|
|
@@ -483,14 +541,30 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
|
|
|
483
541
|
|
|
484
542
|
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
485
543
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
486
|
-
|
|
544
|
+
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
|
545
|
+
if (size > HASH_THRESHOLD) {
|
|
546
|
+
// input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
|
|
547
|
+
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
|
|
548
|
+
std::vector<uint8_t> input(input_size, 0);
|
|
549
|
+
uint64_t hash = fnv_hash((const uint8_t*)data, size);
|
|
550
|
+
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
|
551
|
+
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
552
|
+
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
|
|
553
|
+
rpc_msg_set_tensor_hash_rsp response;
|
|
554
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
|
|
555
|
+
GGML_ASSERT(status);
|
|
556
|
+
if (response.result) {
|
|
557
|
+
// the server has the same data, no need to send it
|
|
558
|
+
return;
|
|
559
|
+
}
|
|
560
|
+
}
|
|
561
|
+
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
|
|
487
562
|
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
|
488
563
|
std::vector<uint8_t> input(input_size, 0);
|
|
489
|
-
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
|
490
564
|
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
|
491
565
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
492
566
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
|
493
|
-
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size()
|
|
567
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
|
|
494
568
|
GGML_ASSERT(status);
|
|
495
569
|
}
|
|
496
570
|
|
|
@@ -772,9 +846,12 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
|
|
|
772
846
|
|
|
773
847
|
class rpc_server {
|
|
774
848
|
public:
|
|
775
|
-
rpc_server(ggml_backend_t backend
|
|
849
|
+
rpc_server(ggml_backend_t backend, const char * cache_dir)
|
|
850
|
+
: backend(backend), cache_dir(cache_dir) {
|
|
851
|
+
}
|
|
776
852
|
~rpc_server();
|
|
777
853
|
|
|
854
|
+
void hello(rpc_msg_hello_rsp & response);
|
|
778
855
|
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
|
779
856
|
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
|
780
857
|
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
|
@@ -782,6 +859,7 @@ public:
|
|
|
782
859
|
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
|
783
860
|
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
|
784
861
|
bool set_tensor(const std::vector<uint8_t> & input);
|
|
862
|
+
bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
|
|
785
863
|
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
|
786
864
|
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
|
787
865
|
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
|
@@ -789,6 +867,7 @@ public:
|
|
|
789
867
|
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
|
790
868
|
|
|
791
869
|
private:
|
|
870
|
+
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
|
792
871
|
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
|
793
872
|
ggml_tensor * create_node(uint64_t id,
|
|
794
873
|
struct ggml_context * ctx,
|
|
@@ -797,9 +876,17 @@ private:
|
|
|
797
876
|
|
|
798
877
|
|
|
799
878
|
ggml_backend_t backend;
|
|
879
|
+
const char * cache_dir;
|
|
800
880
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
|
801
881
|
};
|
|
802
882
|
|
|
883
|
+
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
|
884
|
+
response.major = RPC_PROTO_MAJOR_VERSION;
|
|
885
|
+
response.minor = RPC_PROTO_MINOR_VERSION;
|
|
886
|
+
response.patch = RPC_PROTO_PATCH_VERSION;
|
|
887
|
+
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
|
|
888
|
+
}
|
|
889
|
+
|
|
803
890
|
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
|
804
891
|
ggml_backend_buffer_type_t buft;
|
|
805
892
|
struct ggml_init_params params {
|
|
@@ -808,12 +895,13 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
808
895
|
/*.no_alloc =*/ true,
|
|
809
896
|
};
|
|
810
897
|
|
|
811
|
-
|
|
898
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
899
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
900
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
812
901
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
813
902
|
|
|
814
903
|
if (tensor == nullptr) {
|
|
815
904
|
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
|
816
|
-
ggml_free(ctx);
|
|
817
905
|
return false;
|
|
818
906
|
}
|
|
819
907
|
|
|
@@ -826,7 +914,6 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
826
914
|
|
|
827
915
|
response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
|
|
828
916
|
|
|
829
|
-
ggml_free(ctx);
|
|
830
917
|
return true;
|
|
831
918
|
}
|
|
832
919
|
|
|
@@ -895,8 +982,21 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
|
|
895
982
|
}
|
|
896
983
|
|
|
897
984
|
ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
|
|
985
|
+
// Validate tensor type before using it
|
|
986
|
+
if (tensor->type >= GGML_TYPE_COUNT) {
|
|
987
|
+
GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
|
|
988
|
+
return nullptr;
|
|
989
|
+
}
|
|
990
|
+
|
|
898
991
|
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
|
|
899
992
|
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
993
|
+
|
|
994
|
+
// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
|
|
995
|
+
if (result == nullptr) {
|
|
996
|
+
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
|
|
997
|
+
return nullptr;
|
|
998
|
+
}
|
|
999
|
+
|
|
900
1000
|
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
|
901
1001
|
result->nb[i] = tensor->nb[i];
|
|
902
1002
|
}
|
|
@@ -940,11 +1040,12 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
940
1040
|
/*.mem_buffer =*/ NULL,
|
|
941
1041
|
/*.no_alloc =*/ true,
|
|
942
1042
|
};
|
|
943
|
-
|
|
1043
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1044
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1045
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
944
1046
|
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
|
945
1047
|
if (tensor == nullptr) {
|
|
946
1048
|
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
947
|
-
ggml_free(ctx);
|
|
948
1049
|
return false;
|
|
949
1050
|
}
|
|
950
1051
|
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
|
@@ -955,13 +1056,90 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
955
1056
|
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
|
956
1057
|
|
|
957
1058
|
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
|
958
|
-
|
|
1059
|
+
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
|
|
1060
|
+
__func__, in_tensor->data, offset, size, p0, p1);
|
|
1061
|
+
return false;
|
|
959
1062
|
}
|
|
960
1063
|
}
|
|
961
1064
|
|
|
962
1065
|
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
|
|
1066
|
+
if (cache_dir && size > HASH_THRESHOLD) {
|
|
1067
|
+
uint64_t hash = fnv_hash((const uint8_t*)data, size);
|
|
1068
|
+
char hash_str[17];
|
|
1069
|
+
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
|
1070
|
+
// save to cache_dir/hash_str
|
|
1071
|
+
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
|
1072
|
+
std::ofstream ofs(cache_file, std::ios::binary);
|
|
1073
|
+
ofs.write((const char *)data, size);
|
|
1074
|
+
printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
|
|
1075
|
+
}
|
|
963
1076
|
ggml_backend_tensor_set(tensor, data, offset, size);
|
|
964
|
-
|
|
1077
|
+
return true;
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
|
1081
|
+
if (!cache_dir) {
|
|
1082
|
+
return false;
|
|
1083
|
+
}
|
|
1084
|
+
char hash_str[17];
|
|
1085
|
+
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
|
1086
|
+
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
|
1087
|
+
if (!fs::exists(cache_file)) {
|
|
1088
|
+
return false;
|
|
1089
|
+
}
|
|
1090
|
+
std::ifstream ifs(cache_file, std::ios::binary);
|
|
1091
|
+
ifs.seekg(0, std::ios::end);
|
|
1092
|
+
size_t size = ifs.tellg();
|
|
1093
|
+
ifs.seekg(0, std::ios::beg);
|
|
1094
|
+
data.resize(size);
|
|
1095
|
+
ifs.read((char *)data.data(), size);
|
|
1096
|
+
return true;
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
|
|
1100
|
+
{
|
|
1101
|
+
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
|
|
1102
|
+
if (input.size() != sizeof(rpc_tensor) + 16) {
|
|
1103
|
+
return false;
|
|
1104
|
+
}
|
|
1105
|
+
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
|
1106
|
+
uint64_t offset;
|
|
1107
|
+
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
|
1108
|
+
const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
|
|
1109
|
+
std::vector<uint8_t> cached_file;
|
|
1110
|
+
if (!get_cached_file(*hash, cached_file)) {
|
|
1111
|
+
response.result = 0;
|
|
1112
|
+
return true;
|
|
1113
|
+
}
|
|
1114
|
+
size_t size = cached_file.size();
|
|
1115
|
+
struct ggml_init_params params {
|
|
1116
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
1117
|
+
/*.mem_buffer =*/ NULL,
|
|
1118
|
+
/*.no_alloc =*/ true,
|
|
1119
|
+
};
|
|
1120
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1121
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1122
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
1123
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
|
1124
|
+
if (tensor == nullptr) {
|
|
1125
|
+
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
1126
|
+
return false;
|
|
1127
|
+
}
|
|
1128
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
|
|
1129
|
+
|
|
1130
|
+
// sanitize tensor->data
|
|
1131
|
+
{
|
|
1132
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
|
1133
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
|
1134
|
+
|
|
1135
|
+
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
|
1136
|
+
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
|
1137
|
+
__func__, in_tensor->data, offset, size, *hash, p0, p1);
|
|
1138
|
+
return false;
|
|
1139
|
+
}
|
|
1140
|
+
}
|
|
1141
|
+
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
|
|
1142
|
+
response.result = 1;
|
|
965
1143
|
return true;
|
|
966
1144
|
}
|
|
967
1145
|
|
|
@@ -971,11 +1149,12 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
|
|
971
1149
|
/*.mem_buffer =*/ NULL,
|
|
972
1150
|
/*.no_alloc =*/ true,
|
|
973
1151
|
};
|
|
974
|
-
|
|
1152
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1153
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1154
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
975
1155
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
976
1156
|
if (tensor == nullptr) {
|
|
977
1157
|
GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
|
|
978
|
-
ggml_free(ctx);
|
|
979
1158
|
return false;
|
|
980
1159
|
}
|
|
981
1160
|
|
|
@@ -991,11 +1170,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
|
|
991
1170
|
// This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
|
|
992
1171
|
// Currently unimplemented.
|
|
993
1172
|
GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
|
|
994
|
-
ggml_free(ctx);
|
|
995
1173
|
return false;
|
|
996
1174
|
}
|
|
997
1175
|
|
|
998
|
-
ggml_free(ctx);
|
|
999
1176
|
return true;
|
|
1000
1177
|
}
|
|
1001
1178
|
|
|
@@ -1005,11 +1182,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
|
|
|
1005
1182
|
/*.mem_buffer =*/ NULL,
|
|
1006
1183
|
/*.no_alloc =*/ true,
|
|
1007
1184
|
};
|
|
1008
|
-
|
|
1185
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1186
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1187
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
1009
1188
|
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
1010
1189
|
if (tensor == nullptr) {
|
|
1011
1190
|
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
|
1012
|
-
ggml_free(ctx);
|
|
1013
1191
|
return false;
|
|
1014
1192
|
}
|
|
1015
1193
|
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
|
@@ -1022,13 +1200,14 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
|
|
|
1022
1200
|
if (request.tensor.data + request.offset < p0 ||
|
|
1023
1201
|
request.tensor.data + request.offset >= p1 ||
|
|
1024
1202
|
request.size > (p1 - request.tensor.data - request.offset)) {
|
|
1025
|
-
|
|
1203
|
+
GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
|
1204
|
+
__func__, request.tensor.data, request.offset, request.size, p0, p1);
|
|
1205
|
+
return false;
|
|
1026
1206
|
}
|
|
1027
1207
|
}
|
|
1028
1208
|
|
|
1029
1209
|
response.resize(request.size, 0);
|
|
1030
1210
|
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
|
|
1031
|
-
ggml_free(ctx);
|
|
1032
1211
|
return true;
|
|
1033
1212
|
}
|
|
1034
1213
|
|
|
@@ -1038,12 +1217,14 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1038
1217
|
/*.mem_buffer =*/ NULL,
|
|
1039
1218
|
/*.no_alloc =*/ true,
|
|
1040
1219
|
};
|
|
1041
|
-
|
|
1220
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1221
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1222
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
1223
|
+
|
|
1042
1224
|
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
|
1043
1225
|
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
|
1044
1226
|
if (src == nullptr || dst == nullptr) {
|
|
1045
1227
|
GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
|
|
1046
|
-
ggml_free(ctx);
|
|
1047
1228
|
return false;
|
|
1048
1229
|
}
|
|
1049
1230
|
|
|
@@ -1061,7 +1242,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1061
1242
|
dst_data + src_size,
|
|
1062
1243
|
dst_base,
|
|
1063
1244
|
dst_base + dst_buf_sz);
|
|
1064
|
-
ggml_free(ctx);
|
|
1065
1245
|
return false;
|
|
1066
1246
|
}
|
|
1067
1247
|
|
|
@@ -1069,7 +1249,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
|
|
|
1069
1249
|
__func__, (void*) src->buffer, (void*) dst->buffer);
|
|
1070
1250
|
|
|
1071
1251
|
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
|
1072
|
-
ggml_free(ctx);
|
|
1073
1252
|
return true;
|
|
1074
1253
|
}
|
|
1075
1254
|
|
|
@@ -1077,22 +1256,50 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|
|
1077
1256
|
struct ggml_context * ctx,
|
|
1078
1257
|
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
|
|
1079
1258
|
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
|
|
1080
|
-
if (id == 0) {
|
|
1081
|
-
return nullptr;
|
|
1082
|
-
}
|
|
1083
1259
|
if (tensor_map.find(id) != tensor_map.end()) {
|
|
1084
1260
|
return tensor_map[id];
|
|
1085
1261
|
}
|
|
1086
|
-
|
|
1262
|
+
// Safely find the tensor pointer
|
|
1263
|
+
auto it_ptr = tensor_ptrs.find(id);
|
|
1264
|
+
if (it_ptr == tensor_ptrs.end()) {
|
|
1265
|
+
return nullptr;
|
|
1266
|
+
}
|
|
1267
|
+
const rpc_tensor * tensor = it_ptr->second;
|
|
1268
|
+
|
|
1087
1269
|
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
|
|
1088
1270
|
if (result == nullptr) {
|
|
1089
1271
|
return nullptr;
|
|
1090
1272
|
}
|
|
1091
1273
|
tensor_map[id] = result;
|
|
1092
1274
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
1093
|
-
|
|
1275
|
+
// Check if the source ID is 0 before calling create_node recursively
|
|
1276
|
+
if (tensor->src[i] == 0) {
|
|
1277
|
+
result->src[i] = nullptr;
|
|
1278
|
+
} else {
|
|
1279
|
+
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
|
|
1280
|
+
// If the recursive call failed for a non-zero ID, propagate the error
|
|
1281
|
+
if (result->src[i] == nullptr) {
|
|
1282
|
+
GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
|
|
1283
|
+
__func__, i, tensor->src[i], id);
|
|
1284
|
+
// Must return nullptr to signal failure up the call stack
|
|
1285
|
+
return nullptr;
|
|
1286
|
+
}
|
|
1287
|
+
}
|
|
1288
|
+
}
|
|
1289
|
+
|
|
1290
|
+
// Handle view_src similarly
|
|
1291
|
+
if (tensor->view_src == 0) {
|
|
1292
|
+
result->view_src = nullptr;
|
|
1293
|
+
} else {
|
|
1294
|
+
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
|
|
1295
|
+
// If the recursive call failed for a non-zero ID, propagate the error
|
|
1296
|
+
if (result->view_src == nullptr) {
|
|
1297
|
+
GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
|
|
1298
|
+
__func__, tensor->view_src, id);
|
|
1299
|
+
// Must return nullptr to signal failure up the call stack
|
|
1300
|
+
return nullptr;
|
|
1301
|
+
}
|
|
1094
1302
|
}
|
|
1095
|
-
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
|
|
1096
1303
|
result->view_offs = tensor->view_offs;
|
|
1097
1304
|
return result;
|
|
1098
1305
|
}
|
|
@@ -1118,12 +1325,15 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|
|
1118
1325
|
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
|
|
1119
1326
|
|
|
1120
1327
|
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
|
1328
|
+
|
|
1121
1329
|
struct ggml_init_params params = {
|
|
1122
1330
|
/*.mem_size =*/ buf_size,
|
|
1123
1331
|
/*.mem_buffer =*/ NULL,
|
|
1124
1332
|
/*.no_alloc =*/ true,
|
|
1125
1333
|
};
|
|
1126
|
-
|
|
1334
|
+
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
1335
|
+
GGML_ASSERT(ctx_ptr != nullptr);
|
|
1336
|
+
ggml_context * ctx = ctx_ptr.get();
|
|
1127
1337
|
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
|
|
1128
1338
|
graph->n_nodes = n_nodes;
|
|
1129
1339
|
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
|
|
@@ -1135,10 +1345,17 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|
|
1135
1345
|
int64_t id;
|
|
1136
1346
|
memcpy(&id, &nodes[i], sizeof(id));
|
|
1137
1347
|
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
|
1348
|
+
|
|
1349
|
+
// Check if create_node failed for a *non-zero* ID.
|
|
1350
|
+
// If id was 0, create_node returning nullptr is expected.
|
|
1351
|
+
// If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
|
|
1352
|
+
if (graph->nodes[i] == nullptr && id != 0) {
|
|
1353
|
+
GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
|
|
1354
|
+
return false;
|
|
1355
|
+
}
|
|
1138
1356
|
}
|
|
1139
1357
|
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
|
1140
1358
|
response.result = status;
|
|
1141
|
-
ggml_free(ctx);
|
|
1142
1359
|
return true;
|
|
1143
1360
|
}
|
|
1144
1361
|
|
|
@@ -1148,10 +1365,27 @@ rpc_server::~rpc_server() {
|
|
|
1148
1365
|
}
|
|
1149
1366
|
}
|
|
1150
1367
|
|
|
1151
|
-
static void rpc_serve_client(ggml_backend_t backend,
|
|
1152
|
-
|
|
1368
|
+
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
1369
|
+
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
|
1370
|
+
rpc_server server(backend, cache_dir);
|
|
1371
|
+
uint8_t cmd;
|
|
1372
|
+
if (!recv_data(sockfd, &cmd, 1)) {
|
|
1373
|
+
return;
|
|
1374
|
+
}
|
|
1375
|
+
// the first command sent by the client must be HELLO
|
|
1376
|
+
if (cmd != RPC_CMD_HELLO) {
|
|
1377
|
+
fprintf(stderr, "Expected HELLO command, update client\n");
|
|
1378
|
+
return;
|
|
1379
|
+
}
|
|
1380
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1381
|
+
return;
|
|
1382
|
+
}
|
|
1383
|
+
rpc_msg_hello_rsp response;
|
|
1384
|
+
server.hello(response);
|
|
1385
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1386
|
+
return;
|
|
1387
|
+
}
|
|
1153
1388
|
while (true) {
|
|
1154
|
-
uint8_t cmd;
|
|
1155
1389
|
if (!recv_data(sockfd, &cmd, 1)) {
|
|
1156
1390
|
break;
|
|
1157
1391
|
}
|
|
@@ -1161,6 +1395,10 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1161
1395
|
break;
|
|
1162
1396
|
}
|
|
1163
1397
|
switch (cmd) {
|
|
1398
|
+
case RPC_CMD_HELLO: {
|
|
1399
|
+
// HELLO command is handled above
|
|
1400
|
+
return;
|
|
1401
|
+
}
|
|
1164
1402
|
case RPC_CMD_ALLOC_BUFFER: {
|
|
1165
1403
|
rpc_msg_alloc_buffer_req request;
|
|
1166
1404
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
@@ -1179,7 +1417,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1179
1417
|
return;
|
|
1180
1418
|
}
|
|
1181
1419
|
rpc_msg_get_alloc_size_rsp response;
|
|
1182
|
-
server.get_alloc_size(request, response)
|
|
1420
|
+
if (!server.get_alloc_size(request, response)) {
|
|
1421
|
+
return;
|
|
1422
|
+
}
|
|
1183
1423
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1184
1424
|
return;
|
|
1185
1425
|
}
|
|
@@ -1255,7 +1495,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1255
1495
|
if (!server.set_tensor(input)) {
|
|
1256
1496
|
return;
|
|
1257
1497
|
}
|
|
1258
|
-
|
|
1498
|
+
break;
|
|
1499
|
+
}
|
|
1500
|
+
case RPC_CMD_SET_TENSOR_HASH: {
|
|
1501
|
+
std::vector<uint8_t> input;
|
|
1502
|
+
if (!recv_msg(sockfd, input)) {
|
|
1503
|
+
return;
|
|
1504
|
+
}
|
|
1505
|
+
rpc_msg_set_tensor_hash_rsp response;
|
|
1506
|
+
if (!server.set_tensor_hash(input, response)) {
|
|
1507
|
+
return;
|
|
1508
|
+
}
|
|
1509
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1259
1510
|
return;
|
|
1260
1511
|
}
|
|
1261
1512
|
break;
|
|
@@ -1335,7 +1586,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1335
1586
|
}
|
|
1336
1587
|
}
|
|
1337
1588
|
|
|
1338
|
-
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
|
1589
|
+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
|
1590
|
+
const char * cache_dir,
|
|
1591
|
+
size_t free_mem, size_t total_mem) {
|
|
1339
1592
|
std::string host;
|
|
1340
1593
|
int port;
|
|
1341
1594
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
@@ -1364,7 +1617,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
|
|
|
1364
1617
|
}
|
|
1365
1618
|
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
|
1366
1619
|
fflush(stdout);
|
|
1367
|
-
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
|
1620
|
+
rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
|
|
1368
1621
|
printf("Client connection closed\n");
|
|
1369
1622
|
fflush(stdout);
|
|
1370
1623
|
}
|