llama_cpp 0.15.2 → 0.15.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@
6
6
  #include <string>
7
7
  #include <vector>
8
8
  #include <memory>
9
+ #include <mutex>
9
10
  #include <unordered_map>
10
11
  #include <unordered_set>
11
12
  #ifdef _WIN32
@@ -47,6 +48,7 @@ struct socket_t {
47
48
  sockfd_t fd;
48
49
  socket_t(sockfd_t fd) : fd(fd) {}
49
50
  ~socket_t() {
51
+ GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
50
52
  #ifdef _WIN32
51
53
  closesocket(this->fd);
52
54
  #else
@@ -56,6 +58,7 @@ struct socket_t {
56
58
  };
57
59
 
58
60
  // ggml_tensor is serialized into rpc_tensor
61
+ #pragma pack(push, 1)
59
62
  struct rpc_tensor {
60
63
  uint64_t id;
61
64
  uint32_t type;
@@ -71,6 +74,7 @@ struct rpc_tensor {
71
74
  uint64_t data;
72
75
  char name[GGML_MAX_NAME];
73
76
  };
77
+ #pragma pack(pop)
74
78
 
75
79
  // RPC commands
76
80
  enum rpc_cmd {
@@ -95,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
95
99
  }
96
100
 
97
101
  struct ggml_backend_rpc_buffer_type_context {
98
- std::shared_ptr<socket_t> sock;
102
+ std::string endpoint;
99
103
  std::string name;
100
104
  size_t alignment;
101
105
  size_t max_size;
@@ -104,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
104
108
  struct ggml_backend_rpc_context {
105
109
  std::string endpoint;
106
110
  std::string name;
107
- std::shared_ptr<socket_t> sock;
108
- ggml_backend_buffer_type_t buft;
109
111
  };
110
112
 
111
113
  struct ggml_backend_rpc_buffer_context {
@@ -229,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
229
231
  return true;
230
232
  }
231
233
 
232
- static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
233
- std::string str(endpoint);
234
- size_t pos = str.find(':');
234
+ static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
235
+ size_t pos = endpoint.find(':');
235
236
  if (pos == std::string::npos) {
236
237
  return false;
237
238
  }
238
- host = str.substr(0, pos);
239
- port = std::stoi(str.substr(pos + 1));
239
+ host = endpoint.substr(0, pos);
240
+ port = std::stoi(endpoint.substr(pos + 1));
240
241
  return true;
241
242
  }
242
243
 
@@ -271,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
271
272
 
272
273
  // RPC client-side implementation
273
274
 
275
+ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
276
+ static std::mutex mutex;
277
+ std::lock_guard<std::mutex> lock(mutex);
278
+ static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
279
+ static bool initialized = false;
280
+
281
+ auto it = sockets.find(endpoint);
282
+ if (it != sockets.end()) {
283
+ if (auto sock = it->second.lock()) {
284
+ return sock;
285
+ }
286
+ }
287
+ std::string host;
288
+ int port;
289
+ if (!parse_endpoint(endpoint, host, port)) {
290
+ return nullptr;
291
+ }
292
+ #ifdef _WIN32
293
+ if (!initialized) {
294
+ WSADATA wsaData;
295
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
296
+ if (res != 0) {
297
+ return nullptr;
298
+ }
299
+ initialized = true;
300
+ }
301
+ #else
302
+ UNUSED(initialized);
303
+ #endif
304
+ auto sock = socket_connect(host.c_str(), port);
305
+ if (sock == nullptr) {
306
+ return nullptr;
307
+ }
308
+ GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
309
+ sockets[endpoint] = sock;
310
+ return sock;
311
+ }
312
+
274
313
  GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
275
314
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
276
315
  return ctx->name.c_str();
@@ -340,23 +379,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
340
379
  return result;
341
380
  }
342
381
 
343
- static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
344
- ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
345
- tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
346
- for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
347
- result->nb[i] = tensor->nb[i];
348
- }
349
- result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
350
- result->op = (ggml_op) tensor->op;
351
- for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
352
- result->op_params[i] = tensor->op_params[i];
353
- }
354
- result->flags = tensor->flags;
355
- result->data = reinterpret_cast<void *>(tensor->data);
356
- ggml_set_name(result, tensor->name);
357
- return result;
358
- }
359
-
360
382
  GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
361
383
  UNUSED(buffer);
362
384
  if (ggml_is_quantized(tensor->type)) {
@@ -457,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
457
479
  std::vector<uint8_t> input(input_size, 0);
458
480
  memcpy(input.data(), &size, sizeof(size));
459
481
  std::vector<uint8_t> output;
460
- bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
482
+ auto sock = get_socket(buft_ctx->endpoint);
483
+ bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
461
484
  GGML_ASSERT(status);
462
485
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
463
486
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -465,13 +488,15 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
465
488
  memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
466
489
  size_t remote_size;
467
490
  memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
468
-
469
- ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
470
- ggml_backend_rpc_buffer_interface,
471
- new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
472
- remote_size);
473
-
474
- return buffer;
491
+ if (remote_ptr != 0) {
492
+ ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
493
+ ggml_backend_rpc_buffer_interface,
494
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
495
+ remote_size);
496
+ return buffer;
497
+ } else {
498
+ return nullptr;
499
+ }
475
500
  }
476
501
 
477
502
  static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
@@ -521,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
521
546
  }
522
547
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
523
548
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
524
- return buft_ctx->sock == rpc_ctx->sock;
549
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
525
550
  }
526
551
 
527
552
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -534,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
534
559
  /* .is_host = */ NULL,
535
560
  };
536
561
 
537
-
538
562
  GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
539
563
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
540
564
 
@@ -543,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
543
567
 
544
568
  GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
545
569
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
546
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
547
- delete buft_ctx;
548
- delete rpc_ctx->buft;
549
570
  delete rpc_ctx;
550
571
  delete backend;
551
572
  }
552
573
 
553
574
  GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
554
575
  ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
555
- return ctx->buft;
576
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
556
577
  }
557
578
 
558
579
  GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
@@ -603,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
603
624
  std::vector<uint8_t> input;
604
625
  serialize_graph(cgraph, input);
605
626
  std::vector<uint8_t> output;
606
- bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
627
+ auto sock = get_socket(rpc_ctx->endpoint);
628
+ bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
607
629
  GGML_ASSERT(status);
608
630
  GGML_ASSERT(output.size() == 1);
609
631
  return (enum ggml_status)output[0];
@@ -637,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
637
659
  /* .event_synchronize = */ NULL,
638
660
  };
639
661
 
640
- static std::unordered_map<std::string, ggml_backend_t> instances;
641
-
642
662
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
643
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
644
- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
645
- }
646
-
647
- GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
648
- std::string endpoint_str(endpoint);
649
- if (instances.find(endpoint_str) != instances.end()) {
650
- return instances[endpoint_str];
651
- }
652
- #ifdef _WIN32
653
- {
654
- WSADATA wsaData;
655
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
656
- if (res != 0) {
657
- return nullptr;
658
- }
659
- }
660
- #endif
661
- GGML_PRINT_DEBUG("Connecting to %s\n", endpoint);
662
- std::string host;
663
- int port;
664
- if (!parse_endpoint(endpoint, host, port)) {
665
- return nullptr;
663
+ static std::mutex mutex;
664
+ std::lock_guard<std::mutex> lock(mutex);
665
+ // NOTE: buffer types are allocated and never freed; this is by design
666
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
667
+ auto it = buft_map.find(endpoint);
668
+ if (it != buft_map.end()) {
669
+ return it->second;
666
670
  }
667
- auto sock = socket_connect(host.c_str(), port);
671
+ auto sock = get_socket(endpoint);
668
672
  if (sock == nullptr) {
669
673
  return nullptr;
670
674
  }
671
675
  size_t alignment = get_alignment(sock);
672
676
  size_t max_size = get_max_size(sock);
673
677
  ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
674
- /* .sock = */ sock,
675
- /* .name = */ "RPC" + std::to_string(sock->fd),
678
+ /* .endpoint = */ endpoint,
679
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
676
680
  /* .alignment = */ alignment,
677
- /* .max_size = */ max_size
681
+ /* .max_size = */ max_size
678
682
  };
679
683
 
680
684
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
681
685
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
682
686
  /* .context = */ buft_ctx
683
687
  };
688
+ buft_map[endpoint] = buft;
689
+ return buft;
690
+ }
684
691
 
692
+ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
685
693
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
686
- /* .endpoint = */ endpoint,
687
- /* .name = */ "RPC" + std::to_string(sock->fd),
688
- /* .sock = */ sock,
689
- /* .buft = */ buft
694
+ /* .endpoint = */ endpoint,
695
+ /* .name = */ "RPC",
690
696
  };
691
697
 
692
- instances[endpoint] = new ggml_backend {
698
+ ggml_backend_t backend = new ggml_backend {
693
699
  /* .guid = */ ggml_backend_rpc_guid(),
694
700
  /* .interface = */ ggml_backend_rpc_interface,
695
701
  /* .context = */ ctx
696
702
  };
697
-
698
- return instances[endpoint];
703
+ return backend;
699
704
  }
700
705
 
701
706
  GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
@@ -719,34 +724,72 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
719
724
  }
720
725
 
721
726
  GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
722
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
723
- if (backend == nullptr) {
727
+ auto sock = get_socket(endpoint);
728
+ if (sock == nullptr) {
724
729
  *free = 0;
725
730
  *total = 0;
726
731
  return;
727
732
  }
728
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
729
- get_device_memory(ctx->sock, free, total);
733
+ get_device_memory(sock, free, total);
730
734
  }
731
735
 
732
736
  // RPC server-side implementation
733
737
 
734
- static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
738
+ class rpc_server {
739
+ public:
740
+ rpc_server(ggml_backend_t backend) : backend(backend) {}
741
+ ~rpc_server();
742
+
743
+ bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
744
+ void get_alignment(std::vector<uint8_t> & output);
745
+ void get_max_size(std::vector<uint8_t> & output);
746
+ bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
747
+ bool free_buffer(const std::vector<uint8_t> & input);
748
+ bool buffer_clear(const std::vector<uint8_t> & input);
749
+ bool set_tensor(const std::vector<uint8_t> & input);
750
+ bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
751
+ bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
752
+ bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
753
+
754
+ private:
755
+ ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
756
+ ggml_tensor * create_node(uint64_t id,
757
+ struct ggml_context * ctx,
758
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
759
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
760
+
761
+
762
+ ggml_backend_t backend;
763
+ std::unordered_set<ggml_backend_buffer_t> buffers;
764
+ };
765
+
766
+ bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
735
767
  // input serialization format: | size (8 bytes) |
768
+ if (input.size() != sizeof(uint64_t)) {
769
+ return false;
770
+ }
736
771
  uint64_t size;
737
772
  memcpy(&size, input.data(), sizeof(size));
738
773
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
739
774
  ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
740
- uint64_t remote_ptr = reinterpret_cast<uint64_t>(buffer);
741
- uint64_t remote_size = buffer->size;
742
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
775
+ uint64_t remote_ptr = 0;
776
+ uint64_t remote_size = 0;
777
+ if (buffer != nullptr) {
778
+ remote_ptr = reinterpret_cast<uint64_t>(buffer);
779
+ remote_size = buffer->size;
780
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
781
+ buffers.insert(buffer);
782
+ } else {
783
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
784
+ }
743
785
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
744
786
  output.resize(2*sizeof(uint64_t), 0);
745
787
  memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
746
788
  memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
789
+ return true;
747
790
  }
748
791
 
749
- static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
792
+ void rpc_server::get_alignment(std::vector<uint8_t> & output) {
750
793
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
751
794
  size_t alignment = ggml_backend_buft_get_alignment(buft);
752
795
  GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
@@ -755,7 +798,7 @@ static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & out
755
798
  memcpy(output.data(), &alignment, sizeof(alignment));
756
799
  }
757
800
 
758
- static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
801
+ void rpc_server::get_max_size(std::vector<uint8_t> & output) {
759
802
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
760
803
  size_t max_size = ggml_backend_buft_get_max_size(buft);
761
804
  GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
@@ -764,41 +807,90 @@ static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & outp
764
807
  memcpy(output.data(), &max_size, sizeof(max_size));
765
808
  }
766
809
 
767
- static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
810
+ bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
768
811
  // input serialization format: | remote_ptr (8 bytes) |
812
+ if (input.size() != sizeof(uint64_t)) {
813
+ return false;
814
+ }
769
815
  uint64_t remote_ptr;
770
816
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
771
817
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
772
818
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
819
+ if (buffers.find(buffer) == buffers.end()) {
820
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
821
+ return false;
822
+ }
773
823
  void * base = ggml_backend_buffer_get_base(buffer);
774
824
  // output serialization format: | base_ptr (8 bytes) |
775
825
  uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
776
826
  output.resize(sizeof(uint64_t), 0);
777
827
  memcpy(output.data(), &base_ptr, sizeof(base_ptr));
828
+ return true;
778
829
  }
779
830
 
780
- static void rpc_free_buffer(const std::vector<uint8_t> & input) {
831
+ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
781
832
  // input serialization format: | remote_ptr (8 bytes) |
833
+ if (input.size() != sizeof(uint64_t)) {
834
+ return false;
835
+ }
782
836
  uint64_t remote_ptr;
783
837
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
784
838
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
785
839
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
840
+ if (buffers.find(buffer) == buffers.end()) {
841
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
842
+ return false;
843
+ }
786
844
  ggml_backend_buffer_free(buffer);
845
+ buffers.erase(buffer);
846
+ return true;
787
847
  }
788
848
 
789
- static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
849
+ bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
790
850
  // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
851
+ if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
852
+ return false;
853
+ }
791
854
  uint64_t remote_ptr;
792
855
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
793
856
  uint8_t value;
794
857
  memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
795
858
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
796
859
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
860
+ if (buffers.find(buffer) == buffers.end()) {
861
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
862
+ return false;
863
+ }
797
864
  ggml_backend_buffer_clear(buffer, value);
865
+ return true;
866
+ }
867
+
868
+ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
869
+ ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
870
+ tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
871
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
872
+ result->nb[i] = tensor->nb[i];
873
+ }
874
+ result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
875
+ if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
876
+ return nullptr;
877
+ }
878
+ result->op = (ggml_op) tensor->op;
879
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
880
+ result->op_params[i] = tensor->op_params[i];
881
+ }
882
+ result->flags = tensor->flags;
883
+ result->data = reinterpret_cast<void *>(tensor->data);
884
+ ggml_set_name(result, tensor->name);
885
+ return result;
798
886
  }
799
887
 
800
- static void rpc_set_tensor(const std::vector<uint8_t> & input) {
888
+
889
+ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
801
890
  // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
891
+ if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
892
+ return false;
893
+ }
802
894
  const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
803
895
  uint64_t offset;
804
896
  memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
@@ -811,14 +903,23 @@ static void rpc_set_tensor(const std::vector<uint8_t> & input) {
811
903
  };
812
904
  struct ggml_context * ctx = ggml_init(params);
813
905
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
906
+ if (tensor == nullptr) {
907
+ GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
908
+ ggml_free(ctx);
909
+ return false;
910
+ }
814
911
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
815
912
  const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
816
913
  ggml_backend_tensor_set(tensor, data, offset, size);
817
914
  ggml_free(ctx);
915
+ return true;
818
916
  }
819
917
 
820
- static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
918
+ bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
821
919
  // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
920
+ if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
921
+ return false;
922
+ }
822
923
  const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
823
924
  uint64_t offset;
824
925
  memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
@@ -832,15 +933,24 @@ static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8
832
933
  };
833
934
  struct ggml_context * ctx = ggml_init(params);
834
935
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
936
+ if (tensor == nullptr) {
937
+ GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
938
+ ggml_free(ctx);
939
+ return false;
940
+ }
835
941
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
836
942
  // output serialization format: | data (size bytes) |
837
943
  output.resize(size, 0);
838
944
  ggml_backend_tensor_get(tensor, output.data(), offset, size);
839
945
  ggml_free(ctx);
946
+ return true;
840
947
  }
841
948
 
842
- static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
949
+ bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
843
950
  // serialization format: | rpc_tensor src | rpc_tensor dst |
951
+ if (input.size() != 2*sizeof(rpc_tensor)) {
952
+ return false;
953
+ }
844
954
  const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
845
955
  const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
846
956
 
@@ -852,18 +962,24 @@ static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint
852
962
  struct ggml_context * ctx = ggml_init(params);
853
963
  ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
854
964
  ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
965
+ if (src == nullptr || dst == nullptr) {
966
+ GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
967
+ ggml_free(ctx);
968
+ return false;
969
+ }
855
970
  GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
856
971
  bool result = ggml_backend_buffer_copy_tensor(src, dst);
857
972
  // output serialization format: | result (1 byte) |
858
973
  output.resize(1, 0);
859
974
  output[0] = result;
860
975
  ggml_free(ctx);
976
+ return true;
861
977
  }
862
978
 
863
- static struct ggml_tensor * create_node(uint64_t id,
864
- struct ggml_context * ctx,
865
- const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
866
- std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
979
+ ggml_tensor * rpc_server::create_node(uint64_t id,
980
+ struct ggml_context * ctx,
981
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
982
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
867
983
  if (id == 0) {
868
984
  return nullptr;
869
985
  }
@@ -872,6 +988,9 @@ static struct ggml_tensor * create_node(uint64_t id,
872
988
  }
873
989
  const rpc_tensor * tensor = tensor_ptrs.at(id);
874
990
  struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
991
+ if (result == nullptr) {
992
+ return nullptr;
993
+ }
875
994
  tensor_map[id] = result;
876
995
  for (int i = 0; i < GGML_MAX_SRC; i++) {
877
996
  result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
@@ -881,14 +1000,23 @@ static struct ggml_tensor * create_node(uint64_t id,
881
1000
  return result;
882
1001
  }
883
1002
 
884
- static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
1003
+ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
885
1004
  // serialization format:
886
1005
  // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
1006
+ if (input.size() < sizeof(uint32_t)) {
1007
+ return false;
1008
+ }
887
1009
  uint32_t n_nodes;
888
1010
  memcpy(&n_nodes, input.data(), sizeof(n_nodes));
1011
+ if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
1012
+ return false;
1013
+ }
889
1014
  const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
890
1015
  uint32_t n_tensors;
891
1016
  memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
1017
+ if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1018
+ return false;
1019
+ }
892
1020
  const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
893
1021
  GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
894
1022
 
@@ -914,9 +1042,17 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
914
1042
  output.resize(1, 0);
915
1043
  output[0] = status;
916
1044
  ggml_free(ctx);
1045
+ return true;
1046
+ }
1047
+
1048
+ rpc_server::~rpc_server() {
1049
+ for (auto buffer : buffers) {
1050
+ ggml_backend_buffer_free(buffer);
1051
+ }
917
1052
  }
918
1053
 
919
1054
  static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1055
+ rpc_server server(backend);
920
1056
  while (true) {
921
1057
  uint8_t cmd;
922
1058
  if (!recv_data(sockfd, &cmd, 1)) {
@@ -932,45 +1068,46 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
932
1068
  if (!recv_data(sockfd, input.data(), input_size)) {
933
1069
  break;
934
1070
  }
1071
+ bool ok = true;
935
1072
  switch (cmd) {
936
1073
  case ALLOC_BUFFER: {
937
- rpc_alloc_buffer(backend, input, output);
1074
+ ok = server.alloc_buffer(input, output);
938
1075
  break;
939
1076
  }
940
1077
  case GET_ALIGNMENT: {
941
- rpc_get_alignment(backend, output);
1078
+ server.get_alignment(output);
942
1079
  break;
943
1080
  }
944
1081
  case GET_MAX_SIZE: {
945
- rpc_get_max_size(backend, output);
1082
+ server.get_max_size(output);
946
1083
  break;
947
1084
  }
948
1085
  case BUFFER_GET_BASE: {
949
- rpc_buffer_get_base(input, output);
1086
+ ok = server.buffer_get_base(input, output);
950
1087
  break;
951
1088
  }
952
1089
  case FREE_BUFFER: {
953
- rpc_free_buffer(input);
1090
+ ok = server.free_buffer(input);
954
1091
  break;
955
1092
  }
956
1093
  case BUFFER_CLEAR: {
957
- rpc_buffer_clear(input);
1094
+ ok = server.buffer_clear(input);
958
1095
  break;
959
1096
  }
960
1097
  case SET_TENSOR: {
961
- rpc_set_tensor(input);
1098
+ ok = server.set_tensor(input);
962
1099
  break;
963
1100
  }
964
1101
  case GET_TENSOR: {
965
- rpc_get_tensor(input, output);
1102
+ ok = server.get_tensor(input, output);
966
1103
  break;
967
1104
  }
968
1105
  case COPY_TENSOR: {
969
- rpc_copy_tensor(input, output);
1106
+ ok = server.copy_tensor(input, output);
970
1107
  break;
971
1108
  }
972
1109
  case GRAPH_COMPUTE: {
973
- rpc_graph_compute(backend, input, output);
1110
+ ok = server.graph_compute(input, output);
974
1111
  break;
975
1112
  }
976
1113
  case GET_DEVICE_MEMORY: {
@@ -982,9 +1119,12 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
982
1119
  }
983
1120
  default: {
984
1121
  fprintf(stderr, "Unknown command: %d\n", cmd);
985
- return;
1122
+ ok = false;
986
1123
  }
987
1124
  }
1125
+ if (!ok) {
1126
+ break;
1127
+ }
988
1128
  uint64_t output_size = output.size();
989
1129
  if (!send_data(sockfd, &output_size, sizeof(output_size))) {
990
1130
  break;