whispercpp 1.3.0 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
@@ -0,0 +1,1406 @@
|
|
1
|
+
#include "ggml-rpc.h"
|
2
|
+
#include "ggml-impl.h"
|
3
|
+
#include "ggml-backend-impl.h"
|
4
|
+
|
5
|
+
#include <cinttypes>
|
6
|
+
#include <string>
|
7
|
+
#include <vector>
|
8
|
+
#include <memory>
|
9
|
+
#include <mutex>
|
10
|
+
#include <unordered_map>
|
11
|
+
#include <unordered_set>
|
12
|
+
#ifdef _WIN32
|
13
|
+
# define WIN32_LEAN_AND_MEAN
|
14
|
+
# ifndef NOMINMAX
|
15
|
+
# define NOMINMAX
|
16
|
+
# endif
|
17
|
+
# include <windows.h>
|
18
|
+
# include <winsock2.h>
|
19
|
+
#else
|
20
|
+
# include <arpa/inet.h>
|
21
|
+
# include <sys/socket.h>
|
22
|
+
# include <sys/types.h>
|
23
|
+
# include <netinet/in.h>
|
24
|
+
# include <netinet/tcp.h>
|
25
|
+
# include <netdb.h>
|
26
|
+
# include <unistd.h>
|
27
|
+
#endif
|
28
|
+
#include <cstring>
|
29
|
+
|
30
|
+
#define UNUSED GGML_UNUSED
|
31
|
+
|
32
|
+
#define GGML_DEBUG 0
|
33
|
+
#if (GGML_DEBUG >= 1)
|
34
|
+
#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
|
35
|
+
#else
|
36
|
+
#define GGML_PRINT_DEBUG(...)
|
37
|
+
#endif
|
38
|
+
|
39
|
+
#ifdef _WIN32
|
40
|
+
typedef SOCKET sockfd_t;
|
41
|
+
using ssize_t = __int64;
|
42
|
+
#else
|
43
|
+
typedef int sockfd_t;
|
44
|
+
#endif
|
45
|
+
|
46
|
+
// cross-platform socket
|
47
|
+
struct socket_t {
|
48
|
+
sockfd_t fd;
|
49
|
+
socket_t(sockfd_t fd) : fd(fd) {}
|
50
|
+
~socket_t() {
|
51
|
+
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
|
52
|
+
#ifdef _WIN32
|
53
|
+
closesocket(this->fd);
|
54
|
+
#else
|
55
|
+
close(this->fd);
|
56
|
+
#endif
|
57
|
+
}
|
58
|
+
};
|
59
|
+
|
60
|
+
// all RPC structures must be packed
|
61
|
+
#pragma pack(push, 1)
|
62
|
+
// ggml_tensor is serialized into rpc_tensor
|
63
|
+
struct rpc_tensor {
|
64
|
+
uint64_t id;
|
65
|
+
uint32_t type;
|
66
|
+
uint64_t buffer;
|
67
|
+
uint32_t ne[GGML_MAX_DIMS];
|
68
|
+
uint32_t nb[GGML_MAX_DIMS];
|
69
|
+
uint32_t op;
|
70
|
+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
71
|
+
int32_t flags;
|
72
|
+
uint64_t src[GGML_MAX_SRC];
|
73
|
+
uint64_t view_src;
|
74
|
+
uint64_t view_offs;
|
75
|
+
uint64_t data;
|
76
|
+
char name[GGML_MAX_NAME];
|
77
|
+
|
78
|
+
char padding[4];
|
79
|
+
};
|
80
|
+
|
81
|
+
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
|
82
|
+
|
83
|
+
// RPC commands
|
84
|
+
enum rpc_cmd {
|
85
|
+
RPC_CMD_ALLOC_BUFFER = 0,
|
86
|
+
RPC_CMD_GET_ALIGNMENT,
|
87
|
+
RPC_CMD_GET_MAX_SIZE,
|
88
|
+
RPC_CMD_BUFFER_GET_BASE,
|
89
|
+
RPC_CMD_FREE_BUFFER,
|
90
|
+
RPC_CMD_BUFFER_CLEAR,
|
91
|
+
RPC_CMD_SET_TENSOR,
|
92
|
+
RPC_CMD_GET_TENSOR,
|
93
|
+
RPC_CMD_COPY_TENSOR,
|
94
|
+
RPC_CMD_GRAPH_COMPUTE,
|
95
|
+
RPC_CMD_GET_DEVICE_MEMORY,
|
96
|
+
RPC_CMD_COUNT,
|
97
|
+
};
|
98
|
+
|
99
|
+
struct rpc_msg_alloc_buffer_req {
|
100
|
+
uint64_t size;
|
101
|
+
};
|
102
|
+
|
103
|
+
struct rpc_msg_alloc_buffer_rsp {
|
104
|
+
uint64_t remote_ptr;
|
105
|
+
uint64_t remote_size;
|
106
|
+
};
|
107
|
+
|
108
|
+
struct rpc_msg_get_alignment_rsp {
|
109
|
+
uint64_t alignment;
|
110
|
+
};
|
111
|
+
|
112
|
+
struct rpc_msg_get_max_size_rsp {
|
113
|
+
uint64_t max_size;
|
114
|
+
};
|
115
|
+
|
116
|
+
struct rpc_msg_buffer_get_base_req {
|
117
|
+
uint64_t remote_ptr;
|
118
|
+
};
|
119
|
+
|
120
|
+
struct rpc_msg_buffer_get_base_rsp {
|
121
|
+
uint64_t base_ptr;
|
122
|
+
};
|
123
|
+
|
124
|
+
struct rpc_msg_free_buffer_req {
|
125
|
+
uint64_t remote_ptr;
|
126
|
+
};
|
127
|
+
|
128
|
+
struct rpc_msg_buffer_clear_req {
|
129
|
+
uint64_t remote_ptr;
|
130
|
+
uint8_t value;
|
131
|
+
};
|
132
|
+
|
133
|
+
struct rpc_msg_get_tensor_req {
|
134
|
+
rpc_tensor tensor;
|
135
|
+
uint64_t offset;
|
136
|
+
uint64_t size;
|
137
|
+
};
|
138
|
+
|
139
|
+
struct rpc_msg_copy_tensor_req {
|
140
|
+
rpc_tensor src;
|
141
|
+
rpc_tensor dst;
|
142
|
+
};
|
143
|
+
|
144
|
+
struct rpc_msg_copy_tensor_rsp {
|
145
|
+
uint8_t result;
|
146
|
+
};
|
147
|
+
|
148
|
+
struct rpc_msg_graph_compute_rsp {
|
149
|
+
uint8_t result;
|
150
|
+
};
|
151
|
+
|
152
|
+
struct rpc_msg_get_device_memory_rsp {
|
153
|
+
uint64_t free_mem;
|
154
|
+
uint64_t total_mem;
|
155
|
+
};
|
156
|
+
#pragma pack(pop)
|
157
|
+
|
158
|
+
// RPC data structures
|
159
|
+
|
160
|
+
static ggml_guid_t ggml_backend_rpc_guid() {
|
161
|
+
static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
|
162
|
+
return &guid;
|
163
|
+
}
|
164
|
+
|
165
|
+
struct ggml_backend_rpc_buffer_type_context {
|
166
|
+
std::string endpoint;
|
167
|
+
std::string name;
|
168
|
+
size_t alignment;
|
169
|
+
size_t max_size;
|
170
|
+
};
|
171
|
+
|
172
|
+
struct ggml_backend_rpc_context {
|
173
|
+
std::string endpoint;
|
174
|
+
std::string name;
|
175
|
+
};
|
176
|
+
|
177
|
+
struct ggml_backend_rpc_buffer_context {
|
178
|
+
std::shared_ptr<socket_t> sock;
|
179
|
+
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
180
|
+
uint64_t remote_ptr;
|
181
|
+
};
|
182
|
+
|
183
|
+
// RPC helper functions
|
184
|
+
|
185
|
+
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
|
186
|
+
#ifdef _WIN32
|
187
|
+
if (fd == INVALID_SOCKET) {
|
188
|
+
return nullptr;
|
189
|
+
}
|
190
|
+
#else
|
191
|
+
if (fd < 0) {
|
192
|
+
return nullptr;
|
193
|
+
}
|
194
|
+
#endif
|
195
|
+
return std::make_shared<socket_t>(fd);
|
196
|
+
}
|
197
|
+
|
198
|
+
static bool set_no_delay(sockfd_t sockfd) {
|
199
|
+
int flag = 1;
|
200
|
+
// set TCP_NODELAY to disable Nagle's algorithm
|
201
|
+
int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
|
202
|
+
return ret == 0;
|
203
|
+
}
|
204
|
+
|
205
|
+
static bool set_reuse_addr(sockfd_t sockfd) {
|
206
|
+
int flag = 1;
|
207
|
+
int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
|
208
|
+
return ret == 0;
|
209
|
+
}
|
210
|
+
|
211
|
+
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
|
212
|
+
struct sockaddr_in addr;
|
213
|
+
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
214
|
+
auto sock_ptr = make_socket(sockfd);
|
215
|
+
if (sock_ptr == nullptr) {
|
216
|
+
return nullptr;
|
217
|
+
}
|
218
|
+
if (!set_no_delay(sockfd)) {
|
219
|
+
fprintf(stderr, "Failed to set TCP_NODELAY\n");
|
220
|
+
return nullptr;
|
221
|
+
}
|
222
|
+
addr.sin_family = AF_INET;
|
223
|
+
addr.sin_port = htons(port);
|
224
|
+
struct hostent * server = gethostbyname(host);
|
225
|
+
if (server == NULL) {
|
226
|
+
fprintf(stderr, "Cannot resolve host '%s'\n", host);
|
227
|
+
return nullptr;
|
228
|
+
}
|
229
|
+
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
|
230
|
+
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
|
231
|
+
return nullptr;
|
232
|
+
}
|
233
|
+
return sock_ptr;
|
234
|
+
}
|
235
|
+
|
236
|
+
static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
|
237
|
+
auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
|
238
|
+
auto client_socket = make_socket(client_socket_fd);
|
239
|
+
if (client_socket == nullptr) {
|
240
|
+
return nullptr;
|
241
|
+
}
|
242
|
+
if (!set_no_delay(client_socket_fd)) {
|
243
|
+
fprintf(stderr, "Failed to set TCP_NODELAY\n");
|
244
|
+
return nullptr;
|
245
|
+
}
|
246
|
+
return client_socket;
|
247
|
+
}
|
248
|
+
|
249
|
+
static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
|
250
|
+
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
251
|
+
auto sock = make_socket(sockfd);
|
252
|
+
if (sock == nullptr) {
|
253
|
+
return nullptr;
|
254
|
+
}
|
255
|
+
if (!set_reuse_addr(sockfd)) {
|
256
|
+
fprintf(stderr, "Failed to set SO_REUSEADDR\n");
|
257
|
+
return nullptr;
|
258
|
+
}
|
259
|
+
if (inet_addr(host) == INADDR_NONE) {
|
260
|
+
fprintf(stderr, "Invalid host address: %s\n", host);
|
261
|
+
return nullptr;
|
262
|
+
}
|
263
|
+
struct sockaddr_in serv_addr;
|
264
|
+
serv_addr.sin_family = AF_INET;
|
265
|
+
serv_addr.sin_addr.s_addr = inet_addr(host);
|
266
|
+
serv_addr.sin_port = htons(port);
|
267
|
+
|
268
|
+
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
|
269
|
+
return nullptr;
|
270
|
+
}
|
271
|
+
if (listen(sockfd, 1) < 0) {
|
272
|
+
return nullptr;
|
273
|
+
}
|
274
|
+
return sock;
|
275
|
+
}
|
276
|
+
|
277
|
+
static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
|
278
|
+
size_t bytes_sent = 0;
|
279
|
+
while (bytes_sent < size) {
|
280
|
+
ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
|
281
|
+
if (n < 0) {
|
282
|
+
return false;
|
283
|
+
}
|
284
|
+
bytes_sent += n;
|
285
|
+
}
|
286
|
+
return true;
|
287
|
+
}
|
288
|
+
|
289
|
+
static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
290
|
+
size_t bytes_recv = 0;
|
291
|
+
while (bytes_recv < size) {
|
292
|
+
ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
|
293
|
+
if (n <= 0) {
|
294
|
+
return false;
|
295
|
+
}
|
296
|
+
bytes_recv += n;
|
297
|
+
}
|
298
|
+
return true;
|
299
|
+
}
|
300
|
+
|
301
|
+
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
|
302
|
+
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
|
303
|
+
return false;
|
304
|
+
}
|
305
|
+
return send_data(sockfd, msg, msg_size);
|
306
|
+
}
|
307
|
+
|
308
|
+
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
|
309
|
+
uint64_t size;
|
310
|
+
if (!recv_data(sockfd, &size, sizeof(size))) {
|
311
|
+
return false;
|
312
|
+
}
|
313
|
+
if (size != msg_size) {
|
314
|
+
return false;
|
315
|
+
}
|
316
|
+
return recv_data(sockfd, msg, msg_size);
|
317
|
+
}
|
318
|
+
|
319
|
+
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
|
320
|
+
uint64_t size;
|
321
|
+
if (!recv_data(sockfd, &size, sizeof(size))) {
|
322
|
+
return false;
|
323
|
+
}
|
324
|
+
try {
|
325
|
+
input.resize(size);
|
326
|
+
} catch (const std::bad_alloc & e) {
|
327
|
+
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
|
328
|
+
return false;
|
329
|
+
}
|
330
|
+
return recv_data(sockfd, input.data(), size);
|
331
|
+
}
|
332
|
+
|
333
|
+
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
334
|
+
size_t pos = endpoint.find(':');
|
335
|
+
if (pos == std::string::npos) {
|
336
|
+
return false;
|
337
|
+
}
|
338
|
+
host = endpoint.substr(0, pos);
|
339
|
+
port = std::stoi(endpoint.substr(pos + 1));
|
340
|
+
return true;
|
341
|
+
}
|
342
|
+
|
343
|
+
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
344
|
+
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
345
|
+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
|
346
|
+
uint8_t cmd_byte = cmd;
|
347
|
+
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
348
|
+
return false;
|
349
|
+
}
|
350
|
+
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
|
351
|
+
return false;
|
352
|
+
}
|
353
|
+
if (!send_data(sock->fd, input, input_size)) {
|
354
|
+
return false;
|
355
|
+
}
|
356
|
+
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
|
357
|
+
// even if we do, we can skip sending output_size from the server for commands with known output size
|
358
|
+
uint64_t out_size;
|
359
|
+
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
|
360
|
+
return false;
|
361
|
+
}
|
362
|
+
if (out_size != output_size) {
|
363
|
+
return false;
|
364
|
+
}
|
365
|
+
if (!recv_data(sock->fd, output, output_size)) {
|
366
|
+
return false;
|
367
|
+
}
|
368
|
+
return true;
|
369
|
+
}
|
370
|
+
|
371
|
+
// RPC client-side implementation
|
372
|
+
|
373
|
+
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
374
|
+
static std::mutex mutex;
|
375
|
+
std::lock_guard<std::mutex> lock(mutex);
|
376
|
+
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
|
377
|
+
static bool initialized = false;
|
378
|
+
|
379
|
+
auto it = sockets.find(endpoint);
|
380
|
+
if (it != sockets.end()) {
|
381
|
+
if (auto sock = it->second.lock()) {
|
382
|
+
return sock;
|
383
|
+
}
|
384
|
+
}
|
385
|
+
std::string host;
|
386
|
+
int port;
|
387
|
+
if (!parse_endpoint(endpoint, host, port)) {
|
388
|
+
return nullptr;
|
389
|
+
}
|
390
|
+
#ifdef _WIN32
|
391
|
+
if (!initialized) {
|
392
|
+
WSADATA wsaData;
|
393
|
+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
394
|
+
if (res != 0) {
|
395
|
+
return nullptr;
|
396
|
+
}
|
397
|
+
initialized = true;
|
398
|
+
}
|
399
|
+
#else
|
400
|
+
UNUSED(initialized);
|
401
|
+
#endif
|
402
|
+
auto sock = socket_connect(host.c_str(), port);
|
403
|
+
if (sock == nullptr) {
|
404
|
+
return nullptr;
|
405
|
+
}
|
406
|
+
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
407
|
+
sockets[endpoint] = sock;
|
408
|
+
return sock;
|
409
|
+
}
|
410
|
+
|
411
|
+
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
412
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
413
|
+
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
|
414
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
|
415
|
+
GGML_ASSERT(status);
|
416
|
+
delete ctx;
|
417
|
+
}
|
418
|
+
|
419
|
+
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
420
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
421
|
+
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
422
|
+
return ctx->base_cache[buffer];
|
423
|
+
}
|
424
|
+
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
425
|
+
rpc_msg_buffer_get_base_rsp response;
|
426
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
427
|
+
GGML_ASSERT(status);
|
428
|
+
void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
429
|
+
ctx->base_cache[buffer] = base_ptr;
|
430
|
+
return base_ptr;
|
431
|
+
}
|
432
|
+
|
433
|
+
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
434
|
+
rpc_tensor result;
|
435
|
+
result.id = reinterpret_cast<uint64_t>(tensor);
|
436
|
+
result.type = tensor->type;
|
437
|
+
if (tensor->buffer) {
|
438
|
+
ggml_backend_buffer_t buffer = tensor->buffer;
|
439
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
440
|
+
result.buffer = ctx->remote_ptr;
|
441
|
+
} else {
|
442
|
+
result.buffer = 0;
|
443
|
+
}
|
444
|
+
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
445
|
+
result.ne[i] = tensor->ne[i];
|
446
|
+
result.nb[i] = tensor->nb[i];
|
447
|
+
}
|
448
|
+
result.op = tensor->op;
|
449
|
+
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
|
450
|
+
result.op_params[i] = tensor->op_params[i];
|
451
|
+
}
|
452
|
+
result.flags = tensor->flags;
|
453
|
+
for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
|
454
|
+
result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
|
455
|
+
}
|
456
|
+
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
|
457
|
+
result.view_offs = tensor->view_offs;
|
458
|
+
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
459
|
+
snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
|
460
|
+
return result;
|
461
|
+
}
|
462
|
+
|
463
|
+
static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
464
|
+
UNUSED(buffer);
|
465
|
+
if (ggml_is_quantized(tensor->type)) {
|
466
|
+
// TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
|
467
|
+
GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
|
468
|
+
}
|
469
|
+
}
|
470
|
+
|
471
|
+
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
472
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
473
|
+
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
|
474
|
+
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
475
|
+
std::vector<uint8_t> input(input_size, 0);
|
476
|
+
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
477
|
+
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
478
|
+
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
479
|
+
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
480
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
|
481
|
+
GGML_ASSERT(status);
|
482
|
+
}
|
483
|
+
|
484
|
+
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
485
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
486
|
+
rpc_msg_get_tensor_req request;
|
487
|
+
request.tensor = serialize_tensor(tensor);
|
488
|
+
request.offset = offset;
|
489
|
+
request.size = size;
|
490
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
|
491
|
+
GGML_ASSERT(status);
|
492
|
+
}
|
493
|
+
|
494
|
+
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
495
|
+
// check if src and dst are on the same server
|
496
|
+
ggml_backend_buffer_t src_buffer = src->buffer;
|
497
|
+
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
498
|
+
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
499
|
+
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
500
|
+
if (src_ctx->sock != dst_ctx->sock) {
|
501
|
+
return false;
|
502
|
+
}
|
503
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
504
|
+
rpc_msg_copy_tensor_req request;
|
505
|
+
request.src = serialize_tensor(src);
|
506
|
+
request.dst = serialize_tensor(dst);
|
507
|
+
rpc_msg_copy_tensor_rsp response;
|
508
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
509
|
+
GGML_ASSERT(status);
|
510
|
+
return response.result;
|
511
|
+
}
|
512
|
+
|
513
|
+
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
514
|
+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
515
|
+
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
|
516
|
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
|
517
|
+
GGML_ASSERT(status);
|
518
|
+
}
|
519
|
+
|
520
|
+
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
521
|
+
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
|
522
|
+
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
|
523
|
+
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
|
524
|
+
/* .memset_tensor = */ NULL,
|
525
|
+
/* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
|
526
|
+
/* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
|
527
|
+
/* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
|
528
|
+
/* .clear = */ ggml_backend_rpc_buffer_clear,
|
529
|
+
/* .reset = */ NULL,
|
530
|
+
};
|
531
|
+
|
532
|
+
static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
533
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
534
|
+
return buft_ctx->name.c_str();
|
535
|
+
}
|
536
|
+
|
537
|
+
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
538
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
539
|
+
rpc_msg_alloc_buffer_req request = {size};
|
540
|
+
rpc_msg_alloc_buffer_rsp response;
|
541
|
+
auto sock = get_socket(buft_ctx->endpoint);
|
542
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
543
|
+
GGML_ASSERT(status);
|
544
|
+
if (response.remote_ptr != 0) {
|
545
|
+
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
546
|
+
ggml_backend_rpc_buffer_interface,
|
547
|
+
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
|
548
|
+
response.remote_size);
|
549
|
+
return buffer;
|
550
|
+
} else {
|
551
|
+
return nullptr;
|
552
|
+
}
|
553
|
+
}
|
554
|
+
|
555
|
+
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
556
|
+
rpc_msg_get_alignment_rsp response;
|
557
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
|
558
|
+
GGML_ASSERT(status);
|
559
|
+
return response.alignment;
|
560
|
+
}
|
561
|
+
|
562
|
+
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
563
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
564
|
+
return buft_ctx->alignment;
|
565
|
+
}
|
566
|
+
|
567
|
+
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
568
|
+
rpc_msg_get_max_size_rsp response;
|
569
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
|
570
|
+
GGML_ASSERT(status);
|
571
|
+
return response.max_size;
|
572
|
+
}
|
573
|
+
|
574
|
+
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
575
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
576
|
+
return buft_ctx->max_size;
|
577
|
+
}
|
578
|
+
|
579
|
+
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
580
|
+
UNUSED(buft);
|
581
|
+
return ggml_nbytes(tensor);
|
582
|
+
}
|
583
|
+
|
584
|
+
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
585
|
+
/* .get_name = */ ggml_backend_rpc_buffer_type_name,
|
586
|
+
/* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
|
587
|
+
/* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
|
588
|
+
/* .get_max_size = */ ggml_backend_rpc_get_max_size,
|
589
|
+
/* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
|
590
|
+
/* .is_host = */ NULL,
|
591
|
+
};
|
592
|
+
|
593
|
+
static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
594
|
+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
595
|
+
|
596
|
+
return rpc_ctx->name.c_str();
|
597
|
+
}
|
598
|
+
|
599
|
+
static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
600
|
+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
601
|
+
delete rpc_ctx;
|
602
|
+
delete backend;
|
603
|
+
}
|
604
|
+
|
605
|
+
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
606
|
+
UNUSED(backend);
|
607
|
+
// this is no-op because we don't have any async operations
|
608
|
+
}
|
609
|
+
|
610
|
+
static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
|
611
|
+
if (tensor == nullptr) {
|
612
|
+
return;
|
613
|
+
}
|
614
|
+
if (visited.find(tensor) != visited.end()) {
|
615
|
+
return;
|
616
|
+
}
|
617
|
+
visited.insert(tensor);
|
618
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
619
|
+
add_tensor(tensor->src[i], tensors, visited);
|
620
|
+
}
|
621
|
+
add_tensor(tensor->view_src, tensors, visited);
|
622
|
+
tensors.push_back(serialize_tensor(tensor));
|
623
|
+
}
|
624
|
+
|
625
|
+
static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
626
|
+
uint32_t n_nodes = cgraph->n_nodes;
|
627
|
+
std::vector<rpc_tensor> tensors;
|
628
|
+
std::unordered_set<ggml_tensor*> visited;
|
629
|
+
for (uint32_t i = 0; i < n_nodes; i++) {
|
630
|
+
add_tensor(cgraph->nodes[i], tensors, visited);
|
631
|
+
}
|
632
|
+
// serialization format:
|
633
|
+
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
634
|
+
uint32_t n_tensors = tensors.size();
|
635
|
+
int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
636
|
+
output.resize(output_size, 0);
|
637
|
+
memcpy(output.data(), &n_nodes, sizeof(n_nodes));
|
638
|
+
for (uint32_t i = 0; i < n_nodes; i++) {
|
639
|
+
memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
|
640
|
+
}
|
641
|
+
uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
|
642
|
+
*out_ntensors = n_tensors;
|
643
|
+
rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
|
644
|
+
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
|
645
|
+
}
|
646
|
+
|
647
|
+
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
648
|
+
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
649
|
+
std::vector<uint8_t> input;
|
650
|
+
serialize_graph(cgraph, input);
|
651
|
+
rpc_msg_graph_compute_rsp response;
|
652
|
+
auto sock = get_socket(rpc_ctx->endpoint);
|
653
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
654
|
+
GGML_ASSERT(status);
|
655
|
+
return (enum ggml_status)response.result;
|
656
|
+
}
|
657
|
+
|
658
|
+
static ggml_backend_i ggml_backend_rpc_interface = {
|
659
|
+
/* .get_name = */ ggml_backend_rpc_name,
|
660
|
+
/* .free = */ ggml_backend_rpc_free,
|
661
|
+
/* .set_tensor_async = */ NULL,
|
662
|
+
/* .get_tensor_async = */ NULL,
|
663
|
+
/* .cpy_tensor_async = */ NULL,
|
664
|
+
/* .synchronize = */ ggml_backend_rpc_synchronize,
|
665
|
+
/* .graph_plan_create = */ NULL,
|
666
|
+
/* .graph_plan_free = */ NULL,
|
667
|
+
/* .graph_plan_update = */ NULL,
|
668
|
+
/* .graph_plan_compute = */ NULL,
|
669
|
+
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
670
|
+
/* .event_record = */ NULL,
|
671
|
+
/* .event_wait = */ NULL,
|
672
|
+
};
|
673
|
+
|
674
|
+
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
675
|
+
static std::mutex mutex;
|
676
|
+
std::lock_guard<std::mutex> lock(mutex);
|
677
|
+
// NOTE: buffer types are allocated and never freed; this is by design
|
678
|
+
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
679
|
+
auto it = buft_map.find(endpoint);
|
680
|
+
if (it != buft_map.end()) {
|
681
|
+
return it->second;
|
682
|
+
}
|
683
|
+
auto sock = get_socket(endpoint);
|
684
|
+
if (sock == nullptr) {
|
685
|
+
fprintf(stderr, "Failed to connect to %s\n", endpoint);
|
686
|
+
return nullptr;
|
687
|
+
}
|
688
|
+
size_t alignment = get_alignment(sock);
|
689
|
+
size_t max_size = get_max_size(sock);
|
690
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
691
|
+
/* .endpoint = */ endpoint,
|
692
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
693
|
+
/* .alignment = */ alignment,
|
694
|
+
/* .max_size = */ max_size
|
695
|
+
};
|
696
|
+
|
697
|
+
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
698
|
+
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
699
|
+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
700
|
+
/* .context = */ buft_ctx
|
701
|
+
};
|
702
|
+
buft_map[endpoint] = buft;
|
703
|
+
return buft;
|
704
|
+
}
|
705
|
+
|
706
|
+
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
707
|
+
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
708
|
+
/* .endpoint = */ endpoint,
|
709
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
710
|
+
};
|
711
|
+
|
712
|
+
ggml_backend_t backend = new ggml_backend {
|
713
|
+
/* .guid = */ ggml_backend_rpc_guid(),
|
714
|
+
/* .interface = */ ggml_backend_rpc_interface,
|
715
|
+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
716
|
+
/* .context = */ ctx
|
717
|
+
};
|
718
|
+
return backend;
|
719
|
+
}
|
720
|
+
|
721
|
+
bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
722
|
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
723
|
+
}
|
724
|
+
|
725
|
+
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
726
|
+
rpc_msg_get_device_memory_rsp response;
|
727
|
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
|
728
|
+
GGML_ASSERT(status);
|
729
|
+
*free = response.free_mem;
|
730
|
+
*total = response.total_mem;
|
731
|
+
}
|
732
|
+
|
733
|
+
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
734
|
+
auto sock = get_socket(endpoint);
|
735
|
+
if (sock == nullptr) {
|
736
|
+
*free = 0;
|
737
|
+
*total = 0;
|
738
|
+
return;
|
739
|
+
}
|
740
|
+
get_device_memory(sock, free, total);
|
741
|
+
}
|
742
|
+
|
743
|
+
// RPC server-side implementation
|
744
|
+
|
745
|
+
class rpc_server {
|
746
|
+
public:
|
747
|
+
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
748
|
+
~rpc_server();
|
749
|
+
|
750
|
+
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
751
|
+
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
752
|
+
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
753
|
+
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
754
|
+
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
755
|
+
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
756
|
+
bool set_tensor(const std::vector<uint8_t> & input);
|
757
|
+
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
758
|
+
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
759
|
+
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
760
|
+
|
761
|
+
private:
|
762
|
+
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
763
|
+
ggml_tensor * create_node(uint64_t id,
|
764
|
+
struct ggml_context * ctx,
|
765
|
+
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
|
766
|
+
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
|
767
|
+
|
768
|
+
|
769
|
+
ggml_backend_t backend;
|
770
|
+
std::unordered_set<ggml_backend_buffer_t> buffers;
|
771
|
+
};
|
772
|
+
|
773
|
+
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
774
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
775
|
+
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
776
|
+
response.remote_ptr = 0;
|
777
|
+
response.remote_size = 0;
|
778
|
+
if (buffer != nullptr) {
|
779
|
+
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
780
|
+
response.remote_size = buffer->size;
|
781
|
+
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
|
782
|
+
buffers.insert(buffer);
|
783
|
+
} else {
|
784
|
+
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
|
785
|
+
}
|
786
|
+
}
|
787
|
+
|
788
|
+
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
|
789
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
790
|
+
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
791
|
+
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
792
|
+
response.alignment = alignment;
|
793
|
+
}
|
794
|
+
|
795
|
+
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
|
796
|
+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
797
|
+
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
798
|
+
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
799
|
+
response.max_size = max_size;
|
800
|
+
}
|
801
|
+
|
802
|
+
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
803
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
804
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
805
|
+
if (buffers.find(buffer) == buffers.end()) {
|
806
|
+
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
807
|
+
return false;
|
808
|
+
}
|
809
|
+
void * base = ggml_backend_buffer_get_base(buffer);
|
810
|
+
response.base_ptr = reinterpret_cast<uint64_t>(base);
|
811
|
+
return true;
|
812
|
+
}
|
813
|
+
|
814
|
+
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
|
815
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
816
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
817
|
+
if (buffers.find(buffer) == buffers.end()) {
|
818
|
+
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
819
|
+
return false;
|
820
|
+
}
|
821
|
+
ggml_backend_buffer_free(buffer);
|
822
|
+
buffers.erase(buffer);
|
823
|
+
return true;
|
824
|
+
}
|
825
|
+
|
826
|
+
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
827
|
+
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
|
828
|
+
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
829
|
+
if (buffers.find(buffer) == buffers.end()) {
|
830
|
+
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
831
|
+
return false;
|
832
|
+
}
|
833
|
+
ggml_backend_buffer_clear(buffer, request.value);
|
834
|
+
return true;
|
835
|
+
}
|
836
|
+
|
837
|
+
ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
|
838
|
+
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
|
839
|
+
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
840
|
+
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
841
|
+
result->nb[i] = tensor->nb[i];
|
842
|
+
}
|
843
|
+
result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
|
844
|
+
if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
|
845
|
+
result->buffer = nullptr;
|
846
|
+
}
|
847
|
+
|
848
|
+
if (result->buffer) {
|
849
|
+
// require that the tensor data does not go beyond the buffer end
|
850
|
+
uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
|
851
|
+
uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
|
852
|
+
uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
|
853
|
+
GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
|
854
|
+
GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
|
855
|
+
}
|
856
|
+
|
857
|
+
result->op = (ggml_op) tensor->op;
|
858
|
+
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
|
859
|
+
result->op_params[i] = tensor->op_params[i];
|
860
|
+
}
|
861
|
+
result->flags = tensor->flags;
|
862
|
+
result->data = reinterpret_cast<void *>(tensor->data);
|
863
|
+
ggml_set_name(result, tensor->name);
|
864
|
+
return result;
|
865
|
+
}
|
866
|
+
|
867
|
+
|
868
|
+
bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
869
|
+
// serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
|
870
|
+
if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
|
871
|
+
return false;
|
872
|
+
}
|
873
|
+
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
874
|
+
uint64_t offset;
|
875
|
+
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
876
|
+
const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
|
877
|
+
|
878
|
+
struct ggml_init_params params {
|
879
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
880
|
+
/*.mem_buffer =*/ NULL,
|
881
|
+
/*.no_alloc =*/ true,
|
882
|
+
};
|
883
|
+
struct ggml_context * ctx = ggml_init(params);
|
884
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
885
|
+
if (tensor == nullptr) {
|
886
|
+
GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
|
887
|
+
ggml_free(ctx);
|
888
|
+
return false;
|
889
|
+
}
|
890
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
891
|
+
|
892
|
+
// sanitize tensor->data
|
893
|
+
{
|
894
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
895
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
896
|
+
|
897
|
+
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
898
|
+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
899
|
+
}
|
900
|
+
}
|
901
|
+
|
902
|
+
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
|
903
|
+
ggml_backend_tensor_set(tensor, data, offset, size);
|
904
|
+
ggml_free(ctx);
|
905
|
+
return true;
|
906
|
+
}
|
907
|
+
|
908
|
+
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
|
909
|
+
struct ggml_init_params params {
|
910
|
+
/*.mem_size =*/ ggml_tensor_overhead(),
|
911
|
+
/*.mem_buffer =*/ NULL,
|
912
|
+
/*.no_alloc =*/ true,
|
913
|
+
};
|
914
|
+
struct ggml_context * ctx = ggml_init(params);
|
915
|
+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
916
|
+
if (tensor == nullptr) {
|
917
|
+
GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
|
918
|
+
ggml_free(ctx);
|
919
|
+
return false;
|
920
|
+
}
|
921
|
+
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
922
|
+
|
923
|
+
// sanitize tensor->data
|
924
|
+
{
|
925
|
+
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
926
|
+
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
927
|
+
|
928
|
+
if (request.tensor.data + request.offset < p0 ||
|
929
|
+
request.tensor.data + request.offset >= p1 ||
|
930
|
+
request.size > (p1 - request.tensor.data - request.offset)) {
|
931
|
+
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
932
|
+
}
|
933
|
+
}
|
934
|
+
|
935
|
+
response.resize(request.size, 0);
|
936
|
+
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
|
937
|
+
ggml_free(ctx);
|
938
|
+
return true;
|
939
|
+
}
|
940
|
+
|
941
|
+
bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
|
942
|
+
struct ggml_init_params params {
|
943
|
+
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
944
|
+
/*.mem_buffer =*/ NULL,
|
945
|
+
/*.no_alloc =*/ true,
|
946
|
+
};
|
947
|
+
struct ggml_context * ctx = ggml_init(params);
|
948
|
+
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
949
|
+
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
950
|
+
if (src == nullptr || dst == nullptr) {
|
951
|
+
GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
|
952
|
+
ggml_free(ctx);
|
953
|
+
return false;
|
954
|
+
}
|
955
|
+
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
|
956
|
+
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
957
|
+
ggml_free(ctx);
|
958
|
+
return true;
|
959
|
+
}
|
960
|
+
|
961
|
+
ggml_tensor * rpc_server::create_node(uint64_t id,
|
962
|
+
struct ggml_context * ctx,
|
963
|
+
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
|
964
|
+
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
|
965
|
+
if (id == 0) {
|
966
|
+
return nullptr;
|
967
|
+
}
|
968
|
+
if (tensor_map.find(id) != tensor_map.end()) {
|
969
|
+
return tensor_map[id];
|
970
|
+
}
|
971
|
+
const rpc_tensor * tensor = tensor_ptrs.at(id);
|
972
|
+
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
|
973
|
+
if (result == nullptr) {
|
974
|
+
return nullptr;
|
975
|
+
}
|
976
|
+
tensor_map[id] = result;
|
977
|
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
978
|
+
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
|
979
|
+
}
|
980
|
+
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
|
981
|
+
result->view_offs = tensor->view_offs;
|
982
|
+
return result;
|
983
|
+
}
|
984
|
+
|
985
|
+
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
986
|
+
// serialization format:
|
987
|
+
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
988
|
+
if (input.size() < sizeof(uint32_t)) {
|
989
|
+
return false;
|
990
|
+
}
|
991
|
+
uint32_t n_nodes;
|
992
|
+
memcpy(&n_nodes, input.data(), sizeof(n_nodes));
|
993
|
+
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
|
994
|
+
return false;
|
995
|
+
}
|
996
|
+
const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
|
997
|
+
uint32_t n_tensors;
|
998
|
+
memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
|
999
|
+
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
|
1000
|
+
return false;
|
1001
|
+
}
|
1002
|
+
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
|
1003
|
+
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
|
1004
|
+
|
1005
|
+
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
1006
|
+
struct ggml_init_params params = {
|
1007
|
+
/*.mem_size =*/ buf_size,
|
1008
|
+
/*.mem_buffer =*/ NULL,
|
1009
|
+
/*.no_alloc =*/ true,
|
1010
|
+
};
|
1011
|
+
struct ggml_context * ctx = ggml_init(params);
|
1012
|
+
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
|
1013
|
+
graph->n_nodes = n_nodes;
|
1014
|
+
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
|
1015
|
+
for (uint32_t i = 0; i < n_tensors; i++) {
|
1016
|
+
tensor_ptrs[tensors[i].id] = &tensors[i];
|
1017
|
+
}
|
1018
|
+
std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
|
1019
|
+
for (uint32_t i = 0; i < n_nodes; i++) {
|
1020
|
+
int64_t id;
|
1021
|
+
memcpy(&id, &nodes[i], sizeof(id));
|
1022
|
+
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
1023
|
+
}
|
1024
|
+
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
1025
|
+
response.result = status;
|
1026
|
+
ggml_free(ctx);
|
1027
|
+
return true;
|
1028
|
+
}
|
1029
|
+
|
1030
|
+
rpc_server::~rpc_server() {
|
1031
|
+
for (auto buffer : buffers) {
|
1032
|
+
ggml_backend_buffer_free(buffer);
|
1033
|
+
}
|
1034
|
+
}
|
1035
|
+
|
1036
|
+
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
1037
|
+
rpc_server server(backend);
|
1038
|
+
while (true) {
|
1039
|
+
uint8_t cmd;
|
1040
|
+
if (!recv_data(sockfd, &cmd, 1)) {
|
1041
|
+
break;
|
1042
|
+
}
|
1043
|
+
if (cmd >= RPC_CMD_COUNT) {
|
1044
|
+
// fail fast if the command is invalid
|
1045
|
+
fprintf(stderr, "Unknown command: %d\n", cmd);
|
1046
|
+
break;
|
1047
|
+
}
|
1048
|
+
switch (cmd) {
|
1049
|
+
case RPC_CMD_ALLOC_BUFFER: {
|
1050
|
+
rpc_msg_alloc_buffer_req request;
|
1051
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1052
|
+
return;
|
1053
|
+
}
|
1054
|
+
rpc_msg_alloc_buffer_rsp response;
|
1055
|
+
server.alloc_buffer(request, response);
|
1056
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1057
|
+
return;
|
1058
|
+
}
|
1059
|
+
break;
|
1060
|
+
}
|
1061
|
+
case RPC_CMD_GET_ALIGNMENT: {
|
1062
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
1063
|
+
return;
|
1064
|
+
}
|
1065
|
+
rpc_msg_get_alignment_rsp response;
|
1066
|
+
server.get_alignment(response);
|
1067
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1068
|
+
return;
|
1069
|
+
}
|
1070
|
+
break;
|
1071
|
+
}
|
1072
|
+
case RPC_CMD_GET_MAX_SIZE: {
|
1073
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
1074
|
+
return;
|
1075
|
+
}
|
1076
|
+
rpc_msg_get_max_size_rsp response;
|
1077
|
+
server.get_max_size(response);
|
1078
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1079
|
+
return;
|
1080
|
+
}
|
1081
|
+
break;
|
1082
|
+
}
|
1083
|
+
case RPC_CMD_BUFFER_GET_BASE: {
|
1084
|
+
rpc_msg_buffer_get_base_req request;
|
1085
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1086
|
+
return;
|
1087
|
+
}
|
1088
|
+
rpc_msg_buffer_get_base_rsp response;
|
1089
|
+
if (!server.buffer_get_base(request, response)) {
|
1090
|
+
return;
|
1091
|
+
}
|
1092
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1093
|
+
return;
|
1094
|
+
}
|
1095
|
+
break;
|
1096
|
+
}
|
1097
|
+
case RPC_CMD_FREE_BUFFER: {
|
1098
|
+
rpc_msg_free_buffer_req request;
|
1099
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1100
|
+
return;
|
1101
|
+
}
|
1102
|
+
if (!server.free_buffer(request)) {
|
1103
|
+
return;
|
1104
|
+
}
|
1105
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
1106
|
+
return;
|
1107
|
+
}
|
1108
|
+
break;
|
1109
|
+
}
|
1110
|
+
case RPC_CMD_BUFFER_CLEAR: {
|
1111
|
+
rpc_msg_buffer_clear_req request;
|
1112
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1113
|
+
return;
|
1114
|
+
}
|
1115
|
+
if (!server.buffer_clear(request)) {
|
1116
|
+
return;
|
1117
|
+
}
|
1118
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
1119
|
+
return;
|
1120
|
+
}
|
1121
|
+
break;
|
1122
|
+
}
|
1123
|
+
case RPC_CMD_SET_TENSOR: {
|
1124
|
+
std::vector<uint8_t> input;
|
1125
|
+
if (!recv_msg(sockfd, input)) {
|
1126
|
+
return;
|
1127
|
+
}
|
1128
|
+
if (!server.set_tensor(input)) {
|
1129
|
+
return;
|
1130
|
+
}
|
1131
|
+
if (!send_msg(sockfd, nullptr, 0)) {
|
1132
|
+
return;
|
1133
|
+
}
|
1134
|
+
break;
|
1135
|
+
}
|
1136
|
+
case RPC_CMD_GET_TENSOR: {
|
1137
|
+
rpc_msg_get_tensor_req request;
|
1138
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1139
|
+
return;
|
1140
|
+
}
|
1141
|
+
std::vector<uint8_t> response;
|
1142
|
+
if (!server.get_tensor(request, response)) {
|
1143
|
+
return;
|
1144
|
+
}
|
1145
|
+
if (!send_msg(sockfd, response.data(), response.size())) {
|
1146
|
+
return;
|
1147
|
+
}
|
1148
|
+
break;
|
1149
|
+
}
|
1150
|
+
case RPC_CMD_COPY_TENSOR: {
|
1151
|
+
rpc_msg_copy_tensor_req request;
|
1152
|
+
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
1153
|
+
return;
|
1154
|
+
}
|
1155
|
+
rpc_msg_copy_tensor_rsp response;
|
1156
|
+
if (!server.copy_tensor(request, response)) {
|
1157
|
+
return;
|
1158
|
+
}
|
1159
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1160
|
+
return;
|
1161
|
+
}
|
1162
|
+
break;
|
1163
|
+
}
|
1164
|
+
case RPC_CMD_GRAPH_COMPUTE: {
|
1165
|
+
std::vector<uint8_t> input;
|
1166
|
+
if (!recv_msg(sockfd, input)) {
|
1167
|
+
return;
|
1168
|
+
}
|
1169
|
+
rpc_msg_graph_compute_rsp response;
|
1170
|
+
if (!server.graph_compute(input, response)) {
|
1171
|
+
return;
|
1172
|
+
}
|
1173
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1174
|
+
return;
|
1175
|
+
}
|
1176
|
+
break;
|
1177
|
+
}
|
1178
|
+
case RPC_CMD_GET_DEVICE_MEMORY: {
|
1179
|
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
1180
|
+
return;
|
1181
|
+
}
|
1182
|
+
rpc_msg_get_device_memory_rsp response;
|
1183
|
+
response.free_mem = free_mem;
|
1184
|
+
response.total_mem = total_mem;
|
1185
|
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
1186
|
+
return;
|
1187
|
+
}
|
1188
|
+
break;
|
1189
|
+
}
|
1190
|
+
default: {
|
1191
|
+
fprintf(stderr, "Unknown command: %d\n", cmd);
|
1192
|
+
return;
|
1193
|
+
}
|
1194
|
+
}
|
1195
|
+
}
|
1196
|
+
}
|
1197
|
+
|
1198
|
+
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
1199
|
+
std::string host;
|
1200
|
+
int port;
|
1201
|
+
if (!parse_endpoint(endpoint, host, port)) {
|
1202
|
+
return;
|
1203
|
+
}
|
1204
|
+
#ifdef _WIN32
|
1205
|
+
{
|
1206
|
+
WSADATA wsaData;
|
1207
|
+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
1208
|
+
if (res != 0) {
|
1209
|
+
fprintf(stderr, "WSAStartup failed: %d\n", res);
|
1210
|
+
return;
|
1211
|
+
}
|
1212
|
+
}
|
1213
|
+
#endif
|
1214
|
+
auto server_socket = create_server_socket(host.c_str(), port);
|
1215
|
+
if (server_socket == nullptr) {
|
1216
|
+
fprintf(stderr, "Failed to create server socket\n");
|
1217
|
+
return;
|
1218
|
+
}
|
1219
|
+
while (true) {
|
1220
|
+
auto client_socket = socket_accept(server_socket->fd);
|
1221
|
+
if (client_socket == nullptr) {
|
1222
|
+
fprintf(stderr, "Failed to accept client connection\n");
|
1223
|
+
return;
|
1224
|
+
}
|
1225
|
+
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
1226
|
+
fflush(stdout);
|
1227
|
+
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
1228
|
+
printf("Client connection closed\n");
|
1229
|
+
fflush(stdout);
|
1230
|
+
}
|
1231
|
+
#ifdef _WIN32
|
1232
|
+
WSACleanup();
|
1233
|
+
#endif
|
1234
|
+
}
|
1235
|
+
|
1236
|
+
// device interface
|
1237
|
+
|
1238
|
+
struct ggml_backend_rpc_device_context {
|
1239
|
+
std::string endpoint;
|
1240
|
+
std::string name;
|
1241
|
+
};
|
1242
|
+
|
1243
|
+
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
1244
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1245
|
+
|
1246
|
+
return ctx->name.c_str();
|
1247
|
+
}
|
1248
|
+
|
1249
|
+
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
|
1250
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1251
|
+
|
1252
|
+
return ctx->name.c_str();
|
1253
|
+
}
|
1254
|
+
|
1255
|
+
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
1256
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1257
|
+
|
1258
|
+
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
|
1259
|
+
|
1260
|
+
UNUSED(dev);
|
1261
|
+
}
|
1262
|
+
|
1263
|
+
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
1264
|
+
// TODO: obtain value from the server
|
1265
|
+
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
1266
|
+
|
1267
|
+
UNUSED(dev);
|
1268
|
+
}
|
1269
|
+
|
1270
|
+
static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
1271
|
+
props->name = ggml_backend_rpc_device_get_name(dev);
|
1272
|
+
props->description = ggml_backend_rpc_device_get_description(dev);
|
1273
|
+
props->type = ggml_backend_rpc_device_get_type(dev);
|
1274
|
+
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
1275
|
+
props->caps = {
|
1276
|
+
/* .async = */ false,
|
1277
|
+
/* .host_buffer = */ false,
|
1278
|
+
/* .buffer_from_host_ptr = */ false,
|
1279
|
+
/* .events = */ false,
|
1280
|
+
};
|
1281
|
+
}
|
1282
|
+
|
1283
|
+
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
|
1284
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1285
|
+
|
1286
|
+
return ggml_backend_rpc_init(ctx->endpoint.c_str());
|
1287
|
+
|
1288
|
+
UNUSED(params);
|
1289
|
+
}
|
1290
|
+
|
1291
|
+
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
|
1292
|
+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1293
|
+
|
1294
|
+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
1295
|
+
|
1296
|
+
UNUSED(dev);
|
1297
|
+
}
|
1298
|
+
|
1299
|
+
static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
1300
|
+
UNUSED(dev);
|
1301
|
+
UNUSED(op);
|
1302
|
+
//TODO: call the remote backend and cache the results
|
1303
|
+
return true;
|
1304
|
+
}
|
1305
|
+
|
1306
|
+
static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
1307
|
+
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
1308
|
+
return false;
|
1309
|
+
}
|
1310
|
+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
1311
|
+
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
|
1312
|
+
return buft_ctx->endpoint == dev_ctx->endpoint;
|
1313
|
+
}
|
1314
|
+
|
1315
|
+
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
1316
|
+
/* .get_name = */ ggml_backend_rpc_device_get_name,
|
1317
|
+
/* .get_description = */ ggml_backend_rpc_device_get_description,
|
1318
|
+
/* .get_memory = */ ggml_backend_rpc_device_get_memory,
|
1319
|
+
/* .get_type = */ ggml_backend_rpc_device_get_type,
|
1320
|
+
/* .get_props = */ ggml_backend_rpc_device_get_props,
|
1321
|
+
/* .init_backend = */ ggml_backend_rpc_device_init,
|
1322
|
+
/* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
|
1323
|
+
/* .get_host_buffer_type = */ NULL,
|
1324
|
+
/* .buffer_from_host_ptr = */ NULL,
|
1325
|
+
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
|
1326
|
+
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
|
1327
|
+
/* .offload_op = */ NULL,
|
1328
|
+
/* .event_new = */ NULL,
|
1329
|
+
/* .event_free = */ NULL,
|
1330
|
+
/* .event_synchronize = */ NULL,
|
1331
|
+
};
|
1332
|
+
|
1333
|
+
// backend reg interface
|
1334
|
+
|
1335
|
+
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
1336
|
+
return "RPC";
|
1337
|
+
|
1338
|
+
UNUSED(reg);
|
1339
|
+
}
|
1340
|
+
|
1341
|
+
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
|
1342
|
+
return 0;
|
1343
|
+
|
1344
|
+
UNUSED(reg);
|
1345
|
+
}
|
1346
|
+
|
1347
|
+
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
1348
|
+
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
|
1349
|
+
|
1350
|
+
UNUSED(reg);
|
1351
|
+
UNUSED(index);
|
1352
|
+
}
|
1353
|
+
|
1354
|
+
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
1355
|
+
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
|
1356
|
+
return (void *)ggml_backend_rpc_add_device;
|
1357
|
+
}
|
1358
|
+
return NULL;
|
1359
|
+
|
1360
|
+
UNUSED(reg);
|
1361
|
+
}
|
1362
|
+
|
1363
|
+
static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
|
1364
|
+
/* .get_name = */ ggml_backend_rpc_reg_get_name,
|
1365
|
+
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
|
1366
|
+
/* .get_device = */ ggml_backend_rpc_reg_get_device,
|
1367
|
+
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
|
1368
|
+
};
|
1369
|
+
|
1370
|
+
ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
1371
|
+
static struct ggml_backend_reg ggml_backend_rpc_reg = {
|
1372
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
1373
|
+
/* .iface = */ ggml_backend_rpc_reg_i,
|
1374
|
+
/* .context = */ NULL,
|
1375
|
+
};
|
1376
|
+
|
1377
|
+
return &ggml_backend_rpc_reg;
|
1378
|
+
}
|
1379
|
+
|
1380
|
+
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
|
1381
|
+
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
|
1382
|
+
|
1383
|
+
static std::mutex mutex;
|
1384
|
+
std::lock_guard<std::mutex> lock(mutex);
|
1385
|
+
|
1386
|
+
if (dev_map.find(endpoint) != dev_map.end()) {
|
1387
|
+
return dev_map[endpoint];
|
1388
|
+
}
|
1389
|
+
|
1390
|
+
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
|
1391
|
+
/* .endpoint = */ endpoint,
|
1392
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
1393
|
+
};
|
1394
|
+
|
1395
|
+
ggml_backend_dev_t dev = new ggml_backend_device {
|
1396
|
+
/* .iface = */ ggml_backend_rpc_device_i,
|
1397
|
+
/* .reg = */ ggml_backend_rpc_reg(),
|
1398
|
+
/* .context = */ ctx,
|
1399
|
+
};
|
1400
|
+
|
1401
|
+
dev_map[endpoint] = dev;
|
1402
|
+
|
1403
|
+
return dev;
|
1404
|
+
}
|
1405
|
+
|
1406
|
+
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
|