llama_cpp 0.15.3 → 0.15.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3813
3813
  return;
3814
3814
  }
3815
3815
  #endif
3816
- #if defined(__ARM_NEON)
3816
+ #if defined(__ARM_FEATURE_SVE)
3817
+ const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
3818
+ const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
3819
+
3820
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
3821
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
3822
+
3823
+ assert(nb % 2 == 0); // TODO: handle odd nb
3824
+
3825
+ for (int i = 0; i < nb; i += 2) {
3826
+ const block_q4_0 * restrict x0 = &x[i + 0];
3827
+ const block_q4_0 * restrict x1 = &x[i + 1];
3828
+ const block_q8_0 * restrict y0 = &y[i + 0];
3829
+ const block_q8_0 * restrict y1 = &y[i + 1];
3830
+
3831
+ // load x
3832
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
3833
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
3834
+
3835
+ // 4-bit -> 8-bit
3836
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
3837
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
3838
+
3839
+ // sub 8
3840
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
3841
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
3842
+
3843
+ // load y
3844
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
3845
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
3846
+
3847
+ // dot product
3848
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3849
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3850
+ }
3851
+
3852
+ *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
3853
+ #elif defined(__ARM_NEON)
3817
3854
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3818
3855
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
3819
3856
 
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5384
5421
  return;
5385
5422
  }
5386
5423
  #endif
5387
- #if defined(__ARM_NEON)
5424
+ #if defined(__ARM_FEATURE_SVE)
5425
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5426
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
5427
+
5428
+ assert(nb % 2 == 0); // TODO: handle odd nb
5429
+
5430
+ for (int i = 0; i < nb; i += 2) {
5431
+ const block_q8_0 * restrict x0 = &x[i + 0];
5432
+ const block_q8_0 * restrict x1 = &x[i + 1];
5433
+ const block_q8_0 * restrict y0 = &y[i + 0];
5434
+ const block_q8_0 * restrict y1 = &y[i + 1];
5435
+
5436
+ // load x
5437
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5438
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5439
+
5440
+ // load y
5441
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5442
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5443
+
5444
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5445
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5446
+ }
5447
+
5448
+ *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5449
+ #elif defined(__ARM_NEON)
5388
5450
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
5389
5451
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
5390
5452
 
@@ -6026,6 +6088,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6026
6088
 
6027
6089
  const uint8_t * restrict q2 = x[i].qs;
6028
6090
  const int8_t * restrict q8 = y[i].qs;
6091
+
6029
6092
  const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
6030
6093
  const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
6031
6094
  const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
@@ -6745,6 +6808,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6745
6808
  for (int i = 0; i < nb; ++i) {
6746
6809
 
6747
6810
  const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6811
+ const uint8_t * restrict q3 = x[i].qs;
6812
+ const int8_t * restrict q8 = y[i].qs;
6748
6813
  // Set up scales
6749
6814
  memcpy(aux, x[i].scales, 12);
6750
6815
  __m128i scales128 = lsx_set_w(
@@ -6766,29 +6831,32 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6766
6831
 
6767
6832
  int bit = 0;
6768
6833
  int is = 0;
6834
+ __m256i xvbit;
6769
6835
 
6770
- const uint8_t * restrict q3 = x[i].qs;
6771
- const int8_t * restrict q8 = y[i].qs;
6772
6836
 
6773
6837
  for (int j = 0; j < QK_K/128; ++j) {
6774
6838
  // load low 2 bits
6775
6839
  const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
6776
6840
 
6841
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6777
6842
  // prepare low and high bits
6778
6843
  const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
6779
- const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6844
+ const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6780
6845
  ++bit;
6781
6846
 
6847
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6782
6848
  const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
6783
- const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6849
+ const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6784
6850
  ++bit;
6785
6851
 
6852
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6786
6853
  const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
6787
- const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6854
+ const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6788
6855
  ++bit;
6789
6856
 
6857
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6790
6858
  const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
6791
- const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6859
+ const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6792
6860
  ++bit;
6793
6861
 
6794
6862
  // load Q8 quants
@@ -7337,6 +7405,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7337
7405
  *s = vec_extract(vsumf0, 0);
7338
7406
 
7339
7407
  #elif defined __loongarch_asx
7408
+ GGML_UNUSED(kmask1);
7409
+ GGML_UNUSED(kmask2);
7410
+ GGML_UNUSED(kmask3);
7340
7411
 
7341
7412
  const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7342
7413
 
@@ -7349,6 +7420,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7349
7420
  const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7350
7421
 
7351
7422
  memcpy(utmp, x[i].scales, 12);
7423
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7424
+ const uint32_t uaux = utmp[1] & kmask1;
7425
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7426
+ utmp[2] = uaux;
7427
+ utmp[0] &= kmask1;
7352
7428
 
7353
7429
  const uint8_t * restrict q4 = x[i].qs;
7354
7430
  const int8_t * restrict q8 = y[i].qs;
@@ -7388,16 +7464,17 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7388
7464
 
7389
7465
  __m256 vd = __lasx_xvreplfr2vr_s(d);
7390
7466
  acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
7467
+
7391
7468
  }
7392
7469
 
7393
7470
  acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
7394
7471
  __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
7395
7472
  acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
7396
7473
 
7474
+
7397
7475
  ft_union fi;
7398
7476
  fi.i = __lsx_vpickve2gr_w(acc_m, 0);
7399
7477
  *s = hsum_float_8(acc) + fi.f ;
7400
-
7401
7478
  #else
7402
7479
 
7403
7480
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -7935,6 +8012,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7935
8012
  *s = vec_extract(vsumf0, 0);
7936
8013
 
7937
8014
  #elif defined __loongarch_asx
8015
+ GGML_UNUSED(kmask1);
8016
+ GGML_UNUSED(kmask2);
8017
+ GGML_UNUSED(kmask3);
7938
8018
 
7939
8019
  const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7940
8020
  const __m128i mzero = __lsx_vldi(0);
@@ -7953,6 +8033,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7953
8033
  const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7954
8034
 
7955
8035
  memcpy(utmp, x[i].scales, 12);
8036
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
8037
+ const uint32_t uaux = utmp[1] & kmask1;
8038
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
8039
+ utmp[2] = uaux;
8040
+ utmp[0] &= kmask1;
7956
8041
 
7957
8042
  const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
7958
8043
 
@@ -7971,6 +8056,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7971
8056
  __m256i sumi = __lasx_xvldi(0);
7972
8057
 
7973
8058
  int bit = 0;
8059
+ __m256i xvbit;
7974
8060
 
7975
8061
  for (int j = 0; j < QK_K/64; ++j) {
7976
8062
 
@@ -7979,13 +8065,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7979
8065
 
7980
8066
  const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
7981
8067
 
8068
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
7982
8069
  const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
7983
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
8070
+ const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
7984
8071
  const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
7985
8072
  hmask = __lasx_xvslli_h(hmask, 1);
7986
8073
 
8074
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
7987
8075
  const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
7988
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
8076
+ const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
7989
8077
  const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
7990
8078
  hmask = __lasx_xvslli_h(hmask, 1);
7991
8079
 
@@ -7999,10 +8087,12 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7999
8087
  p16_1 = lasx_madd_h(scale_1, p16_1);
8000
8088
 
8001
8089
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
8090
+
8002
8091
  }
8003
8092
 
8004
8093
  __m256 vd = __lasx_xvreplfr2vr_s(d);
8005
8094
  acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
8095
+
8006
8096
  }
8007
8097
 
8008
8098
  *s = hsum_float_8(acc) + summs;
@@ -6,6 +6,7 @@
6
6
  #include <string>
7
7
  #include <vector>
8
8
  #include <memory>
9
+ #include <mutex>
9
10
  #include <unordered_map>
10
11
  #include <unordered_set>
11
12
  #ifdef _WIN32
@@ -47,6 +48,7 @@ struct socket_t {
47
48
  sockfd_t fd;
48
49
  socket_t(sockfd_t fd) : fd(fd) {}
49
50
  ~socket_t() {
51
+ GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
50
52
  #ifdef _WIN32
51
53
  closesocket(this->fd);
52
54
  #else
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
97
99
  }
98
100
 
99
101
  struct ggml_backend_rpc_buffer_type_context {
100
- std::shared_ptr<socket_t> sock;
102
+ std::string endpoint;
101
103
  std::string name;
102
104
  size_t alignment;
103
105
  size_t max_size;
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
106
108
  struct ggml_backend_rpc_context {
107
109
  std::string endpoint;
108
110
  std::string name;
109
- std::shared_ptr<socket_t> sock;
110
- ggml_backend_buffer_type_t buft;
111
111
  };
112
112
 
113
113
  struct ggml_backend_rpc_buffer_context {
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231
231
  return true;
232
232
  }
233
233
 
234
- static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
235
- std::string str(endpoint);
236
- size_t pos = str.find(':');
234
+ static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
235
+ size_t pos = endpoint.find(':');
237
236
  if (pos == std::string::npos) {
238
237
  return false;
239
238
  }
240
- host = str.substr(0, pos);
241
- port = std::stoi(str.substr(pos + 1));
239
+ host = endpoint.substr(0, pos);
240
+ port = std::stoi(endpoint.substr(pos + 1));
242
241
  return true;
243
242
  }
244
243
 
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273
272
 
274
273
  // RPC client-side implementation
275
274
 
275
+ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
276
+ static std::mutex mutex;
277
+ std::lock_guard<std::mutex> lock(mutex);
278
+ static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
279
+ static bool initialized = false;
280
+
281
+ auto it = sockets.find(endpoint);
282
+ if (it != sockets.end()) {
283
+ if (auto sock = it->second.lock()) {
284
+ return sock;
285
+ }
286
+ }
287
+ std::string host;
288
+ int port;
289
+ if (!parse_endpoint(endpoint, host, port)) {
290
+ return nullptr;
291
+ }
292
+ #ifdef _WIN32
293
+ if (!initialized) {
294
+ WSADATA wsaData;
295
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
296
+ if (res != 0) {
297
+ return nullptr;
298
+ }
299
+ initialized = true;
300
+ }
301
+ #else
302
+ UNUSED(initialized);
303
+ #endif
304
+ auto sock = socket_connect(host.c_str(), port);
305
+ if (sock == nullptr) {
306
+ return nullptr;
307
+ }
308
+ GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
309
+ sockets[endpoint] = sock;
310
+ return sock;
311
+ }
312
+
276
313
  GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
277
314
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
278
315
  return ctx->name.c_str();
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442
479
  std::vector<uint8_t> input(input_size, 0);
443
480
  memcpy(input.data(), &size, sizeof(size));
444
481
  std::vector<uint8_t> output;
445
- bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
482
+ auto sock = get_socket(buft_ctx->endpoint);
483
+ bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
446
484
  GGML_ASSERT(status);
447
485
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
448
486
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453
491
  if (remote_ptr != 0) {
454
492
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
455
493
  ggml_backend_rpc_buffer_interface,
456
- new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
494
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
457
495
  remote_size);
458
496
  return buffer;
459
497
  } else {
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508
546
  }
509
547
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
510
548
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
511
- return buft_ctx->sock == rpc_ctx->sock;
549
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
512
550
  }
513
551
 
514
552
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521
559
  /* .is_host = */ NULL,
522
560
  };
523
561
 
524
-
525
562
  GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
526
563
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
527
564
 
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530
567
 
531
568
  GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
532
569
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
533
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
534
- delete buft_ctx;
535
- delete rpc_ctx->buft;
536
570
  delete rpc_ctx;
537
571
  delete backend;
538
572
  }
539
573
 
540
574
  GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
541
575
  ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
542
- return ctx->buft;
576
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
543
577
  }
544
578
 
545
579
  GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590
624
  std::vector<uint8_t> input;
591
625
  serialize_graph(cgraph, input);
592
626
  std::vector<uint8_t> output;
593
- bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
627
+ auto sock = get_socket(rpc_ctx->endpoint);
628
+ bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
594
629
  GGML_ASSERT(status);
595
630
  GGML_ASSERT(output.size() == 1);
596
631
  return (enum ggml_status)output[0];
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624
659
  /* .event_synchronize = */ NULL,
625
660
  };
626
661
 
627
- static std::unordered_map<std::string, ggml_backend_t> instances;
628
-
629
662
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
630
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
631
- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
632
- }
633
-
634
- GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
635
- std::string endpoint_str(endpoint);
636
- if (instances.find(endpoint_str) != instances.end()) {
637
- return instances[endpoint_str];
638
- }
639
- #ifdef _WIN32
640
- {
641
- WSADATA wsaData;
642
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
643
- if (res != 0) {
644
- return nullptr;
645
- }
646
- }
647
- #endif
648
- fprintf(stderr, "Connecting to %s\n", endpoint);
649
- std::string host;
650
- int port;
651
- if (!parse_endpoint(endpoint, host, port)) {
652
- return nullptr;
653
- }
654
- auto sock = socket_connect(host.c_str(), port);
663
+ static std::mutex mutex;
664
+ std::lock_guard<std::mutex> lock(mutex);
665
+ // NOTE: buffer types are allocated and never freed; this is by design
666
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
667
+ auto it = buft_map.find(endpoint);
668
+ if (it != buft_map.end()) {
669
+ return it->second;
670
+ }
671
+ auto sock = get_socket(endpoint);
655
672
  if (sock == nullptr) {
656
673
  return nullptr;
657
674
  }
658
675
  size_t alignment = get_alignment(sock);
659
676
  size_t max_size = get_max_size(sock);
660
677
  ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661
- /* .sock = */ sock,
662
- /* .name = */ "RPC" + std::to_string(sock->fd),
678
+ /* .endpoint = */ endpoint,
679
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
663
680
  /* .alignment = */ alignment,
664
- /* .max_size = */ max_size
681
+ /* .max_size = */ max_size
665
682
  };
666
683
 
667
684
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
668
685
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
669
686
  /* .context = */ buft_ctx
670
687
  };
688
+ buft_map[endpoint] = buft;
689
+ return buft;
690
+ }
671
691
 
692
+ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
672
693
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673
- /* .endpoint = */ endpoint,
674
- /* .name = */ "RPC" + std::to_string(sock->fd),
675
- /* .sock = */ sock,
676
- /* .buft = */ buft
694
+ /* .endpoint = */ endpoint,
695
+ /* .name = */ "RPC",
677
696
  };
678
697
 
679
- instances[endpoint] = new ggml_backend {
698
+ ggml_backend_t backend = new ggml_backend {
680
699
  /* .guid = */ ggml_backend_rpc_guid(),
681
700
  /* .interface = */ ggml_backend_rpc_interface,
682
701
  /* .context = */ ctx
683
702
  };
684
-
685
- return instances[endpoint];
703
+ return backend;
686
704
  }
687
705
 
688
706
  GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706
724
  }
707
725
 
708
726
  GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
709
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
710
- if (backend == nullptr) {
727
+ auto sock = get_socket(endpoint);
728
+ if (sock == nullptr) {
711
729
  *free = 0;
712
730
  *total = 0;
713
731
  return;
714
732
  }
715
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
716
- get_device_memory(ctx->sock, free, total);
733
+ get_device_memory(sock, free, total);
717
734
  }
718
735
 
719
736
  // RPC server-side implementation