llama_cpp 0.15.1 → 0.15.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1032 @@
1
+ #include "ggml-rpc.h"
2
+ #include "ggml.h"
3
+ #include "ggml-backend-impl.h"
4
+
5
+ #include <cinttypes>
6
+ #include <string>
7
+ #include <vector>
8
+ #include <memory>
9
+ #include <unordered_map>
10
+ #include <unordered_set>
11
+ #ifdef _WIN32
12
+ # define WIN32_LEAN_AND_MEAN
13
+ # ifndef NOMINMAX
14
+ # define NOMINMAX
15
+ # endif
16
+ # include <windows.h>
17
+ # include <winsock2.h>
18
+ #else
19
+ # include <arpa/inet.h>
20
+ # include <sys/socket.h>
21
+ # include <sys/types.h>
22
+ # include <netinet/in.h>
23
+ # include <netinet/tcp.h>
24
+ # include <netdb.h>
25
+ # include <unistd.h>
26
+ #endif
27
+ #include <string.h>
28
+
29
+ #define UNUSED GGML_UNUSED
30
+
31
+ #define GGML_DEBUG 0
32
+ #if (GGML_DEBUG >= 1)
33
+ #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
34
+ #else
35
+ #define GGML_PRINT_DEBUG(...)
36
+ #endif
37
+
38
+ #ifdef _WIN32
39
+ typedef SOCKET sockfd_t;
40
+ using ssize_t = __int64;
41
+ #else
42
+ typedef int sockfd_t;
43
+ #endif
44
+
45
+ // cross-platform socket
46
+ struct socket_t {
47
+ sockfd_t fd;
48
+ socket_t(sockfd_t fd) : fd(fd) {}
49
+ ~socket_t() {
50
+ #ifdef _WIN32
51
+ closesocket(this->fd);
52
+ #else
53
+ close(this->fd);
54
+ #endif
55
+ }
56
+ };
57
+
58
+ // ggml_tensor is serialized into rpc_tensor
59
+ struct rpc_tensor {
60
+ uint64_t id;
61
+ uint32_t type;
62
+ uint64_t buffer;
63
+ uint32_t ne[GGML_MAX_DIMS];
64
+ uint32_t nb[GGML_MAX_DIMS];
65
+ uint32_t op;
66
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
67
+ int32_t flags;
68
+ uint64_t src[GGML_MAX_SRC];
69
+ uint64_t view_src;
70
+ uint64_t view_offs;
71
+ uint64_t data;
72
+ char name[GGML_MAX_NAME];
73
+ };
74
+
75
+ // RPC commands
76
+ enum rpc_cmd {
77
+ ALLOC_BUFFER = 0,
78
+ GET_ALIGNMENT,
79
+ GET_MAX_SIZE,
80
+ BUFFER_GET_BASE,
81
+ FREE_BUFFER,
82
+ BUFFER_CLEAR,
83
+ SET_TENSOR,
84
+ GET_TENSOR,
85
+ COPY_TENSOR,
86
+ GRAPH_COMPUTE,
87
+ GET_DEVICE_MEMORY,
88
+ };
89
+
90
+ // RPC data structures
91
+
92
+ static ggml_guid_t ggml_backend_rpc_guid() {
93
+ static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
94
+ return &guid;
95
+ }
96
+
97
+ struct ggml_backend_rpc_buffer_type_context {
98
+ std::shared_ptr<socket_t> sock;
99
+ std::string name;
100
+ size_t alignment;
101
+ size_t max_size;
102
+ };
103
+
104
+ struct ggml_backend_rpc_context {
105
+ std::string endpoint;
106
+ std::string name;
107
+ std::shared_ptr<socket_t> sock;
108
+ ggml_backend_buffer_type_t buft;
109
+ };
110
+
111
+ struct ggml_backend_rpc_buffer_context {
112
+ std::shared_ptr<socket_t> sock;
113
+ std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
114
+ uint64_t remote_ptr;
115
+ std::string name;
116
+ };
117
+
118
+ // RPC helper functions
119
+
120
+ static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
121
+ #ifdef _WIN32
122
+ if (fd == INVALID_SOCKET) {
123
+ return nullptr;
124
+ }
125
+ #else
126
+ if (fd < 0) {
127
+ return nullptr;
128
+ }
129
+ #endif
130
+ return std::make_shared<socket_t>(fd);
131
+ }
132
+
133
+ static bool set_no_delay(sockfd_t sockfd) {
134
+ int flag = 1;
135
+ // set TCP_NODELAY to disable Nagle's algorithm
136
+ int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
137
+ return ret == 0;
138
+ }
139
+
140
+ static bool set_reuse_addr(sockfd_t sockfd) {
141
+ int flag = 1;
142
+ int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
143
+ return ret == 0;
144
+ }
145
+
146
+ static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
147
+ struct sockaddr_in addr;
148
+ auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
149
+ auto sock_ptr = make_socket(sockfd);
150
+ if (sock_ptr == nullptr) {
151
+ return nullptr;
152
+ }
153
+ if (!set_no_delay(sockfd)) {
154
+ fprintf(stderr, "Failed to set TCP_NODELAY\n");
155
+ return nullptr;
156
+ }
157
+ addr.sin_family = AF_INET;
158
+ addr.sin_port = htons(port);
159
+ struct hostent * server = gethostbyname(host);
160
+ if (server == NULL) {
161
+ fprintf(stderr, "Cannot resolve host '%s'\n", host);
162
+ return nullptr;
163
+ }
164
+ memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
165
+ if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
166
+ return nullptr;
167
+ }
168
+ return sock_ptr;
169
+ }
170
+
171
+ static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
172
+ auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
173
+ auto client_socket = make_socket(client_socket_fd);
174
+ if (client_socket == nullptr) {
175
+ return nullptr;
176
+ }
177
+ if (!set_no_delay(client_socket_fd)) {
178
+ fprintf(stderr, "Failed to set TCP_NODELAY\n");
179
+ return nullptr;
180
+ }
181
+ return client_socket;
182
+ }
183
+
184
+ static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
185
+ auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
186
+ auto sock = make_socket(sockfd);
187
+ if (sock == nullptr) {
188
+ return nullptr;
189
+ }
190
+ if (!set_reuse_addr(sockfd)) {
191
+ fprintf(stderr, "Failed to set SO_REUSEADDR\n");
192
+ return nullptr;
193
+ }
194
+ struct sockaddr_in serv_addr;
195
+ serv_addr.sin_family = AF_INET;
196
+ serv_addr.sin_addr.s_addr = inet_addr(host);
197
+ serv_addr.sin_port = htons(port);
198
+
199
+ if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
200
+ return nullptr;
201
+ }
202
+ if (listen(sockfd, 1) < 0) {
203
+ return nullptr;
204
+ }
205
+ return sock;
206
+ }
207
+
208
+ static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
209
+ size_t bytes_sent = 0;
210
+ while (bytes_sent < size) {
211
+ ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
212
+ if (n < 0) {
213
+ return false;
214
+ }
215
+ bytes_sent += n;
216
+ }
217
+ return true;
218
+ }
219
+
220
+ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
221
+ size_t bytes_recv = 0;
222
+ while (bytes_recv < size) {
223
+ ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
224
+ if (n <= 0) {
225
+ return false;
226
+ }
227
+ bytes_recv += n;
228
+ }
229
+ return true;
230
+ }
231
+
232
+ static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
233
+ std::string str(endpoint);
234
+ size_t pos = str.find(':');
235
+ if (pos == std::string::npos) {
236
+ return false;
237
+ }
238
+ host = str.substr(0, pos);
239
+ port = std::stoi(str.substr(pos + 1));
240
+ return true;
241
+ }
242
+
243
+ // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
244
+ // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
245
+ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
246
+ uint8_t cmd_byte = cmd;
247
+ if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
248
+ return false;
249
+ }
250
+ uint64_t input_size = input.size();
251
+ if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
252
+ return false;
253
+ }
254
+ if (!send_data(sock->fd, input.data(), input.size())) {
255
+ return false;
256
+ }
257
+ uint64_t output_size;
258
+ if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
259
+ return false;
260
+ }
261
+ if (output_size == 0) {
262
+ output.clear();
263
+ return true;
264
+ }
265
+ output.resize(output_size);
266
+ if (!recv_data(sock->fd, output.data(), output_size)) {
267
+ return false;
268
+ }
269
+ return true;
270
+ }
271
+
272
+ // RPC client-side implementation
273
+
274
+ GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
275
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
276
+ return ctx->name.c_str();
277
+ }
278
+
279
+ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
280
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
281
+ // input serialization format: | remote_ptr (8 bytes) |
282
+ std::vector<uint8_t> input(sizeof(uint64_t), 0);
283
+ uint64_t remote_ptr = ctx->remote_ptr;
284
+ memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
285
+ std::vector<uint8_t> output;
286
+ bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
287
+ GGML_ASSERT(status);
288
+ GGML_ASSERT(output.empty());
289
+ delete ctx;
290
+ }
291
+
292
+ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
293
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
294
+ if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
295
+ return ctx->base_cache[buffer];
296
+ }
297
+ // input serialization format: | remote_ptr (8 bytes) |
298
+ std::vector<uint8_t> input(sizeof(uint64_t), 0);
299
+ uint64_t remote_ptr = ctx->remote_ptr;
300
+ memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
301
+ std::vector<uint8_t> output;
302
+ bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
303
+ GGML_ASSERT(status);
304
+ GGML_ASSERT(output.size() == sizeof(uint64_t));
305
+ // output serialization format: | base_ptr (8 bytes) |
306
+ uint64_t base_ptr;
307
+ memcpy(&base_ptr, output.data(), sizeof(base_ptr));
308
+ void * base = reinterpret_cast<void *>(base_ptr);
309
+ ctx->base_cache[buffer] = base;
310
+ return base;
311
+ }
312
+
313
+ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
314
+ rpc_tensor result;
315
+ result.id = reinterpret_cast<uint64_t>(tensor);
316
+ result.type = tensor->type;
317
+ if (tensor->buffer) {
318
+ ggml_backend_buffer_t buffer = tensor->buffer;
319
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
320
+ result.buffer = ctx->remote_ptr;
321
+ } else {
322
+ result.buffer = 0;
323
+ }
324
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
325
+ result.ne[i] = tensor->ne[i];
326
+ result.nb[i] = tensor->nb[i];
327
+ }
328
+ result.op = tensor->op;
329
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
330
+ result.op_params[i] = tensor->op_params[i];
331
+ }
332
+ result.flags = tensor->flags;
333
+ for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
334
+ result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
335
+ }
336
+ result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
337
+ result.view_offs = tensor->view_offs;
338
+ result.data = reinterpret_cast<uint64_t>(tensor->data);
339
+ snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
340
+ return result;
341
+ }
342
+
343
+ static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
344
+ ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
345
+ tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
346
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
347
+ result->nb[i] = tensor->nb[i];
348
+ }
349
+ result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
350
+ result->op = (ggml_op) tensor->op;
351
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
352
+ result->op_params[i] = tensor->op_params[i];
353
+ }
354
+ result->flags = tensor->flags;
355
+ result->data = reinterpret_cast<void *>(tensor->data);
356
+ ggml_set_name(result, tensor->name);
357
+ return result;
358
+ }
359
+
360
+ GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
361
+ UNUSED(buffer);
362
+ if (ggml_is_quantized(tensor->type)) {
363
+ // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
364
+ GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
365
+ }
366
+ }
367
+
368
+ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
369
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
370
+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
371
+ size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
372
+ std::vector<uint8_t> input(input_size, 0);
373
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
374
+ memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
375
+ memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
376
+ memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
377
+ std::vector<uint8_t> output;
378
+ bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
379
+ GGML_ASSERT(status);
380
+ }
381
+
382
+ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
383
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
384
+ // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
385
+ int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
386
+ std::vector<uint8_t> input(input_size, 0);
387
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
388
+ memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
389
+ memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
390
+ memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
391
+ std::vector<uint8_t> output;
392
+ bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
393
+ GGML_ASSERT(status);
394
+ GGML_ASSERT(output.size() == size);
395
+ // output serialization format: | data (size bytes) |
396
+ memcpy(data, output.data(), size);
397
+ }
398
+
399
+ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
400
+ // check if src and dst are on the same server
401
+ ggml_backend_buffer_t src_buffer = src->buffer;
402
+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
403
+ ggml_backend_buffer_t dst_buffer = dst->buffer;
404
+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
405
+ if (src_ctx->sock != dst_ctx->sock) {
406
+ return false;
407
+ }
408
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
409
+ // input serialization format: | rpc_tensor src | rpc_tensor dst |
410
+ int input_size = 2*sizeof(rpc_tensor);
411
+ std::vector<uint8_t> input(input_size, 0);
412
+ rpc_tensor rpc_src = serialize_tensor(src);
413
+ rpc_tensor rpc_dst = serialize_tensor(dst);
414
+ memcpy(input.data(), &rpc_src, sizeof(rpc_src));
415
+ memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
416
+ std::vector<uint8_t> output;
417
+ bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
418
+ GGML_ASSERT(status);
419
+ // output serialization format: | result (1 byte) |
420
+ GGML_ASSERT(output.size() == 1);
421
+ return output[0];
422
+ }
423
+
424
+ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
425
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
426
+ // serialization format: | bufptr (8 bytes) | value (1 byte) |
427
+ int input_size = sizeof(uint64_t) + sizeof(uint8_t);
428
+ std::vector<uint8_t> input(input_size, 0);
429
+ memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
430
+ memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
431
+ std::vector<uint8_t> output;
432
+ bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
433
+ GGML_ASSERT(status);
434
+ }
435
+
436
+ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
437
+ /* .get_name = */ ggml_backend_rpc_buffer_get_name,
438
+ /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
439
+ /* .get_base = */ ggml_backend_rpc_buffer_get_base,
440
+ /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
441
+ /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
442
+ /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
443
+ /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
444
+ /* .clear = */ ggml_backend_rpc_buffer_clear,
445
+ /* .reset = */ NULL,
446
+ };
447
+
448
+ GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
449
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
450
+ return buft_ctx->name.c_str();
451
+ }
452
+
453
+ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
454
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
455
+ // input serialization format: | size (8 bytes) |
456
+ int input_size = sizeof(uint64_t);
457
+ std::vector<uint8_t> input(input_size, 0);
458
+ memcpy(input.data(), &size, sizeof(size));
459
+ std::vector<uint8_t> output;
460
+ bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
461
+ GGML_ASSERT(status);
462
+ GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
463
+ // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
464
+ uint64_t remote_ptr;
465
+ memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
466
+ size_t remote_size;
467
+ memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
468
+
469
+ ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
470
+ ggml_backend_rpc_buffer_interface,
471
+ new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
472
+ remote_size);
473
+
474
+ return buffer;
475
+ }
476
+
477
+ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
478
+ // input serialization format: | 0 bytes |
479
+ std::vector<uint8_t> input;
480
+ std::vector<uint8_t> output;
481
+ bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
482
+ GGML_ASSERT(status);
483
+ GGML_ASSERT(output.size() == sizeof(uint64_t));
484
+ // output serialization format: | alignment (8 bytes) |
485
+ uint64_t alignment;
486
+ memcpy(&alignment, output.data(), sizeof(alignment));
487
+ return alignment;
488
+ }
489
+
490
+ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
491
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
492
+ return buft_ctx->alignment;
493
+ }
494
+
495
+ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
496
+ // input serialization format: | 0 bytes |
497
+ std::vector<uint8_t> input;
498
+ std::vector<uint8_t> output;
499
+ bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
500
+ GGML_ASSERT(status);
501
+ GGML_ASSERT(output.size() == sizeof(uint64_t));
502
+ // output serialization format: | max_size (8 bytes) |
503
+ uint64_t max_size;
504
+ memcpy(&max_size, output.data(), sizeof(max_size));
505
+ return max_size;
506
+ }
507
+
508
+ GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
509
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
510
+ return buft_ctx->max_size;
511
+ }
512
+
513
+ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
514
+ UNUSED(buft);
515
+ return ggml_nbytes(tensor);
516
+ }
517
+
518
+ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
519
+ if (!ggml_backend_is_rpc(backend)) {
520
+ return false;
521
+ }
522
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
523
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
524
+ return buft_ctx->sock == rpc_ctx->sock;
525
+ }
526
+
527
+ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
528
+ /* .get_name = */ ggml_backend_rpc_buffer_type_name,
529
+ /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
530
+ /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
531
+ /* .get_max_size = */ ggml_backend_rpc_get_max_size,
532
+ /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
533
+ /* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
534
+ /* .is_host = */ NULL,
535
+ };
536
+
537
+
538
+ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
539
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
540
+
541
+ return rpc_ctx->name.c_str();
542
+ }
543
+
544
+ GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
545
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
546
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
547
+ delete buft_ctx;
548
+ delete rpc_ctx->buft;
549
+ delete rpc_ctx;
550
+ delete backend;
551
+ }
552
+
553
+ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
554
+ ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
555
+ return ctx->buft;
556
+ }
557
+
558
+ GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
559
+ UNUSED(backend);
560
+ // this is no-op because we don't have any async operations
561
+ }
562
+
563
+ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
564
+ if (tensor == nullptr) {
565
+ return;
566
+ }
567
+ if (visited.find(tensor) != visited.end()) {
568
+ return;
569
+ }
570
+ visited.insert(tensor);
571
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
572
+ add_tensor(tensor->src[i], tensors, visited);
573
+ }
574
+ add_tensor(tensor->view_src, tensors, visited);
575
+ tensors.push_back(serialize_tensor(tensor));
576
+ }
577
+
578
+ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
579
+ uint32_t n_nodes = cgraph->n_nodes;
580
+ std::vector<rpc_tensor> tensors;
581
+ std::unordered_set<ggml_tensor*> visited;
582
+ for (uint32_t i = 0; i < n_nodes; i++) {
583
+ add_tensor(cgraph->nodes[i], tensors, visited);
584
+ }
585
+ // serialization format:
586
+ // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
587
+ uint32_t n_tensors = tensors.size();
588
+ int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
589
+ output.resize(output_size, 0);
590
+ memcpy(output.data(), &n_nodes, sizeof(n_nodes));
591
+ uint64_t * out_nodes = (uint64_t *)(output.data() + sizeof(n_nodes));
592
+ for (uint32_t i = 0; i < n_nodes; i++) {
593
+ out_nodes[i] = reinterpret_cast<uint64_t>(cgraph->nodes[i]);
594
+ }
595
+ uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
596
+ *out_ntensors = n_tensors;
597
+ rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
598
+ memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
599
+ }
600
+
601
+ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
602
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
603
+ std::vector<uint8_t> input;
604
+ serialize_graph(cgraph, input);
605
+ std::vector<uint8_t> output;
606
+ bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
607
+ GGML_ASSERT(status);
608
+ GGML_ASSERT(output.size() == 1);
609
+ return (enum ggml_status)output[0];
610
+ }
611
+
612
+ GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
613
+ UNUSED(backend);
614
+ UNUSED(op);
615
+ GGML_ASSERT(false && "not implemented");
616
+ return false;
617
+ }
618
+
619
+ static ggml_backend_i ggml_backend_rpc_interface = {
620
+ /* .get_name = */ ggml_backend_rpc_name,
621
+ /* .free = */ ggml_backend_rpc_free,
622
+ /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
623
+ /* .set_tensor_async = */ NULL,
624
+ /* .get_tensor_async = */ NULL,
625
+ /* .cpy_tensor_async = */ NULL,
626
+ /* .synchronize = */ ggml_backend_rpc_synchronize,
627
+ /* .graph_plan_create = */ NULL,
628
+ /* .graph_plan_free = */ NULL,
629
+ /* .graph_plan_compute = */ NULL,
630
+ /* .graph_compute = */ ggml_backend_rpc_graph_compute,
631
+ /* .supports_op = */ ggml_backend_rpc_supports_op,
632
+ /* .offload_op = */ NULL,
633
+ /* .event_new = */ NULL,
634
+ /* .event_free = */ NULL,
635
+ /* .event_record = */ NULL,
636
+ /* .event_wait = */ NULL,
637
+ /* .event_synchronize = */ NULL,
638
+ };
639
+
640
+ static std::unordered_map<std::string, ggml_backend_t> instances;
641
+
642
+ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
643
+ ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
644
+ return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
645
+ }
646
+
647
+ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
648
+ std::string endpoint_str(endpoint);
649
+ if (instances.find(endpoint_str) != instances.end()) {
650
+ return instances[endpoint_str];
651
+ }
652
+ #ifdef _WIN32
653
+ {
654
+ WSADATA wsaData;
655
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
656
+ if (res != 0) {
657
+ return nullptr;
658
+ }
659
+ }
660
+ #endif
661
+ GGML_PRINT_DEBUG("Connecting to %s\n", endpoint);
662
+ std::string host;
663
+ int port;
664
+ if (!parse_endpoint(endpoint, host, port)) {
665
+ return nullptr;
666
+ }
667
+ auto sock = socket_connect(host.c_str(), port);
668
+ if (sock == nullptr) {
669
+ return nullptr;
670
+ }
671
+ size_t alignment = get_alignment(sock);
672
+ size_t max_size = get_max_size(sock);
673
+ ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
674
+ /* .sock = */ sock,
675
+ /* .name = */ "RPC" + std::to_string(sock->fd),
676
+ /* .alignment = */ alignment,
677
+ /* .max_size = */ max_size
678
+ };
679
+
680
+ ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
681
+ /* .iface = */ ggml_backend_rpc_buffer_type_interface,
682
+ /* .context = */ buft_ctx
683
+ };
684
+
685
+ ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
686
+ /* .endpoint = */ endpoint,
687
+ /* .name = */ "RPC" + std::to_string(sock->fd),
688
+ /* .sock = */ sock,
689
+ /* .buft = */ buft
690
+ };
691
+
692
+ instances[endpoint] = new ggml_backend {
693
+ /* .guid = */ ggml_backend_rpc_guid(),
694
+ /* .interface = */ ggml_backend_rpc_interface,
695
+ /* .context = */ ctx
696
+ };
697
+
698
+ return instances[endpoint];
699
+ }
700
+
701
+ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
702
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
703
+ }
704
+
705
+ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
706
+ // input serialization format: | 0 bytes |
707
+ std::vector<uint8_t> input;
708
+ std::vector<uint8_t> output;
709
+ bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
710
+ GGML_ASSERT(status);
711
+ GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
712
+ // output serialization format: | free (8 bytes) | total (8 bytes) |
713
+ uint64_t free_mem;
714
+ memcpy(&free_mem, output.data(), sizeof(free_mem));
715
+ uint64_t total_mem;
716
+ memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
717
+ *free = free_mem;
718
+ *total = total_mem;
719
+ }
720
+
721
+ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
722
+ ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
723
+ if (backend == nullptr) {
724
+ *free = 0;
725
+ *total = 0;
726
+ return;
727
+ }
728
+ ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
729
+ get_device_memory(ctx->sock, free, total);
730
+ }
731
+
732
+ // RPC server-side implementation
733
+
734
+ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
735
+ // input serialization format: | size (8 bytes) |
736
+ uint64_t size;
737
+ memcpy(&size, input.data(), sizeof(size));
738
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
739
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
740
+ uint64_t remote_ptr = reinterpret_cast<uint64_t>(buffer);
741
+ uint64_t remote_size = buffer->size;
742
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
743
+ // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
744
+ output.resize(2*sizeof(uint64_t), 0);
745
+ memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
746
+ memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
747
+ }
748
+
749
+ static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
750
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
751
+ size_t alignment = ggml_backend_buft_get_alignment(buft);
752
+ GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
753
+ // output serialization format: | alignment (8 bytes) |
754
+ output.resize(sizeof(uint64_t), 0);
755
+ memcpy(output.data(), &alignment, sizeof(alignment));
756
+ }
757
+
758
+ static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
759
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
760
+ size_t max_size = ggml_backend_buft_get_max_size(buft);
761
+ GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
762
+ // output serialization format: | max_size (8 bytes) |
763
+ output.resize(sizeof(uint64_t), 0);
764
+ memcpy(output.data(), &max_size, sizeof(max_size));
765
+ }
766
+
767
+ static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
768
+ // input serialization format: | remote_ptr (8 bytes) |
769
+ uint64_t remote_ptr;
770
+ memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
771
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
772
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
773
+ void * base = ggml_backend_buffer_get_base(buffer);
774
+ // output serialization format: | base_ptr (8 bytes) |
775
+ uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
776
+ output.resize(sizeof(uint64_t), 0);
777
+ memcpy(output.data(), &base_ptr, sizeof(base_ptr));
778
+ }
779
+
780
+ static void rpc_free_buffer(const std::vector<uint8_t> & input) {
781
+ // input serialization format: | remote_ptr (8 bytes) |
782
+ uint64_t remote_ptr;
783
+ memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
784
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
785
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
786
+ ggml_backend_buffer_free(buffer);
787
+ }
788
+
789
+ static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
790
+ // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
791
+ uint64_t remote_ptr;
792
+ memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
793
+ uint8_t value;
794
+ memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
795
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
796
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
797
+ ggml_backend_buffer_clear(buffer, value);
798
+ }
799
+
800
+ static void rpc_set_tensor(const std::vector<uint8_t> & input) {
801
+ // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
802
+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
803
+ uint64_t offset;
804
+ memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
805
+ size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
806
+
807
+ struct ggml_init_params params {
808
+ /*.mem_size =*/ ggml_tensor_overhead(),
809
+ /*.mem_buffer =*/ NULL,
810
+ /*.no_alloc =*/ true,
811
+ };
812
+ struct ggml_context * ctx = ggml_init(params);
813
+ ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
814
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
815
+ const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
816
+ ggml_backend_tensor_set(tensor, data, offset, size);
817
+ ggml_free(ctx);
818
+ }
819
+
820
+ static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
821
+ // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
822
+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
823
+ uint64_t offset;
824
+ memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
825
+ uint64_t size;
826
+ memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
827
+
828
+ struct ggml_init_params params {
829
+ /*.mem_size =*/ ggml_tensor_overhead(),
830
+ /*.mem_buffer =*/ NULL,
831
+ /*.no_alloc =*/ true,
832
+ };
833
+ struct ggml_context * ctx = ggml_init(params);
834
+ ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
835
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
836
+ // output serialization format: | data (size bytes) |
837
+ output.resize(size, 0);
838
+ ggml_backend_tensor_get(tensor, output.data(), offset, size);
839
+ ggml_free(ctx);
840
+ }
841
+
842
+ static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
843
+ // serialization format: | rpc_tensor src | rpc_tensor dst |
844
+ const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
845
+ const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
846
+
847
+ struct ggml_init_params params {
848
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
849
+ /*.mem_buffer =*/ NULL,
850
+ /*.no_alloc =*/ true,
851
+ };
852
+ struct ggml_context * ctx = ggml_init(params);
853
+ ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
854
+ ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
855
+ GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
856
+ bool result = ggml_backend_buffer_copy_tensor(src, dst);
857
+ // output serialization format: | result (1 byte) |
858
+ output.resize(1, 0);
859
+ output[0] = result;
860
+ ggml_free(ctx);
861
+ }
862
+
863
+ static struct ggml_tensor * create_node(uint64_t id,
864
+ struct ggml_context * ctx,
865
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
866
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
867
+ if (id == 0) {
868
+ return nullptr;
869
+ }
870
+ if (tensor_map.find(id) != tensor_map.end()) {
871
+ return tensor_map[id];
872
+ }
873
+ const rpc_tensor * tensor = tensor_ptrs.at(id);
874
+ struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
875
+ tensor_map[id] = result;
876
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
877
+ result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
878
+ }
879
+ result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
880
+ result->view_offs = tensor->view_offs;
881
+ return result;
882
+ }
883
+
884
+ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
885
+ // serialization format:
886
+ // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
887
+ uint32_t n_nodes;
888
+ memcpy(&n_nodes, input.data(), sizeof(n_nodes));
889
+ const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
890
+ uint32_t n_tensors;
891
+ memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
892
+ const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
893
+ GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
894
+
895
+ static size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
896
+ struct ggml_init_params params = {
897
+ /*.mem_size =*/ buf_size,
898
+ /*.mem_buffer =*/ NULL,
899
+ /*.no_alloc =*/ true,
900
+ };
901
+ struct ggml_context * ctx = ggml_init(params);
902
+ struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
903
+ graph->n_nodes = n_nodes;
904
+ std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
905
+ for (uint32_t i = 0; i < n_tensors; i++) {
906
+ tensor_ptrs[tensors[i].id] = &tensors[i];
907
+ }
908
+ std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
909
+ for (uint32_t i = 0; i < n_nodes; i++) {
910
+ graph->nodes[i] = create_node(nodes[i], ctx, tensor_ptrs, tensor_map);
911
+ }
912
+ ggml_status status = ggml_backend_graph_compute(backend, graph);
913
+ // output serialization format: | status (1 byte) |
914
+ output.resize(1, 0);
915
+ output[0] = status;
916
+ ggml_free(ctx);
917
+ }
918
+
919
+ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
920
+ while (true) {
921
+ uint8_t cmd;
922
+ if (!recv_data(sockfd, &cmd, 1)) {
923
+ break;
924
+ }
925
+ std::vector<uint8_t> input;
926
+ std::vector<uint8_t> output;
927
+ uint64_t input_size;
928
+ if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
929
+ break;
930
+ }
931
+ input.resize(input_size);
932
+ if (!recv_data(sockfd, input.data(), input_size)) {
933
+ break;
934
+ }
935
+ switch (cmd) {
936
+ case ALLOC_BUFFER: {
937
+ rpc_alloc_buffer(backend, input, output);
938
+ break;
939
+ }
940
+ case GET_ALIGNMENT: {
941
+ rpc_get_alignment(backend, output);
942
+ break;
943
+ }
944
+ case GET_MAX_SIZE: {
945
+ rpc_get_max_size(backend, output);
946
+ break;
947
+ }
948
+ case BUFFER_GET_BASE: {
949
+ rpc_buffer_get_base(input, output);
950
+ break;
951
+ }
952
+ case FREE_BUFFER: {
953
+ rpc_free_buffer(input);
954
+ break;
955
+ }
956
+ case BUFFER_CLEAR: {
957
+ rpc_buffer_clear(input);
958
+ break;
959
+ }
960
+ case SET_TENSOR: {
961
+ rpc_set_tensor(input);
962
+ break;
963
+ }
964
+ case GET_TENSOR: {
965
+ rpc_get_tensor(input, output);
966
+ break;
967
+ }
968
+ case COPY_TENSOR: {
969
+ rpc_copy_tensor(input, output);
970
+ break;
971
+ }
972
+ case GRAPH_COMPUTE: {
973
+ rpc_graph_compute(backend, input, output);
974
+ break;
975
+ }
976
+ case GET_DEVICE_MEMORY: {
977
+ // output serialization format: | free (8 bytes) | total (8 bytes) |
978
+ output.resize(2*sizeof(uint64_t), 0);
979
+ memcpy(output.data(), &free_mem, sizeof(free_mem));
980
+ memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
981
+ break;
982
+ }
983
+ default: {
984
+ fprintf(stderr, "Unknown command: %d\n", cmd);
985
+ return;
986
+ }
987
+ }
988
+ uint64_t output_size = output.size();
989
+ if (!send_data(sockfd, &output_size, sizeof(output_size))) {
990
+ break;
991
+ }
992
+ if (!send_data(sockfd, output.data(), output_size)) {
993
+ break;
994
+ }
995
+ }
996
+ }
997
+
998
+ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
999
+ std::string host;
1000
+ int port;
1001
+ if (!parse_endpoint(endpoint, host, port)) {
1002
+ return;
1003
+ }
1004
+ #ifdef _WIN32
1005
+ {
1006
+ WSADATA wsaData;
1007
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
1008
+ if (res != 0) {
1009
+ fprintf(stderr, "WSAStartup failed: %d\n", res);
1010
+ return;
1011
+ }
1012
+ }
1013
+ #endif
1014
+ auto server_socket = create_server_socket(host.c_str(), port);
1015
+ if (server_socket == nullptr) {
1016
+ fprintf(stderr, "Failed to create server socket\n");
1017
+ return;
1018
+ }
1019
+ while (true) {
1020
+ auto client_socket = socket_accept(server_socket->fd);
1021
+ if (client_socket == nullptr) {
1022
+ fprintf(stderr, "Failed to accept client connection\n");
1023
+ return;
1024
+ }
1025
+ printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1026
+ rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1027
+ printf("Client connection closed\n");
1028
+ }
1029
+ #ifdef _WIN32
1030
+ WSACleanup();
1031
+ #endif
1032
+ }