llama_cpp 0.15.2 → 0.15.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,6 +56,7 @@ struct socket_t {
56
56
  };
57
57
 
58
58
  // ggml_tensor is serialized into rpc_tensor
59
+ #pragma pack(push, 1)
59
60
  struct rpc_tensor {
60
61
  uint64_t id;
61
62
  uint32_t type;
@@ -71,6 +72,7 @@ struct rpc_tensor {
71
72
  uint64_t data;
72
73
  char name[GGML_MAX_NAME];
73
74
  };
75
+ #pragma pack(pop)
74
76
 
75
77
  // RPC commands
76
78
  enum rpc_cmd {
@@ -340,23 +342,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
340
342
  return result;
341
343
  }
342
344
 
343
- static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
344
- ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
345
- tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
346
- for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
347
- result->nb[i] = tensor->nb[i];
348
- }
349
- result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
350
- result->op = (ggml_op) tensor->op;
351
- for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
352
- result->op_params[i] = tensor->op_params[i];
353
- }
354
- result->flags = tensor->flags;
355
- result->data = reinterpret_cast<void *>(tensor->data);
356
- ggml_set_name(result, tensor->name);
357
- return result;
358
- }
359
-
360
345
  GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
361
346
  UNUSED(buffer);
362
347
  if (ggml_is_quantized(tensor->type)) {
@@ -465,13 +450,15 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
465
450
  memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
466
451
  size_t remote_size;
467
452
  memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
468
-
469
- ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
470
- ggml_backend_rpc_buffer_interface,
471
- new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
472
- remote_size);
473
-
474
- return buffer;
453
+ if (remote_ptr != 0) {
454
+ ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
455
+ ggml_backend_rpc_buffer_interface,
456
+ new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
457
+ remote_size);
458
+ return buffer;
459
+ } else {
460
+ return nullptr;
461
+ }
475
462
  }
476
463
 
477
464
  static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
@@ -658,7 +645,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
658
645
  }
659
646
  }
660
647
  #endif
661
- GGML_PRINT_DEBUG("Connecting to %s\n", endpoint);
648
+ fprintf(stderr, "Connecting to %s\n", endpoint);
662
649
  std::string host;
663
650
  int port;
664
651
  if (!parse_endpoint(endpoint, host, port)) {
@@ -731,22 +718,61 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
731
718
 
732
719
  // RPC server-side implementation
733
720
 
734
- static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
721
+ class rpc_server {
722
+ public:
723
+ rpc_server(ggml_backend_t backend) : backend(backend) {}
724
+ ~rpc_server();
725
+
726
+ bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
727
+ void get_alignment(std::vector<uint8_t> & output);
728
+ void get_max_size(std::vector<uint8_t> & output);
729
+ bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
730
+ bool free_buffer(const std::vector<uint8_t> & input);
731
+ bool buffer_clear(const std::vector<uint8_t> & input);
732
+ bool set_tensor(const std::vector<uint8_t> & input);
733
+ bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
734
+ bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
735
+ bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
736
+
737
+ private:
738
+ ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
739
+ ggml_tensor * create_node(uint64_t id,
740
+ struct ggml_context * ctx,
741
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
742
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
743
+
744
+
745
+ ggml_backend_t backend;
746
+ std::unordered_set<ggml_backend_buffer_t> buffers;
747
+ };
748
+
749
+ bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
735
750
  // input serialization format: | size (8 bytes) |
751
+ if (input.size() != sizeof(uint64_t)) {
752
+ return false;
753
+ }
736
754
  uint64_t size;
737
755
  memcpy(&size, input.data(), sizeof(size));
738
756
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
739
757
  ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
740
- uint64_t remote_ptr = reinterpret_cast<uint64_t>(buffer);
741
- uint64_t remote_size = buffer->size;
742
- GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
758
+ uint64_t remote_ptr = 0;
759
+ uint64_t remote_size = 0;
760
+ if (buffer != nullptr) {
761
+ remote_ptr = reinterpret_cast<uint64_t>(buffer);
762
+ remote_size = buffer->size;
763
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
764
+ buffers.insert(buffer);
765
+ } else {
766
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
767
+ }
743
768
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
744
769
  output.resize(2*sizeof(uint64_t), 0);
745
770
  memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
746
771
  memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
772
+ return true;
747
773
  }
748
774
 
749
- static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
775
+ void rpc_server::get_alignment(std::vector<uint8_t> & output) {
750
776
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
751
777
  size_t alignment = ggml_backend_buft_get_alignment(buft);
752
778
  GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
@@ -755,7 +781,7 @@ static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & out
755
781
  memcpy(output.data(), &alignment, sizeof(alignment));
756
782
  }
757
783
 
758
- static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
784
+ void rpc_server::get_max_size(std::vector<uint8_t> & output) {
759
785
  ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
760
786
  size_t max_size = ggml_backend_buft_get_max_size(buft);
761
787
  GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
@@ -764,41 +790,90 @@ static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & outp
764
790
  memcpy(output.data(), &max_size, sizeof(max_size));
765
791
  }
766
792
 
767
- static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
793
+ bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
768
794
  // input serialization format: | remote_ptr (8 bytes) |
795
+ if (input.size() != sizeof(uint64_t)) {
796
+ return false;
797
+ }
769
798
  uint64_t remote_ptr;
770
799
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
771
800
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
772
801
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
802
+ if (buffers.find(buffer) == buffers.end()) {
803
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
804
+ return false;
805
+ }
773
806
  void * base = ggml_backend_buffer_get_base(buffer);
774
807
  // output serialization format: | base_ptr (8 bytes) |
775
808
  uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
776
809
  output.resize(sizeof(uint64_t), 0);
777
810
  memcpy(output.data(), &base_ptr, sizeof(base_ptr));
811
+ return true;
778
812
  }
779
813
 
780
- static void rpc_free_buffer(const std::vector<uint8_t> & input) {
814
+ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
781
815
  // input serialization format: | remote_ptr (8 bytes) |
816
+ if (input.size() != sizeof(uint64_t)) {
817
+ return false;
818
+ }
782
819
  uint64_t remote_ptr;
783
820
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
784
821
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
785
822
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
823
+ if (buffers.find(buffer) == buffers.end()) {
824
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
825
+ return false;
826
+ }
786
827
  ggml_backend_buffer_free(buffer);
828
+ buffers.erase(buffer);
829
+ return true;
787
830
  }
788
831
 
789
- static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
832
+ bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
790
833
  // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
834
+ if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
835
+ return false;
836
+ }
791
837
  uint64_t remote_ptr;
792
838
  memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
793
839
  uint8_t value;
794
840
  memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
795
841
  GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
796
842
  ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
843
+ if (buffers.find(buffer) == buffers.end()) {
844
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
845
+ return false;
846
+ }
797
847
  ggml_backend_buffer_clear(buffer, value);
848
+ return true;
798
849
  }
799
850
 
800
- static void rpc_set_tensor(const std::vector<uint8_t> & input) {
851
+ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
852
+ ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
853
+ tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
854
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
855
+ result->nb[i] = tensor->nb[i];
856
+ }
857
+ result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
858
+ if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
859
+ return nullptr;
860
+ }
861
+ result->op = (ggml_op) tensor->op;
862
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
863
+ result->op_params[i] = tensor->op_params[i];
864
+ }
865
+ result->flags = tensor->flags;
866
+ result->data = reinterpret_cast<void *>(tensor->data);
867
+ ggml_set_name(result, tensor->name);
868
+ return result;
869
+ }
870
+
871
+
872
+ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
801
873
  // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
874
+ if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
875
+ return false;
876
+ }
802
877
  const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
803
878
  uint64_t offset;
804
879
  memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
@@ -811,14 +886,23 @@ static void rpc_set_tensor(const std::vector<uint8_t> & input) {
811
886
  };
812
887
  struct ggml_context * ctx = ggml_init(params);
813
888
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
889
+ if (tensor == nullptr) {
890
+ GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
891
+ ggml_free(ctx);
892
+ return false;
893
+ }
814
894
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
815
895
  const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
816
896
  ggml_backend_tensor_set(tensor, data, offset, size);
817
897
  ggml_free(ctx);
898
+ return true;
818
899
  }
819
900
 
820
- static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
901
+ bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
821
902
  // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
903
+ if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
904
+ return false;
905
+ }
822
906
  const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
823
907
  uint64_t offset;
824
908
  memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
@@ -832,15 +916,24 @@ static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8
832
916
  };
833
917
  struct ggml_context * ctx = ggml_init(params);
834
918
  ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
919
+ if (tensor == nullptr) {
920
+ GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
921
+ ggml_free(ctx);
922
+ return false;
923
+ }
835
924
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
836
925
  // output serialization format: | data (size bytes) |
837
926
  output.resize(size, 0);
838
927
  ggml_backend_tensor_get(tensor, output.data(), offset, size);
839
928
  ggml_free(ctx);
929
+ return true;
840
930
  }
841
931
 
842
- static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
932
+ bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
843
933
  // serialization format: | rpc_tensor src | rpc_tensor dst |
934
+ if (input.size() != 2*sizeof(rpc_tensor)) {
935
+ return false;
936
+ }
844
937
  const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
845
938
  const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
846
939
 
@@ -852,18 +945,24 @@ static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint
852
945
  struct ggml_context * ctx = ggml_init(params);
853
946
  ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
854
947
  ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
948
+ if (src == nullptr || dst == nullptr) {
949
+ GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
950
+ ggml_free(ctx);
951
+ return false;
952
+ }
855
953
  GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
856
954
  bool result = ggml_backend_buffer_copy_tensor(src, dst);
857
955
  // output serialization format: | result (1 byte) |
858
956
  output.resize(1, 0);
859
957
  output[0] = result;
860
958
  ggml_free(ctx);
959
+ return true;
861
960
  }
862
961
 
863
- static struct ggml_tensor * create_node(uint64_t id,
864
- struct ggml_context * ctx,
865
- const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
866
- std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
962
+ ggml_tensor * rpc_server::create_node(uint64_t id,
963
+ struct ggml_context * ctx,
964
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
965
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
867
966
  if (id == 0) {
868
967
  return nullptr;
869
968
  }
@@ -872,6 +971,9 @@ static struct ggml_tensor * create_node(uint64_t id,
872
971
  }
873
972
  const rpc_tensor * tensor = tensor_ptrs.at(id);
874
973
  struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
974
+ if (result == nullptr) {
975
+ return nullptr;
976
+ }
875
977
  tensor_map[id] = result;
876
978
  for (int i = 0; i < GGML_MAX_SRC; i++) {
877
979
  result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
@@ -881,14 +983,23 @@ static struct ggml_tensor * create_node(uint64_t id,
881
983
  return result;
882
984
  }
883
985
 
884
- static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
986
+ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
885
987
  // serialization format:
886
988
  // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
989
+ if (input.size() < sizeof(uint32_t)) {
990
+ return false;
991
+ }
887
992
  uint32_t n_nodes;
888
993
  memcpy(&n_nodes, input.data(), sizeof(n_nodes));
994
+ if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
995
+ return false;
996
+ }
889
997
  const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
890
998
  uint32_t n_tensors;
891
999
  memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
1000
+ if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
1001
+ return false;
1002
+ }
892
1003
  const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
893
1004
  GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
894
1005
 
@@ -914,9 +1025,17 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
914
1025
  output.resize(1, 0);
915
1026
  output[0] = status;
916
1027
  ggml_free(ctx);
1028
+ return true;
1029
+ }
1030
+
1031
+ rpc_server::~rpc_server() {
1032
+ for (auto buffer : buffers) {
1033
+ ggml_backend_buffer_free(buffer);
1034
+ }
917
1035
  }
918
1036
 
919
1037
  static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1038
+ rpc_server server(backend);
920
1039
  while (true) {
921
1040
  uint8_t cmd;
922
1041
  if (!recv_data(sockfd, &cmd, 1)) {
@@ -932,45 +1051,46 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
932
1051
  if (!recv_data(sockfd, input.data(), input_size)) {
933
1052
  break;
934
1053
  }
1054
+ bool ok = true;
935
1055
  switch (cmd) {
936
1056
  case ALLOC_BUFFER: {
937
- rpc_alloc_buffer(backend, input, output);
1057
+ ok = server.alloc_buffer(input, output);
938
1058
  break;
939
1059
  }
940
1060
  case GET_ALIGNMENT: {
941
- rpc_get_alignment(backend, output);
1061
+ server.get_alignment(output);
942
1062
  break;
943
1063
  }
944
1064
  case GET_MAX_SIZE: {
945
- rpc_get_max_size(backend, output);
1065
+ server.get_max_size(output);
946
1066
  break;
947
1067
  }
948
1068
  case BUFFER_GET_BASE: {
949
- rpc_buffer_get_base(input, output);
1069
+ ok = server.buffer_get_base(input, output);
950
1070
  break;
951
1071
  }
952
1072
  case FREE_BUFFER: {
953
- rpc_free_buffer(input);
1073
+ ok = server.free_buffer(input);
954
1074
  break;
955
1075
  }
956
1076
  case BUFFER_CLEAR: {
957
- rpc_buffer_clear(input);
1077
+ ok = server.buffer_clear(input);
958
1078
  break;
959
1079
  }
960
1080
  case SET_TENSOR: {
961
- rpc_set_tensor(input);
1081
+ ok = server.set_tensor(input);
962
1082
  break;
963
1083
  }
964
1084
  case GET_TENSOR: {
965
- rpc_get_tensor(input, output);
1085
+ ok = server.get_tensor(input, output);
966
1086
  break;
967
1087
  }
968
1088
  case COPY_TENSOR: {
969
- rpc_copy_tensor(input, output);
1089
+ ok = server.copy_tensor(input, output);
970
1090
  break;
971
1091
  }
972
1092
  case GRAPH_COMPUTE: {
973
- rpc_graph_compute(backend, input, output);
1093
+ ok = server.graph_compute(input, output);
974
1094
  break;
975
1095
  }
976
1096
  case GET_DEVICE_MEMORY: {
@@ -982,9 +1102,12 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
982
1102
  }
983
1103
  default: {
984
1104
  fprintf(stderr, "Unknown command: %d\n", cmd);
985
- return;
1105
+ ok = false;
986
1106
  }
987
1107
  }
1108
+ if (!ok) {
1109
+ break;
1110
+ }
988
1111
  uint64_t output_size = output.size();
989
1112
  if (!send_data(sockfd, &output_size, sizeof(output_size))) {
990
1113
  break;