llama_cpp 0.15.2 → 0.15.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -56,6 +56,7 @@ struct socket_t {
56
56
  };
57
57
 
58
58
  // ggml_tensor is serialized into rpc_tensor
59
+ #pragma pack(push, 1)
59
60
  struct rpc_tensor {
60
61
  uint64_t id;
61
62
  uint32_t type;
@@ -71,6 +72,7 @@ struct rpc_tensor {
71
72
  uint64_t data;
72
73
  char name[GGML_MAX_NAME];
73
74
  };
75
+ #pragma pack(pop)
74
76
 
75
77
  // RPC commands
76
78
  enum rpc_cmd {
@@ -340,23 +342,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
340
342
  return result;
341
343
  }
342
344
 
343
- static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
344
- ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
345
- tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
346
- for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
347
- result->nb[i] = tensor->nb[i];
348
- }
349
- result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
350
- result->op = (ggml_op) tensor->op;
351
- for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
352
- result->op_params[i] = tensor->op_params[i];
353
- }
354
- result->flags = tensor->flags;
355
- result->data = reinterpret_cast<void *>(tensor->data);
356
- ggml_set_name(result, tensor->name);
357
- return result;
358
- }
359
-
360
345
  GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
361
346
  UNUSED(buffer);
362
347
  if (ggml_is_quantized(tensor->type)) {
@@ -465,13 +450,15 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
465
450
  memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
466
451
  size_t remote_size;
467
452
  memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
468
-
469
- ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
470
- ggml_backend_rpc_buffer_interface,
471
- new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
472
- remote_size);
473
-
474
- return buffer;
453
+ if (remote_ptr != 0) {
454
+ ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
455
+ ggml_backend_rpc_buffer_interface,
456
+ new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
457
+ remote_size);
458
+ return buffer;
459
+ } else {
460
+ return nullptr;
461
+ }
475
462
  }
476
463
 
477
464
  static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
@@ -658,7 +645,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
658
645
  }
659
646
  }
660
647
  #endif
661
- GGML_PRINT_DEBUG("Connecting to %s\n", endpoint);
648
+ fprintf(stderr, "Connecting to %s\n", endpoint);
662
649
  std::string host;
663
650
  int port;
664
651
  if (!parse_endpoint(endpoint, host, port)) {
@@ -731,22 +718,61 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
731
718
 
732
719
  // RPC server-side implementation
733
720
 
734
- static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
721
+ class rpc_server {
722
+ public:
723
+ rpc_server(ggml_backend_t backend) : backend(backend) {}
724
+ ~rpc_server();
725
+
726
+ bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
727
+ void get_alignment(std::vector<uint8_t> & output);
728
+ void get_max_size(std::vector<uint8_t> & output);
729
+ bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
730
+ bool free_buffer(const std::vector<uint8_t> & input);
731
+ bool buffer_clear(const std::vector<uint8_t> & input);
732
+ bool set_tensor(const std::vector<uint8_t> & input);
733
+ bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
734
+ bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
735
+ bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
736
+
737
+ private:
738
+ ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
739
+ ggml_tensor * create_node(uint64_t id,
740
+ struct ggml_context * ctx,
741
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
742
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
743
+
744
+
745
+ ggml_backend_t backend;
746
+ std::unordered_set<ggml_backend_buffer_t> buffers;
747
+ };
748
+
749
+ bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
735
750
  // input serialization format: | size (8 bytes) |
751
+ if (input.size() != sizeof(uint64_t)) {
752
+ return false;
753
+ }
736
754
  uint64_t size;
737
755
  memcpy(&size, input.data(), sizeof(size));
738
756
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
739
757
  ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
740
- uint64_t remote_ptr = reinterpret_cast<uint64_t>(buffer);
741
- uint64_t remote_size = buffer->size;
742
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
758
+ uint64_t remote_ptr = 0;
759
+ uint64_t remote_size = 0;
760
+ if (buffer != nullptr) {
761
+ remote_ptr = reinterpret_cast<uint64_t>(buffer);
762
+ remote_size = buffer->size;
763
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
764
+ buffers.insert(buffer);
765
+ } else {
766
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
767
+ }
743
768
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
744
769
  output.resize(2*sizeof(uint64_t), 0);
745
770
  memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
746
771
  memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
772
+ return true;
747
773
  }
748
774
 
749
- static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
775
+ void rpc_server::get_alignment(std::vector<uint8_t> & output) {
750
776
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
751
777
  size_t alignment = ggml_backend_buft_get_alignment(buft);
752
778
  GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
@@ -755,7 +781,7 @@ static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & out
755
781
  memcpy(output.data(), &alignment, sizeof(alignment));
756
782
  }
757
783
 
758
- static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
784
+ void rpc_server::get_max_size(std::vector<uint8_t> & output) {
759
785
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
760
786
  size_t max_size = ggml_backend_buft_get_max_size(buft);
761
787
  GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
@@ -764,41 +790,90 @@ static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & outp
764
790
  memcpy(output.data(), &max_size, sizeof(max_size));
765
791
  }
766
792
 
767
- static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
793
+ bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
768
794
  // input serialization format: | remote_ptr (8 bytes) |
795
+ if (input.size() != sizeof(uint64_t)) {
796
+ return false;
797
+ }
769
798
  uint64_t remote_ptr;
770
799
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
771
800
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
772
801
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
802
+ if (buffers.find(buffer) == buffers.end()) {
803
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
804
+ return false;
805
+ }
773
806
  void * base = ggml_backend_buffer_get_base(buffer);
774
807
  // output serialization format: | base_ptr (8 bytes) |
775
808
  uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
776
809
  output.resize(sizeof(uint64_t), 0);
777
810
  memcpy(output.data(), &base_ptr, sizeof(base_ptr));
811
+ return true;
778
812
  }
779
813
 
780
- static void rpc_free_buffer(const std::vector<uint8_t> & input) {
814
+ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
781
815
  // input serialization format: | remote_ptr (8 bytes) |
816
+ if (input.size() != sizeof(uint64_t)) {
817
+ return false;
818
+ }
782
819
  uint64_t remote_ptr;
783
820
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
784
821
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
785
822
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
823
+ if (buffers.find(buffer) == buffers.end()) {
824
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
825
+ return false;
826
+ }
786
827
  ggml_backend_buffer_free(buffer);
828
+ buffers.erase(buffer);
829
+ return true;
787
830
  }
788
831
 
789
- static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
832
+ bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
790
833
  // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
834
+ if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
835
+ return false;
836
+ }
791
837
  uint64_t remote_ptr;
792
838
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
793
839
  uint8_t value;
794
840
  memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
795
841
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
796
842
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
843
+ if (buffers.find(buffer) == buffers.end()) {
844
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
845
+ return false;
846
+ }
797
847
  ggml_backend_buffer_clear(buffer, value);
848
+ return true;
798
849
  }
799
850
 
800
- static void rpc_set_tensor(const std::vector<uint8_t> & input) {
851
+ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
852
+ ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
853
+ tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
854
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
855
+ result->nb[i] = tensor->nb[i];
856
+ }
857
+ result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
858
+ if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
859
+ return nullptr;
860
+ }
861
+ result->op = (ggml_op) tensor->op;
862
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
863
+ result->op_params[i] = tensor->op_params[i];
864
+ }
865
+ result->flags = tensor->flags;
866
+ result->data = reinterpret_cast<void *>(tensor->data);
867
+ ggml_set_name(result, tensor->name);
868
+ return result;
869
+ }
870
+
871
+
872
+ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
801
873
  // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
874
+ if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
875
+ return false;
876
+ }
802
877
  const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
803
878
  uint64_t offset;
804
879
  memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
@@ -811,14 +886,23 @@ static void rpc_set_tensor(const std::vector<uint8_t> & input) {
811
886
  };
812
887
  struct ggml_context * ctx = ggml_init(params);
813
888
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
889
+ if (tensor == nullptr) {
890
+ GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
891
+ ggml_free(ctx);
892
+ return false;
893
+ }
814
894
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
815
895
  const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
816
896
  ggml_backend_tensor_set(tensor, data, offset, size);
817
897
  ggml_free(ctx);
898
+ return true;
818
899
  }
819
900
 
820
- static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
901
+ bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
821
902
  // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
903
+ if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
904
+ return false;
905
+ }
822
906
  const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
823
907
  uint64_t offset;
824
908
  memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
@@ -832,15 +916,24 @@ static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8
832
916
  };
833
917
  struct ggml_context * ctx = ggml_init(params);
834
918
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
919
+ if (tensor == nullptr) {
920
+ GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
921
+ ggml_free(ctx);
922
+ return false;
923
+ }
835
924
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
836
925
  // output serialization format: | data (size bytes) |
837
926
  output.resize(size, 0);
838
927
  ggml_backend_tensor_get(tensor, output.data(), offset, size);
839
928
  ggml_free(ctx);
929
+ return true;
840
930
  }
841
931
 
842
- static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
932
+ bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
843
933
  // serialization format: | rpc_tensor src | rpc_tensor dst |
934
+ if (input.size() != 2*sizeof(rpc_tensor)) {
935
+ return false;
936
+ }
844
937
  const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
845
938
  const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
846
939
 
@@ -852,18 +945,24 @@ static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint
852
945
  struct ggml_context * ctx = ggml_init(params);
853
946
  ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
854
947
  ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
948
+ if (src == nullptr || dst == nullptr) {
949
+ GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
950
+ ggml_free(ctx);
951
+ return false;
952
+ }
855
953
  GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
856
954
  bool result = ggml_backend_buffer_copy_tensor(src, dst);
857
955
  // output serialization format: | result (1 byte) |
858
956
  output.resize(1, 0);
859
957
  output[0] = result;
860
958
  ggml_free(ctx);
959
+ return true;
861
960
  }
862
961
 
863
- static struct ggml_tensor * create_node(uint64_t id,
864
- struct ggml_context * ctx,
865
- const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
866
- std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
962
+ ggml_tensor * rpc_server::create_node(uint64_t id,
963
+ struct ggml_context * ctx,
964
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
965
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
867
966
  if (id == 0) {
868
967
  return nullptr;
869
968
  }
@@ -872,6 +971,9 @@ static struct ggml_tensor * create_node(uint64_t id,
872
971
  }
873
972
  const rpc_tensor * tensor = tensor_ptrs.at(id);
874
973
  struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
974
+ if (result == nullptr) {
975
+ return nullptr;
976
+ }
875
977
  tensor_map[id] = result;
876
978
  for (int i = 0; i < GGML_MAX_SRC; i++) {
877
979
  result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
@@ -881,14 +983,23 @@ static struct ggml_tensor * create_node(uint64_t id,
881
983
  return result;
882
984
  }
883
985
 
884
- static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
986
+ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
885
987
  // serialization format:
886
988
  // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
989
+ if (input.size() < sizeof(uint32_t)) {
990
+ return false;
991
+ }
887
992
  uint32_t n_nodes;
888
993
  memcpy(&n_nodes, input.data(), sizeof(n_nodes));
994
+ if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
995
+ return false;
996
+ }
889
997
  const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
890
998
  uint32_t n_tensors;
891
999
  memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
1000
+ if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1001
+ return false;
1002
+ }
892
1003
  const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
893
1004
  GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
894
1005
 
@@ -914,9 +1025,17 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
914
1025
  output.resize(1, 0);
915
1026
  output[0] = status;
916
1027
  ggml_free(ctx);
1028
+ return true;
1029
+ }
1030
+
1031
+ rpc_server::~rpc_server() {
1032
+ for (auto buffer : buffers) {
1033
+ ggml_backend_buffer_free(buffer);
1034
+ }
917
1035
  }
918
1036
 
919
1037
  static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1038
+ rpc_server server(backend);
920
1039
  while (true) {
921
1040
  uint8_t cmd;
922
1041
  if (!recv_data(sockfd, &cmd, 1)) {
@@ -932,45 +1051,46 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
932
1051
  if (!recv_data(sockfd, input.data(), input_size)) {
933
1052
  break;
934
1053
  }
1054
+ bool ok = true;
935
1055
  switch (cmd) {
936
1056
  case ALLOC_BUFFER: {
937
- rpc_alloc_buffer(backend, input, output);
1057
+ ok = server.alloc_buffer(input, output);
938
1058
  break;
939
1059
  }
940
1060
  case GET_ALIGNMENT: {
941
- rpc_get_alignment(backend, output);
1061
+ server.get_alignment(output);
942
1062
  break;
943
1063
  }
944
1064
  case GET_MAX_SIZE: {
945
- rpc_get_max_size(backend, output);
1065
+ server.get_max_size(output);
946
1066
  break;
947
1067
  }
948
1068
  case BUFFER_GET_BASE: {
949
- rpc_buffer_get_base(input, output);
1069
+ ok = server.buffer_get_base(input, output);
950
1070
  break;
951
1071
  }
952
1072
  case FREE_BUFFER: {
953
- rpc_free_buffer(input);
1073
+ ok = server.free_buffer(input);
954
1074
  break;
955
1075
  }
956
1076
  case BUFFER_CLEAR: {
957
- rpc_buffer_clear(input);
1077
+ ok = server.buffer_clear(input);
958
1078
  break;
959
1079
  }
960
1080
  case SET_TENSOR: {
961
- rpc_set_tensor(input);
1081
+ ok = server.set_tensor(input);
962
1082
  break;
963
1083
  }
964
1084
  case GET_TENSOR: {
965
- rpc_get_tensor(input, output);
1085
+ ok = server.get_tensor(input, output);
966
1086
  break;
967
1087
  }
968
1088
  case COPY_TENSOR: {
969
- rpc_copy_tensor(input, output);
1089
+ ok = server.copy_tensor(input, output);
970
1090
  break;
971
1091
  }
972
1092
  case GRAPH_COMPUTE: {
973
- rpc_graph_compute(backend, input, output);
1093
+ ok = server.graph_compute(input, output);
974
1094
  break;
975
1095
  }
976
1096
  case GET_DEVICE_MEMORY: {
@@ -982,9 +1102,12 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
982
1102
  }
983
1103
  default: {
984
1104
  fprintf(stderr, "Unknown command: %d\n", cmd);
985
- return;
1105
+ ok = false;
986
1106
  }
987
1107
  }
1108
+ if (!ok) {
1109
+ break;
1110
+ }
988
1111
  uint64_t output_size = output.size();
989
1112
  if (!send_data(sockfd, &output_size, sizeof(output_size))) {
990
1113
  break;