@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#include "ggml-rpc.h"
|
|
2
|
-
#include "ggml.h"
|
|
2
|
+
#include "ggml-impl.h"
|
|
3
3
|
#include "ggml-backend-impl.h"
|
|
4
4
|
|
|
5
5
|
#include <cinttypes>
|
|
@@ -25,7 +25,7 @@
|
|
|
25
25
|
# include <netdb.h>
|
|
26
26
|
# include <unistd.h>
|
|
27
27
|
#endif
|
|
28
|
-
#include <
|
|
28
|
+
#include <cstring>
|
|
29
29
|
|
|
30
30
|
#define UNUSED GGML_UNUSED
|
|
31
31
|
|
|
@@ -57,8 +57,9 @@ struct socket_t {
|
|
|
57
57
|
}
|
|
58
58
|
};
|
|
59
59
|
|
|
60
|
-
//
|
|
60
|
+
// all RPC structures must be packed
|
|
61
61
|
#pragma pack(push, 1)
|
|
62
|
+
// ggml_tensor is serialized into rpc_tensor
|
|
62
63
|
struct rpc_tensor {
|
|
63
64
|
uint64_t id;
|
|
64
65
|
uint32_t type;
|
|
@@ -76,25 +77,84 @@ struct rpc_tensor {
|
|
|
76
77
|
|
|
77
78
|
char padding[4];
|
|
78
79
|
};
|
|
79
|
-
#pragma pack(pop)
|
|
80
80
|
|
|
81
81
|
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
|
|
82
82
|
|
|
83
83
|
// RPC commands
|
|
84
84
|
enum rpc_cmd {
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
85
|
+
RPC_CMD_ALLOC_BUFFER = 0,
|
|
86
|
+
RPC_CMD_GET_ALIGNMENT,
|
|
87
|
+
RPC_CMD_GET_MAX_SIZE,
|
|
88
|
+
RPC_CMD_BUFFER_GET_BASE,
|
|
89
|
+
RPC_CMD_FREE_BUFFER,
|
|
90
|
+
RPC_CMD_BUFFER_CLEAR,
|
|
91
|
+
RPC_CMD_SET_TENSOR,
|
|
92
|
+
RPC_CMD_GET_TENSOR,
|
|
93
|
+
RPC_CMD_COPY_TENSOR,
|
|
94
|
+
RPC_CMD_GRAPH_COMPUTE,
|
|
95
|
+
RPC_CMD_GET_DEVICE_MEMORY,
|
|
96
|
+
RPC_CMD_COUNT,
|
|
97
|
+
};
|
|
98
|
+
|
|
99
|
+
struct rpc_msg_alloc_buffer_req {
|
|
100
|
+
uint64_t size;
|
|
101
|
+
};
|
|
102
|
+
|
|
103
|
+
struct rpc_msg_alloc_buffer_rsp {
|
|
104
|
+
uint64_t remote_ptr;
|
|
105
|
+
uint64_t remote_size;
|
|
106
|
+
};
|
|
107
|
+
|
|
108
|
+
struct rpc_msg_get_alignment_rsp {
|
|
109
|
+
uint64_t alignment;
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
struct rpc_msg_get_max_size_rsp {
|
|
113
|
+
uint64_t max_size;
|
|
96
114
|
};
|
|
97
115
|
|
|
116
|
+
struct rpc_msg_buffer_get_base_req {
|
|
117
|
+
uint64_t remote_ptr;
|
|
118
|
+
};
|
|
119
|
+
|
|
120
|
+
struct rpc_msg_buffer_get_base_rsp {
|
|
121
|
+
uint64_t base_ptr;
|
|
122
|
+
};
|
|
123
|
+
|
|
124
|
+
struct rpc_msg_free_buffer_req {
|
|
125
|
+
uint64_t remote_ptr;
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
struct rpc_msg_buffer_clear_req {
|
|
129
|
+
uint64_t remote_ptr;
|
|
130
|
+
uint8_t value;
|
|
131
|
+
};
|
|
132
|
+
|
|
133
|
+
struct rpc_msg_get_tensor_req {
|
|
134
|
+
rpc_tensor tensor;
|
|
135
|
+
uint64_t offset;
|
|
136
|
+
uint64_t size;
|
|
137
|
+
};
|
|
138
|
+
|
|
139
|
+
struct rpc_msg_copy_tensor_req {
|
|
140
|
+
rpc_tensor src;
|
|
141
|
+
rpc_tensor dst;
|
|
142
|
+
};
|
|
143
|
+
|
|
144
|
+
struct rpc_msg_copy_tensor_rsp {
|
|
145
|
+
uint8_t result;
|
|
146
|
+
};
|
|
147
|
+
|
|
148
|
+
struct rpc_msg_graph_compute_rsp {
|
|
149
|
+
uint8_t result;
|
|
150
|
+
};
|
|
151
|
+
|
|
152
|
+
struct rpc_msg_get_device_memory_rsp {
|
|
153
|
+
uint64_t free_mem;
|
|
154
|
+
uint64_t total_mem;
|
|
155
|
+
};
|
|
156
|
+
#pragma pack(pop)
|
|
157
|
+
|
|
98
158
|
// RPC data structures
|
|
99
159
|
|
|
100
160
|
static ggml_guid_t ggml_backend_rpc_guid() {
|
|
@@ -118,7 +178,6 @@ struct ggml_backend_rpc_buffer_context {
|
|
|
118
178
|
std::shared_ptr<socket_t> sock;
|
|
119
179
|
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
|
120
180
|
uint64_t remote_ptr;
|
|
121
|
-
std::string name;
|
|
122
181
|
};
|
|
123
182
|
|
|
124
183
|
// RPC helper functions
|
|
@@ -197,6 +256,10 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
|
|
|
197
256
|
fprintf(stderr, "Failed to set SO_REUSEADDR\n");
|
|
198
257
|
return nullptr;
|
|
199
258
|
}
|
|
259
|
+
if (inet_addr(host) == INADDR_NONE) {
|
|
260
|
+
fprintf(stderr, "Invalid host address: %s\n", host);
|
|
261
|
+
return nullptr;
|
|
262
|
+
}
|
|
200
263
|
struct sockaddr_in serv_addr;
|
|
201
264
|
serv_addr.sin_family = AF_INET;
|
|
202
265
|
serv_addr.sin_addr.s_addr = inet_addr(host);
|
|
@@ -235,6 +298,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|
|
235
298
|
return true;
|
|
236
299
|
}
|
|
237
300
|
|
|
301
|
+
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
|
|
302
|
+
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
|
|
303
|
+
return false;
|
|
304
|
+
}
|
|
305
|
+
return send_data(sockfd, msg, msg_size);
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
|
|
309
|
+
uint64_t size;
|
|
310
|
+
if (!recv_data(sockfd, &size, sizeof(size))) {
|
|
311
|
+
return false;
|
|
312
|
+
}
|
|
313
|
+
if (size != msg_size) {
|
|
314
|
+
return false;
|
|
315
|
+
}
|
|
316
|
+
return recv_data(sockfd, msg, msg_size);
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
|
|
320
|
+
uint64_t size;
|
|
321
|
+
if (!recv_data(sockfd, &size, sizeof(size))) {
|
|
322
|
+
return false;
|
|
323
|
+
}
|
|
324
|
+
try {
|
|
325
|
+
input.resize(size);
|
|
326
|
+
} catch (const std::bad_alloc & e) {
|
|
327
|
+
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
|
|
328
|
+
return false;
|
|
329
|
+
}
|
|
330
|
+
return recv_data(sockfd, input.data(), size);
|
|
331
|
+
}
|
|
332
|
+
|
|
238
333
|
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
|
239
334
|
size_t pos = endpoint.find(':');
|
|
240
335
|
if (pos == std::string::npos) {
|
|
@@ -247,28 +342,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
|
|
|
247
342
|
|
|
248
343
|
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
|
249
344
|
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
|
250
|
-
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const
|
|
345
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
|
|
251
346
|
uint8_t cmd_byte = cmd;
|
|
252
347
|
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
|
253
348
|
return false;
|
|
254
349
|
}
|
|
255
|
-
uint64_t input_size = input.size();
|
|
256
350
|
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
|
|
257
351
|
return false;
|
|
258
352
|
}
|
|
259
|
-
if (!send_data(sock->fd, input
|
|
353
|
+
if (!send_data(sock->fd, input, input_size)) {
|
|
260
354
|
return false;
|
|
261
355
|
}
|
|
262
|
-
|
|
263
|
-
if
|
|
356
|
+
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
|
|
357
|
+
// even if we do, we can skip sending output_size from the server for commands with known output size
|
|
358
|
+
uint64_t out_size;
|
|
359
|
+
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
|
|
264
360
|
return false;
|
|
265
361
|
}
|
|
266
|
-
if (
|
|
267
|
-
|
|
268
|
-
return true;
|
|
362
|
+
if (out_size != output_size) {
|
|
363
|
+
return false;
|
|
269
364
|
}
|
|
270
|
-
output
|
|
271
|
-
if (!recv_data(sock->fd, output.data(), output_size)) {
|
|
365
|
+
if (!recv_data(sock->fd, output, output_size)) {
|
|
272
366
|
return false;
|
|
273
367
|
}
|
|
274
368
|
return true;
|
|
@@ -314,43 +408,26 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
314
408
|
return sock;
|
|
315
409
|
}
|
|
316
410
|
|
|
317
|
-
|
|
411
|
+
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
318
412
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
323
|
-
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
324
|
-
// input serialization format: | remote_ptr (8 bytes) |
|
|
325
|
-
std::vector<uint8_t> input(sizeof(uint64_t), 0);
|
|
326
|
-
uint64_t remote_ptr = ctx->remote_ptr;
|
|
327
|
-
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
|
328
|
-
std::vector<uint8_t> output;
|
|
329
|
-
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
|
|
413
|
+
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
|
|
414
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
|
|
330
415
|
GGML_ASSERT(status);
|
|
331
|
-
GGML_ASSERT(output.empty());
|
|
332
416
|
delete ctx;
|
|
333
417
|
}
|
|
334
418
|
|
|
335
|
-
|
|
419
|
+
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
336
420
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
337
421
|
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
|
338
422
|
return ctx->base_cache[buffer];
|
|
339
423
|
}
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
|
344
|
-
std::vector<uint8_t> output;
|
|
345
|
-
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
|
|
424
|
+
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
|
425
|
+
rpc_msg_buffer_get_base_rsp response;
|
|
426
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
|
346
427
|
GGML_ASSERT(status);
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
memcpy(&base_ptr, output.data(), sizeof(base_ptr));
|
|
351
|
-
void * base = reinterpret_cast<void *>(base_ptr);
|
|
352
|
-
ctx->base_cache[buffer] = base;
|
|
353
|
-
return base;
|
|
428
|
+
void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
|
429
|
+
ctx->base_cache[buffer] = base_ptr;
|
|
430
|
+
return base_ptr;
|
|
354
431
|
}
|
|
355
432
|
|
|
356
433
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
@@ -383,7 +460,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
|
383
460
|
return result;
|
|
384
461
|
}
|
|
385
462
|
|
|
386
|
-
|
|
463
|
+
static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
|
387
464
|
UNUSED(buffer);
|
|
388
465
|
if (ggml_is_quantized(tensor->type)) {
|
|
389
466
|
// TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
|
|
@@ -391,7 +468,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t
|
|
|
391
468
|
}
|
|
392
469
|
}
|
|
393
470
|
|
|
394
|
-
|
|
471
|
+
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
395
472
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
396
473
|
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
|
|
397
474
|
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
|
@@ -400,29 +477,21 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
|
|
|
400
477
|
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
|
401
478
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
402
479
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
|
403
|
-
|
|
404
|
-
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
|
|
480
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
|
|
405
481
|
GGML_ASSERT(status);
|
|
406
482
|
}
|
|
407
483
|
|
|
408
|
-
|
|
484
|
+
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
409
485
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
|
416
|
-
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
|
417
|
-
std::vector<uint8_t> output;
|
|
418
|
-
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
|
|
486
|
+
rpc_msg_get_tensor_req request;
|
|
487
|
+
request.tensor = serialize_tensor(tensor);
|
|
488
|
+
request.offset = offset;
|
|
489
|
+
request.size = size;
|
|
490
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
|
|
419
491
|
GGML_ASSERT(status);
|
|
420
|
-
GGML_ASSERT(output.size() == size);
|
|
421
|
-
// output serialization format: | data (size bytes) |
|
|
422
|
-
memcpy(data, output.data(), size);
|
|
423
492
|
}
|
|
424
493
|
|
|
425
|
-
|
|
494
|
+
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
|
426
495
|
// check if src and dst are on the same server
|
|
427
496
|
ggml_backend_buffer_t src_buffer = src->buffer;
|
|
428
497
|
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
|
@@ -432,38 +501,27 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
|
|
|
432
501
|
return false;
|
|
433
502
|
}
|
|
434
503
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
|
441
|
-
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
|
442
|
-
std::vector<uint8_t> output;
|
|
443
|
-
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
|
|
504
|
+
rpc_msg_copy_tensor_req request;
|
|
505
|
+
request.src = serialize_tensor(src);
|
|
506
|
+
request.dst = serialize_tensor(dst);
|
|
507
|
+
rpc_msg_copy_tensor_rsp response;
|
|
508
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
|
444
509
|
GGML_ASSERT(status);
|
|
445
|
-
|
|
446
|
-
GGML_ASSERT(output.size() == 1);
|
|
447
|
-
return output[0];
|
|
510
|
+
return response.result;
|
|
448
511
|
}
|
|
449
512
|
|
|
450
|
-
|
|
513
|
+
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
451
514
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
std::vector<uint8_t> input(input_size, 0);
|
|
455
|
-
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
|
456
|
-
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
|
457
|
-
std::vector<uint8_t> output;
|
|
458
|
-
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
|
|
515
|
+
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
|
|
516
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
|
|
459
517
|
GGML_ASSERT(status);
|
|
460
518
|
}
|
|
461
519
|
|
|
462
520
|
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
|
463
|
-
/* .get_name = */ ggml_backend_rpc_buffer_get_name,
|
|
464
521
|
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
|
|
465
522
|
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
|
|
466
523
|
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
|
|
524
|
+
/* .memset_tensor = */ NULL,
|
|
467
525
|
/* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
|
|
468
526
|
/* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
|
|
469
527
|
/* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
|
|
@@ -471,32 +529,23 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
|
|
471
529
|
/* .reset = */ NULL,
|
|
472
530
|
};
|
|
473
531
|
|
|
474
|
-
|
|
532
|
+
static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
|
475
533
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
476
534
|
return buft_ctx->name.c_str();
|
|
477
535
|
}
|
|
478
536
|
|
|
479
|
-
|
|
537
|
+
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
480
538
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
std::vector<uint8_t> input(input_size, 0);
|
|
484
|
-
memcpy(input.data(), &size, sizeof(size));
|
|
485
|
-
std::vector<uint8_t> output;
|
|
539
|
+
rpc_msg_alloc_buffer_req request = {size};
|
|
540
|
+
rpc_msg_alloc_buffer_rsp response;
|
|
486
541
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
487
|
-
bool status = send_rpc_cmd(sock,
|
|
542
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
|
488
543
|
GGML_ASSERT(status);
|
|
489
|
-
|
|
490
|
-
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
491
|
-
uint64_t remote_ptr;
|
|
492
|
-
memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
|
|
493
|
-
size_t remote_size;
|
|
494
|
-
memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
|
|
495
|
-
if (remote_ptr != 0) {
|
|
544
|
+
if (response.remote_ptr != 0) {
|
|
496
545
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
497
546
|
ggml_backend_rpc_buffer_interface,
|
|
498
|
-
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr
|
|
499
|
-
remote_size);
|
|
547
|
+
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
|
|
548
|
+
response.remote_size);
|
|
500
549
|
return buffer;
|
|
501
550
|
} else {
|
|
502
551
|
return nullptr;
|
|
@@ -504,42 +553,30 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
|
504
553
|
}
|
|
505
554
|
|
|
506
555
|
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
std::vector<uint8_t> output;
|
|
510
|
-
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
|
|
556
|
+
rpc_msg_get_alignment_rsp response;
|
|
557
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
|
|
511
558
|
GGML_ASSERT(status);
|
|
512
|
-
|
|
513
|
-
// output serialization format: | alignment (8 bytes) |
|
|
514
|
-
uint64_t alignment;
|
|
515
|
-
memcpy(&alignment, output.data(), sizeof(alignment));
|
|
516
|
-
return alignment;
|
|
559
|
+
return response.alignment;
|
|
517
560
|
}
|
|
518
561
|
|
|
519
|
-
|
|
562
|
+
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
520
563
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
521
564
|
return buft_ctx->alignment;
|
|
522
565
|
}
|
|
523
566
|
|
|
524
567
|
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
std::vector<uint8_t> output;
|
|
528
|
-
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
|
|
568
|
+
rpc_msg_get_max_size_rsp response;
|
|
569
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
|
|
529
570
|
GGML_ASSERT(status);
|
|
530
|
-
|
|
531
|
-
// output serialization format: | max_size (8 bytes) |
|
|
532
|
-
uint64_t max_size;
|
|
533
|
-
memcpy(&max_size, output.data(), sizeof(max_size));
|
|
534
|
-
return max_size;
|
|
571
|
+
return response.max_size;
|
|
535
572
|
}
|
|
536
573
|
|
|
537
|
-
|
|
574
|
+
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
538
575
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
539
576
|
return buft_ctx->max_size;
|
|
540
577
|
}
|
|
541
578
|
|
|
542
|
-
|
|
579
|
+
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
|
543
580
|
UNUSED(buft);
|
|
544
581
|
return ggml_nbytes(tensor);
|
|
545
582
|
}
|
|
@@ -553,24 +590,19 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
|
553
590
|
/* .is_host = */ NULL,
|
|
554
591
|
};
|
|
555
592
|
|
|
556
|
-
|
|
593
|
+
static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
|
557
594
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
558
595
|
|
|
559
596
|
return rpc_ctx->name.c_str();
|
|
560
597
|
}
|
|
561
598
|
|
|
562
|
-
|
|
599
|
+
static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
|
563
600
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
564
601
|
delete rpc_ctx;
|
|
565
602
|
delete backend;
|
|
566
603
|
}
|
|
567
604
|
|
|
568
|
-
|
|
569
|
-
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
|
570
|
-
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
571
|
-
}
|
|
572
|
-
|
|
573
|
-
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
|
605
|
+
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
|
574
606
|
UNUSED(backend);
|
|
575
607
|
// this is no-op because we don't have any async operations
|
|
576
608
|
}
|
|
@@ -612,38 +644,20 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
|
|
|
612
644
|
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
|
|
613
645
|
}
|
|
614
646
|
|
|
615
|
-
|
|
647
|
+
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
616
648
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
617
649
|
std::vector<uint8_t> input;
|
|
618
650
|
serialize_graph(cgraph, input);
|
|
619
|
-
|
|
651
|
+
rpc_msg_graph_compute_rsp response;
|
|
620
652
|
auto sock = get_socket(rpc_ctx->endpoint);
|
|
621
|
-
bool status = send_rpc_cmd(sock,
|
|
653
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
|
622
654
|
GGML_ASSERT(status);
|
|
623
|
-
|
|
624
|
-
return (enum ggml_status)output[0];
|
|
625
|
-
}
|
|
626
|
-
|
|
627
|
-
GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
|
628
|
-
UNUSED(backend);
|
|
629
|
-
UNUSED(op);
|
|
630
|
-
//TODO: call the remote backend and cache the results
|
|
631
|
-
return true;
|
|
632
|
-
}
|
|
633
|
-
|
|
634
|
-
GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
635
|
-
if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
|
636
|
-
return false;
|
|
637
|
-
}
|
|
638
|
-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
639
|
-
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
640
|
-
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
|
655
|
+
return (enum ggml_status)response.result;
|
|
641
656
|
}
|
|
642
657
|
|
|
643
658
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
|
644
659
|
/* .get_name = */ ggml_backend_rpc_name,
|
|
645
660
|
/* .free = */ ggml_backend_rpc_free,
|
|
646
|
-
/* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
|
|
647
661
|
/* .set_tensor_async = */ NULL,
|
|
648
662
|
/* .get_tensor_async = */ NULL,
|
|
649
663
|
/* .cpy_tensor_async = */ NULL,
|
|
@@ -653,17 +667,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
653
667
|
/* .graph_plan_update = */ NULL,
|
|
654
668
|
/* .graph_plan_compute = */ NULL,
|
|
655
669
|
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
|
656
|
-
/* .supports_op = */ ggml_backend_rpc_supports_op,
|
|
657
|
-
/* .supports_buft = */ ggml_backend_rpc_supports_buft,
|
|
658
|
-
/* .offload_op = */ NULL,
|
|
659
|
-
/* .event_new = */ NULL,
|
|
660
|
-
/* .event_free = */ NULL,
|
|
661
670
|
/* .event_record = */ NULL,
|
|
662
671
|
/* .event_wait = */ NULL,
|
|
663
|
-
/* .event_synchronize = */ NULL,
|
|
664
672
|
};
|
|
665
673
|
|
|
666
|
-
|
|
674
|
+
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
|
667
675
|
static std::mutex mutex;
|
|
668
676
|
std::lock_guard<std::mutex> lock(mutex);
|
|
669
677
|
// NOTE: buffer types are allocated and never freed; this is by design
|
|
@@ -674,6 +682,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
|
|
|
674
682
|
}
|
|
675
683
|
auto sock = get_socket(endpoint);
|
|
676
684
|
if (sock == nullptr) {
|
|
685
|
+
fprintf(stderr, "Failed to connect to %s\n", endpoint);
|
|
677
686
|
return nullptr;
|
|
678
687
|
}
|
|
679
688
|
size_t alignment = get_alignment(sock);
|
|
@@ -687,13 +696,14 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
|
|
|
687
696
|
|
|
688
697
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
|
689
698
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
|
699
|
+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
|
690
700
|
/* .context = */ buft_ctx
|
|
691
701
|
};
|
|
692
702
|
buft_map[endpoint] = buft;
|
|
693
703
|
return buft;
|
|
694
704
|
}
|
|
695
705
|
|
|
696
|
-
|
|
706
|
+
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
697
707
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
|
698
708
|
/* .endpoint = */ endpoint,
|
|
699
709
|
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
|
@@ -702,32 +712,25 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
|
702
712
|
ggml_backend_t backend = new ggml_backend {
|
|
703
713
|
/* .guid = */ ggml_backend_rpc_guid(),
|
|
704
714
|
/* .interface = */ ggml_backend_rpc_interface,
|
|
715
|
+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
|
705
716
|
/* .context = */ ctx
|
|
706
717
|
};
|
|
707
718
|
return backend;
|
|
708
719
|
}
|
|
709
720
|
|
|
710
|
-
|
|
721
|
+
bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|
711
722
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
|
712
723
|
}
|
|
713
724
|
|
|
714
725
|
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
std::vector<uint8_t> output;
|
|
718
|
-
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
|
|
726
|
+
rpc_msg_get_device_memory_rsp response;
|
|
727
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
|
|
719
728
|
GGML_ASSERT(status);
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
uint64_t free_mem;
|
|
723
|
-
memcpy(&free_mem, output.data(), sizeof(free_mem));
|
|
724
|
-
uint64_t total_mem;
|
|
725
|
-
memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
|
|
726
|
-
*free = free_mem;
|
|
727
|
-
*total = total_mem;
|
|
729
|
+
*free = response.free_mem;
|
|
730
|
+
*total = response.total_mem;
|
|
728
731
|
}
|
|
729
732
|
|
|
730
|
-
|
|
733
|
+
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
|
731
734
|
auto sock = get_socket(endpoint);
|
|
732
735
|
if (sock == nullptr) {
|
|
733
736
|
*free = 0;
|
|
@@ -744,16 +747,16 @@ public:
|
|
|
744
747
|
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
|
745
748
|
~rpc_server();
|
|
746
749
|
|
|
747
|
-
|
|
748
|
-
void get_alignment(
|
|
749
|
-
void get_max_size(
|
|
750
|
-
bool buffer_get_base(const
|
|
751
|
-
bool free_buffer(const
|
|
752
|
-
bool buffer_clear(const
|
|
750
|
+
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
|
751
|
+
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
|
752
|
+
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
|
753
|
+
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
|
754
|
+
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
|
755
|
+
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
|
753
756
|
bool set_tensor(const std::vector<uint8_t> & input);
|
|
754
|
-
bool get_tensor(const
|
|
755
|
-
bool copy_tensor(const
|
|
756
|
-
bool graph_compute(const std::vector<uint8_t> & input,
|
|
757
|
+
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
|
758
|
+
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
|
759
|
+
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
|
757
760
|
|
|
758
761
|
private:
|
|
759
762
|
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
|
@@ -767,80 +770,50 @@ private:
|
|
|
767
770
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
|
768
771
|
};
|
|
769
772
|
|
|
770
|
-
|
|
771
|
-
// input serialization format: | size (8 bytes) |
|
|
772
|
-
if (input.size() != sizeof(uint64_t)) {
|
|
773
|
-
return false;
|
|
774
|
-
}
|
|
775
|
-
uint64_t size;
|
|
776
|
-
memcpy(&size, input.data(), sizeof(size));
|
|
773
|
+
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
|
777
774
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
|
778
|
-
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
|
|
779
|
-
|
|
780
|
-
|
|
775
|
+
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
|
776
|
+
response.remote_ptr = 0;
|
|
777
|
+
response.remote_size = 0;
|
|
781
778
|
if (buffer != nullptr) {
|
|
782
|
-
remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
|
783
|
-
remote_size = buffer->size;
|
|
784
|
-
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
|
|
779
|
+
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
|
780
|
+
response.remote_size = buffer->size;
|
|
781
|
+
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
|
|
785
782
|
buffers.insert(buffer);
|
|
786
783
|
} else {
|
|
787
|
-
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
|
|
784
|
+
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
|
|
788
785
|
}
|
|
789
|
-
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
790
|
-
output.resize(2*sizeof(uint64_t), 0);
|
|
791
|
-
memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
|
|
792
|
-
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
|
|
793
|
-
return true;
|
|
794
786
|
}
|
|
795
787
|
|
|
796
|
-
void rpc_server::get_alignment(
|
|
788
|
+
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
|
|
797
789
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
|
798
790
|
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
|
799
791
|
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
|
800
|
-
|
|
801
|
-
output.resize(sizeof(uint64_t), 0);
|
|
802
|
-
memcpy(output.data(), &alignment, sizeof(alignment));
|
|
792
|
+
response.alignment = alignment;
|
|
803
793
|
}
|
|
804
794
|
|
|
805
|
-
void rpc_server::get_max_size(
|
|
795
|
+
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
|
|
806
796
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
|
807
797
|
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
|
808
798
|
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
|
809
|
-
|
|
810
|
-
output.resize(sizeof(uint64_t), 0);
|
|
811
|
-
memcpy(output.data(), &max_size, sizeof(max_size));
|
|
799
|
+
response.max_size = max_size;
|
|
812
800
|
}
|
|
813
801
|
|
|
814
|
-
bool rpc_server::buffer_get_base(const
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
return false;
|
|
818
|
-
}
|
|
819
|
-
uint64_t remote_ptr;
|
|
820
|
-
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
821
|
-
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
|
|
822
|
-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
802
|
+
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
|
803
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
|
804
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
823
805
|
if (buffers.find(buffer) == buffers.end()) {
|
|
824
806
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
|
825
807
|
return false;
|
|
826
808
|
}
|
|
827
809
|
void * base = ggml_backend_buffer_get_base(buffer);
|
|
828
|
-
|
|
829
|
-
uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
|
|
830
|
-
output.resize(sizeof(uint64_t), 0);
|
|
831
|
-
memcpy(output.data(), &base_ptr, sizeof(base_ptr));
|
|
810
|
+
response.base_ptr = reinterpret_cast<uint64_t>(base);
|
|
832
811
|
return true;
|
|
833
812
|
}
|
|
834
813
|
|
|
835
|
-
bool rpc_server::free_buffer(const
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
return false;
|
|
839
|
-
}
|
|
840
|
-
uint64_t remote_ptr;
|
|
841
|
-
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
842
|
-
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
|
|
843
|
-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
814
|
+
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
|
|
815
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
|
816
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
844
817
|
if (buffers.find(buffer) == buffers.end()) {
|
|
845
818
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
|
846
819
|
return false;
|
|
@@ -850,22 +823,14 @@ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
|
|
|
850
823
|
return true;
|
|
851
824
|
}
|
|
852
825
|
|
|
853
|
-
bool rpc_server::buffer_clear(const
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
return false;
|
|
857
|
-
}
|
|
858
|
-
uint64_t remote_ptr;
|
|
859
|
-
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
|
860
|
-
uint8_t value;
|
|
861
|
-
memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
|
|
862
|
-
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
|
|
863
|
-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
|
826
|
+
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
|
827
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
|
|
828
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
|
864
829
|
if (buffers.find(buffer) == buffers.end()) {
|
|
865
830
|
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
|
866
831
|
return false;
|
|
867
832
|
}
|
|
868
|
-
ggml_backend_buffer_clear(buffer, value);
|
|
833
|
+
ggml_backend_buffer_clear(buffer, request.value);
|
|
869
834
|
return true;
|
|
870
835
|
}
|
|
871
836
|
|
|
@@ -877,8 +842,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
|
|
|
877
842
|
}
|
|
878
843
|
result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
|
|
879
844
|
if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
|
|
880
|
-
|
|
845
|
+
result->buffer = nullptr;
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
if (result->buffer) {
|
|
849
|
+
// require that the tensor data does not go beyond the buffer end
|
|
850
|
+
uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
|
|
851
|
+
uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
|
|
852
|
+
uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
|
|
853
|
+
GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
|
|
854
|
+
GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
|
|
881
855
|
}
|
|
856
|
+
|
|
882
857
|
result->op = (ggml_op) tensor->op;
|
|
883
858
|
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
|
|
884
859
|
result->op_params[i] = tensor->op_params[i];
|
|
@@ -898,7 +873,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
898
873
|
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
|
899
874
|
uint64_t offset;
|
|
900
875
|
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
|
901
|
-
size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
|
|
876
|
+
const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
|
|
902
877
|
|
|
903
878
|
struct ggml_init_params params {
|
|
904
879
|
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
@@ -913,69 +888,72 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|
|
913
888
|
return false;
|
|
914
889
|
}
|
|
915
890
|
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
|
891
|
+
|
|
892
|
+
// sanitize tensor->data
|
|
893
|
+
{
|
|
894
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
|
895
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
|
896
|
+
|
|
897
|
+
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
|
898
|
+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
|
899
|
+
}
|
|
900
|
+
}
|
|
901
|
+
|
|
916
902
|
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
|
|
917
903
|
ggml_backend_tensor_set(tensor, data, offset, size);
|
|
918
904
|
ggml_free(ctx);
|
|
919
905
|
return true;
|
|
920
906
|
}
|
|
921
907
|
|
|
922
|
-
bool rpc_server::get_tensor(const
|
|
923
|
-
// serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
|
|
924
|
-
if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
|
|
925
|
-
return false;
|
|
926
|
-
}
|
|
927
|
-
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
|
928
|
-
uint64_t offset;
|
|
929
|
-
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
|
930
|
-
uint64_t size;
|
|
931
|
-
memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
|
|
932
|
-
|
|
908
|
+
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
|
|
933
909
|
struct ggml_init_params params {
|
|
934
910
|
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
935
911
|
/*.mem_buffer =*/ NULL,
|
|
936
912
|
/*.no_alloc =*/ true,
|
|
937
913
|
};
|
|
938
914
|
struct ggml_context * ctx = ggml_init(params);
|
|
939
|
-
ggml_tensor * tensor = deserialize_tensor(ctx,
|
|
915
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
940
916
|
if (tensor == nullptr) {
|
|
941
917
|
GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
|
|
942
918
|
ggml_free(ctx);
|
|
943
919
|
return false;
|
|
944
920
|
}
|
|
945
|
-
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
921
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
|
922
|
+
|
|
923
|
+
// sanitize tensor->data
|
|
924
|
+
{
|
|
925
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
|
926
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
|
927
|
+
|
|
928
|
+
if (request.tensor.data + request.offset < p0 ||
|
|
929
|
+
request.tensor.data + request.offset >= p1 ||
|
|
930
|
+
request.size > (p1 - request.tensor.data - request.offset)) {
|
|
931
|
+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
|
932
|
+
}
|
|
933
|
+
}
|
|
934
|
+
|
|
935
|
+
response.resize(request.size, 0);
|
|
936
|
+
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
|
|
949
937
|
ggml_free(ctx);
|
|
950
938
|
return true;
|
|
951
939
|
}
|
|
952
940
|
|
|
953
|
-
bool rpc_server::copy_tensor(const
|
|
954
|
-
// serialization format: | rpc_tensor src | rpc_tensor dst |
|
|
955
|
-
if (input.size() != 2*sizeof(rpc_tensor)) {
|
|
956
|
-
return false;
|
|
957
|
-
}
|
|
958
|
-
const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
|
|
959
|
-
const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
|
|
960
|
-
|
|
941
|
+
bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
|
|
961
942
|
struct ggml_init_params params {
|
|
962
943
|
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
|
963
944
|
/*.mem_buffer =*/ NULL,
|
|
964
945
|
/*.no_alloc =*/ true,
|
|
965
946
|
};
|
|
966
947
|
struct ggml_context * ctx = ggml_init(params);
|
|
967
|
-
ggml_tensor * src = deserialize_tensor(ctx,
|
|
968
|
-
ggml_tensor * dst = deserialize_tensor(ctx,
|
|
948
|
+
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
|
949
|
+
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
|
969
950
|
if (src == nullptr || dst == nullptr) {
|
|
970
951
|
GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
|
|
971
952
|
ggml_free(ctx);
|
|
972
953
|
return false;
|
|
973
954
|
}
|
|
974
955
|
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
|
|
975
|
-
|
|
976
|
-
// output serialization format: | result (1 byte) |
|
|
977
|
-
output.resize(1, 0);
|
|
978
|
-
output[0] = result;
|
|
956
|
+
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
|
979
957
|
ggml_free(ctx);
|
|
980
958
|
return true;
|
|
981
959
|
}
|
|
@@ -1004,7 +982,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|
|
1004
982
|
return result;
|
|
1005
983
|
}
|
|
1006
984
|
|
|
1007
|
-
bool rpc_server::graph_compute(const std::vector<uint8_t> & input,
|
|
985
|
+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
|
1008
986
|
// serialization format:
|
|
1009
987
|
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
|
1010
988
|
if (input.size() < sizeof(uint32_t)) {
|
|
@@ -1024,7 +1002,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
|
|
|
1024
1002
|
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
|
|
1025
1003
|
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
|
|
1026
1004
|
|
|
1027
|
-
|
|
1005
|
+
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
|
1028
1006
|
struct ggml_init_params params = {
|
|
1029
1007
|
/*.mem_size =*/ buf_size,
|
|
1030
1008
|
/*.mem_buffer =*/ NULL,
|
|
@@ -1044,9 +1022,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
|
|
|
1044
1022
|
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
|
1045
1023
|
}
|
|
1046
1024
|
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
|
1047
|
-
|
|
1048
|
-
output.resize(1, 0);
|
|
1049
|
-
output[0] = status;
|
|
1025
|
+
response.result = status;
|
|
1050
1026
|
ggml_free(ctx);
|
|
1051
1027
|
return true;
|
|
1052
1028
|
}
|
|
@@ -1064,84 +1040,162 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
1064
1040
|
if (!recv_data(sockfd, &cmd, 1)) {
|
|
1065
1041
|
break;
|
|
1066
1042
|
}
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
|
1043
|
+
if (cmd >= RPC_CMD_COUNT) {
|
|
1044
|
+
// fail fast if the command is invalid
|
|
1045
|
+
fprintf(stderr, "Unknown command: %d\n", cmd);
|
|
1071
1046
|
break;
|
|
1072
1047
|
}
|
|
1073
|
-
input.resize(input_size);
|
|
1074
|
-
if (!recv_data(sockfd, input.data(), input_size)) {
|
|
1075
|
-
break;
|
|
1076
|
-
}
|
|
1077
|
-
bool ok = true;
|
|
1078
1048
|
switch (cmd) {
|
|
1079
|
-
case
|
|
1080
|
-
|
|
1049
|
+
case RPC_CMD_ALLOC_BUFFER: {
|
|
1050
|
+
rpc_msg_alloc_buffer_req request;
|
|
1051
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1052
|
+
return;
|
|
1053
|
+
}
|
|
1054
|
+
rpc_msg_alloc_buffer_rsp response;
|
|
1055
|
+
server.alloc_buffer(request, response);
|
|
1056
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1057
|
+
return;
|
|
1058
|
+
}
|
|
1081
1059
|
break;
|
|
1082
1060
|
}
|
|
1083
|
-
case
|
|
1084
|
-
|
|
1061
|
+
case RPC_CMD_GET_ALIGNMENT: {
|
|
1062
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1063
|
+
return;
|
|
1064
|
+
}
|
|
1065
|
+
rpc_msg_get_alignment_rsp response;
|
|
1066
|
+
server.get_alignment(response);
|
|
1067
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1068
|
+
return;
|
|
1069
|
+
}
|
|
1085
1070
|
break;
|
|
1086
1071
|
}
|
|
1087
|
-
case
|
|
1088
|
-
|
|
1072
|
+
case RPC_CMD_GET_MAX_SIZE: {
|
|
1073
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1074
|
+
return;
|
|
1075
|
+
}
|
|
1076
|
+
rpc_msg_get_max_size_rsp response;
|
|
1077
|
+
server.get_max_size(response);
|
|
1078
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1079
|
+
return;
|
|
1080
|
+
}
|
|
1089
1081
|
break;
|
|
1090
1082
|
}
|
|
1091
|
-
case
|
|
1092
|
-
|
|
1083
|
+
case RPC_CMD_BUFFER_GET_BASE: {
|
|
1084
|
+
rpc_msg_buffer_get_base_req request;
|
|
1085
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1086
|
+
return;
|
|
1087
|
+
}
|
|
1088
|
+
rpc_msg_buffer_get_base_rsp response;
|
|
1089
|
+
if (!server.buffer_get_base(request, response)) {
|
|
1090
|
+
return;
|
|
1091
|
+
}
|
|
1092
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1093
|
+
return;
|
|
1094
|
+
}
|
|
1093
1095
|
break;
|
|
1094
1096
|
}
|
|
1095
|
-
case
|
|
1096
|
-
|
|
1097
|
+
case RPC_CMD_FREE_BUFFER: {
|
|
1098
|
+
rpc_msg_free_buffer_req request;
|
|
1099
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1100
|
+
return;
|
|
1101
|
+
}
|
|
1102
|
+
if (!server.free_buffer(request)) {
|
|
1103
|
+
return;
|
|
1104
|
+
}
|
|
1105
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
|
1106
|
+
return;
|
|
1107
|
+
}
|
|
1097
1108
|
break;
|
|
1098
1109
|
}
|
|
1099
|
-
case
|
|
1100
|
-
|
|
1110
|
+
case RPC_CMD_BUFFER_CLEAR: {
|
|
1111
|
+
rpc_msg_buffer_clear_req request;
|
|
1112
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1113
|
+
return;
|
|
1114
|
+
}
|
|
1115
|
+
if (!server.buffer_clear(request)) {
|
|
1116
|
+
return;
|
|
1117
|
+
}
|
|
1118
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
|
1119
|
+
return;
|
|
1120
|
+
}
|
|
1101
1121
|
break;
|
|
1102
1122
|
}
|
|
1103
|
-
case
|
|
1104
|
-
|
|
1123
|
+
case RPC_CMD_SET_TENSOR: {
|
|
1124
|
+
std::vector<uint8_t> input;
|
|
1125
|
+
if (!recv_msg(sockfd, input)) {
|
|
1126
|
+
return;
|
|
1127
|
+
}
|
|
1128
|
+
if (!server.set_tensor(input)) {
|
|
1129
|
+
return;
|
|
1130
|
+
}
|
|
1131
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
|
1132
|
+
return;
|
|
1133
|
+
}
|
|
1105
1134
|
break;
|
|
1106
1135
|
}
|
|
1107
|
-
case
|
|
1108
|
-
|
|
1136
|
+
case RPC_CMD_GET_TENSOR: {
|
|
1137
|
+
rpc_msg_get_tensor_req request;
|
|
1138
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1139
|
+
return;
|
|
1140
|
+
}
|
|
1141
|
+
std::vector<uint8_t> response;
|
|
1142
|
+
if (!server.get_tensor(request, response)) {
|
|
1143
|
+
return;
|
|
1144
|
+
}
|
|
1145
|
+
if (!send_msg(sockfd, response.data(), response.size())) {
|
|
1146
|
+
return;
|
|
1147
|
+
}
|
|
1109
1148
|
break;
|
|
1110
1149
|
}
|
|
1111
|
-
case
|
|
1112
|
-
|
|
1150
|
+
case RPC_CMD_COPY_TENSOR: {
|
|
1151
|
+
rpc_msg_copy_tensor_req request;
|
|
1152
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
1153
|
+
return;
|
|
1154
|
+
}
|
|
1155
|
+
rpc_msg_copy_tensor_rsp response;
|
|
1156
|
+
if (!server.copy_tensor(request, response)) {
|
|
1157
|
+
return;
|
|
1158
|
+
}
|
|
1159
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1160
|
+
return;
|
|
1161
|
+
}
|
|
1113
1162
|
break;
|
|
1114
1163
|
}
|
|
1115
|
-
case
|
|
1116
|
-
|
|
1164
|
+
case RPC_CMD_GRAPH_COMPUTE: {
|
|
1165
|
+
std::vector<uint8_t> input;
|
|
1166
|
+
if (!recv_msg(sockfd, input)) {
|
|
1167
|
+
return;
|
|
1168
|
+
}
|
|
1169
|
+
rpc_msg_graph_compute_rsp response;
|
|
1170
|
+
if (!server.graph_compute(input, response)) {
|
|
1171
|
+
return;
|
|
1172
|
+
}
|
|
1173
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1174
|
+
return;
|
|
1175
|
+
}
|
|
1117
1176
|
break;
|
|
1118
1177
|
}
|
|
1119
|
-
case
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1178
|
+
case RPC_CMD_GET_DEVICE_MEMORY: {
|
|
1179
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
|
1180
|
+
return;
|
|
1181
|
+
}
|
|
1182
|
+
rpc_msg_get_device_memory_rsp response;
|
|
1183
|
+
response.free_mem = free_mem;
|
|
1184
|
+
response.total_mem = total_mem;
|
|
1185
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
1186
|
+
return;
|
|
1187
|
+
}
|
|
1124
1188
|
break;
|
|
1125
1189
|
}
|
|
1126
1190
|
default: {
|
|
1127
1191
|
fprintf(stderr, "Unknown command: %d\n", cmd);
|
|
1128
|
-
|
|
1192
|
+
return;
|
|
1129
1193
|
}
|
|
1130
1194
|
}
|
|
1131
|
-
if (!ok) {
|
|
1132
|
-
break;
|
|
1133
|
-
}
|
|
1134
|
-
uint64_t output_size = output.size();
|
|
1135
|
-
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
|
|
1136
|
-
break;
|
|
1137
|
-
}
|
|
1138
|
-
if (!send_data(sockfd, output.data(), output_size)) {
|
|
1139
|
-
break;
|
|
1140
|
-
}
|
|
1141
1195
|
}
|
|
1142
1196
|
}
|
|
1143
1197
|
|
|
1144
|
-
void
|
|
1198
|
+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
|
1145
1199
|
std::string host;
|
|
1146
1200
|
int port;
|
|
1147
1201
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
@@ -1169,10 +1223,181 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
|
|
|
1169
1223
|
return;
|
|
1170
1224
|
}
|
|
1171
1225
|
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
|
1226
|
+
fflush(stdout);
|
|
1172
1227
|
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
|
1173
1228
|
printf("Client connection closed\n");
|
|
1229
|
+
fflush(stdout);
|
|
1174
1230
|
}
|
|
1175
1231
|
#ifdef _WIN32
|
|
1176
1232
|
WSACleanup();
|
|
1177
1233
|
#endif
|
|
1178
1234
|
}
|
|
1235
|
+
|
|
1236
|
+
// device interface
|
|
1237
|
+
|
|
1238
|
+
struct ggml_backend_rpc_device_context {
|
|
1239
|
+
std::string endpoint;
|
|
1240
|
+
std::string name;
|
|
1241
|
+
};
|
|
1242
|
+
|
|
1243
|
+
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
|
1244
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1245
|
+
|
|
1246
|
+
return ctx->name.c_str();
|
|
1247
|
+
}
|
|
1248
|
+
|
|
1249
|
+
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
|
|
1250
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1251
|
+
|
|
1252
|
+
return ctx->name.c_str();
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
1256
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1257
|
+
|
|
1258
|
+
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
|
|
1259
|
+
|
|
1260
|
+
UNUSED(dev);
|
|
1261
|
+
}
|
|
1262
|
+
|
|
1263
|
+
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
|
1264
|
+
// TODO: obtain value from the server
|
|
1265
|
+
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
|
1266
|
+
|
|
1267
|
+
UNUSED(dev);
|
|
1268
|
+
}
|
|
1269
|
+
|
|
1270
|
+
static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
|
1271
|
+
props->name = ggml_backend_rpc_device_get_name(dev);
|
|
1272
|
+
props->description = ggml_backend_rpc_device_get_description(dev);
|
|
1273
|
+
props->type = ggml_backend_rpc_device_get_type(dev);
|
|
1274
|
+
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
1275
|
+
props->caps = {
|
|
1276
|
+
/* .async = */ false,
|
|
1277
|
+
/* .host_buffer = */ false,
|
|
1278
|
+
/* .buffer_from_host_ptr = */ false,
|
|
1279
|
+
/* .events = */ false,
|
|
1280
|
+
};
|
|
1281
|
+
}
|
|
1282
|
+
|
|
1283
|
+
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
1284
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1285
|
+
|
|
1286
|
+
return ggml_backend_rpc_init(ctx->endpoint.c_str());
|
|
1287
|
+
|
|
1288
|
+
UNUSED(params);
|
|
1289
|
+
}
|
|
1290
|
+
|
|
1291
|
+
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
1292
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1293
|
+
|
|
1294
|
+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
1295
|
+
|
|
1296
|
+
UNUSED(dev);
|
|
1297
|
+
}
|
|
1298
|
+
|
|
1299
|
+
static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
1300
|
+
UNUSED(dev);
|
|
1301
|
+
UNUSED(op);
|
|
1302
|
+
//TODO: call the remote backend and cache the results
|
|
1303
|
+
return true;
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1306
|
+
static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
1307
|
+
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
|
1308
|
+
return false;
|
|
1309
|
+
}
|
|
1310
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
1311
|
+
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
1312
|
+
return buft_ctx->endpoint == dev_ctx->endpoint;
|
|
1313
|
+
}
|
|
1314
|
+
|
|
1315
|
+
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
1316
|
+
/* .get_name = */ ggml_backend_rpc_device_get_name,
|
|
1317
|
+
/* .get_description = */ ggml_backend_rpc_device_get_description,
|
|
1318
|
+
/* .get_memory = */ ggml_backend_rpc_device_get_memory,
|
|
1319
|
+
/* .get_type = */ ggml_backend_rpc_device_get_type,
|
|
1320
|
+
/* .get_props = */ ggml_backend_rpc_device_get_props,
|
|
1321
|
+
/* .init_backend = */ ggml_backend_rpc_device_init,
|
|
1322
|
+
/* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
|
|
1323
|
+
/* .get_host_buffer_type = */ NULL,
|
|
1324
|
+
/* .buffer_from_host_ptr = */ NULL,
|
|
1325
|
+
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
|
|
1326
|
+
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
|
|
1327
|
+
/* .offload_op = */ NULL,
|
|
1328
|
+
/* .event_new = */ NULL,
|
|
1329
|
+
/* .event_free = */ NULL,
|
|
1330
|
+
/* .event_synchronize = */ NULL,
|
|
1331
|
+
};
|
|
1332
|
+
|
|
1333
|
+
// backend reg interface
|
|
1334
|
+
|
|
1335
|
+
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
|
1336
|
+
return "RPC";
|
|
1337
|
+
|
|
1338
|
+
UNUSED(reg);
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
1342
|
+
return 0;
|
|
1343
|
+
|
|
1344
|
+
UNUSED(reg);
|
|
1345
|
+
}
|
|
1346
|
+
|
|
1347
|
+
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
|
1348
|
+
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
|
|
1349
|
+
|
|
1350
|
+
UNUSED(reg);
|
|
1351
|
+
UNUSED(index);
|
|
1352
|
+
}
|
|
1353
|
+
|
|
1354
|
+
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
|
1355
|
+
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
|
|
1356
|
+
return (void *)ggml_backend_rpc_add_device;
|
|
1357
|
+
}
|
|
1358
|
+
return NULL;
|
|
1359
|
+
|
|
1360
|
+
UNUSED(reg);
|
|
1361
|
+
}
|
|
1362
|
+
|
|
1363
|
+
static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
|
|
1364
|
+
/* .get_name = */ ggml_backend_rpc_reg_get_name,
|
|
1365
|
+
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
|
|
1366
|
+
/* .get_device = */ ggml_backend_rpc_reg_get_device,
|
|
1367
|
+
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
|
|
1368
|
+
};
|
|
1369
|
+
|
|
1370
|
+
ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
|
1371
|
+
static struct ggml_backend_reg ggml_backend_rpc_reg = {
|
|
1372
|
+
/* .iface = */ ggml_backend_rpc_reg_i,
|
|
1373
|
+
/* .context = */ NULL,
|
|
1374
|
+
};
|
|
1375
|
+
|
|
1376
|
+
return &ggml_backend_rpc_reg;
|
|
1377
|
+
}
|
|
1378
|
+
|
|
1379
|
+
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
|
|
1380
|
+
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
|
|
1381
|
+
|
|
1382
|
+
static std::mutex mutex;
|
|
1383
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
1384
|
+
|
|
1385
|
+
if (dev_map.find(endpoint) != dev_map.end()) {
|
|
1386
|
+
return dev_map[endpoint];
|
|
1387
|
+
}
|
|
1388
|
+
|
|
1389
|
+
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
|
|
1390
|
+
/* .endpoint = */ endpoint,
|
|
1391
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
|
1392
|
+
};
|
|
1393
|
+
|
|
1394
|
+
ggml_backend_dev_t dev = new ggml_backend_device {
|
|
1395
|
+
/* .iface = */ ggml_backend_rpc_device_i,
|
|
1396
|
+
/* .reg = */ ggml_backend_rpc_reg(),
|
|
1397
|
+
/* .context = */ ctx,
|
|
1398
|
+
};
|
|
1399
|
+
|
|
1400
|
+
dev_map[endpoint] = dev;
|
|
1401
|
+
|
|
1402
|
+
return dev;
|
|
1403
|
+
}
|