llama_cpp 0.15.3 → 0.15.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3813
3813
  return;
3814
3814
  }
3815
3815
  #endif
3816
- #if defined(__ARM_NEON)
3816
+ #if defined(__ARM_FEATURE_SVE)
3817
+ const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
3818
+ const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
3819
+
3820
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
3821
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
3822
+
3823
+ assert(nb % 2 == 0); // TODO: handle odd nb
3824
+
3825
+ for (int i = 0; i < nb; i += 2) {
3826
+ const block_q4_0 * restrict x0 = &x[i + 0];
3827
+ const block_q4_0 * restrict x1 = &x[i + 1];
3828
+ const block_q8_0 * restrict y0 = &y[i + 0];
3829
+ const block_q8_0 * restrict y1 = &y[i + 1];
3830
+
3831
+ // load x
3832
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
3833
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
3834
+
3835
+ // 4-bit -> 8-bit
3836
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
3837
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
3838
+
3839
+ // sub 8
3840
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
3841
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
3842
+
3843
+ // load y
3844
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
3845
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
3846
+
3847
+ // dot product
3848
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3849
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3850
+ }
3851
+
3852
+ *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
3853
+ #elif defined(__ARM_NEON)
3817
3854
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3818
3855
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
3819
3856
 
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5384
5421
  return;
5385
5422
  }
5386
5423
  #endif
5387
- #if defined(__ARM_NEON)
5424
+ #if defined(__ARM_FEATURE_SVE)
5425
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5426
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
5427
+
5428
+ assert(nb % 2 == 0); // TODO: handle odd nb
5429
+
5430
+ for (int i = 0; i < nb; i += 2) {
5431
+ const block_q8_0 * restrict x0 = &x[i + 0];
5432
+ const block_q8_0 * restrict x1 = &x[i + 1];
5433
+ const block_q8_0 * restrict y0 = &y[i + 0];
5434
+ const block_q8_0 * restrict y1 = &y[i + 1];
5435
+
5436
+ // load x
5437
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5438
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5439
+
5440
+ // load y
5441
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5442
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5443
+
5444
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5445
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5446
+ }
5447
+
5448
+ *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5449
+ #elif defined(__ARM_NEON)
5388
5450
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
5389
5451
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
5390
5452
 
@@ -6026,6 +6088,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6026
6088
 
6027
6089
  const uint8_t * restrict q2 = x[i].qs;
6028
6090
  const int8_t * restrict q8 = y[i].qs;
6091
+
6029
6092
  const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
6030
6093
  const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
6031
6094
  const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
@@ -6745,6 +6808,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6745
6808
  for (int i = 0; i < nb; ++i) {
6746
6809
 
6747
6810
  const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6811
+ const uint8_t * restrict q3 = x[i].qs;
6812
+ const int8_t * restrict q8 = y[i].qs;
6748
6813
  // Set up scales
6749
6814
  memcpy(aux, x[i].scales, 12);
6750
6815
  __m128i scales128 = lsx_set_w(
@@ -6766,29 +6831,32 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6766
6831
 
6767
6832
  int bit = 0;
6768
6833
  int is = 0;
6834
+ __m256i xvbit;
6769
6835
 
6770
- const uint8_t * restrict q3 = x[i].qs;
6771
- const int8_t * restrict q8 = y[i].qs;
6772
6836
 
6773
6837
  for (int j = 0; j < QK_K/128; ++j) {
6774
6838
  // load low 2 bits
6775
6839
  const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
6776
6840
 
6841
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6777
6842
  // prepare low and high bits
6778
6843
  const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
6779
- const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6844
+ const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6780
6845
  ++bit;
6781
6846
 
6847
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6782
6848
  const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
6783
- const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6849
+ const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6784
6850
  ++bit;
6785
6851
 
6852
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6786
6853
  const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
6787
- const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6854
+ const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6788
6855
  ++bit;
6789
6856
 
6857
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6790
6858
  const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
6791
- const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6859
+ const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6792
6860
  ++bit;
6793
6861
 
6794
6862
  // load Q8 quants
@@ -7337,6 +7405,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7337
7405
  *s = vec_extract(vsumf0, 0);
7338
7406
 
7339
7407
  #elif defined __loongarch_asx
7408
+ GGML_UNUSED(kmask1);
7409
+ GGML_UNUSED(kmask2);
7410
+ GGML_UNUSED(kmask3);
7340
7411
 
7341
7412
  const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7342
7413
 
@@ -7349,6 +7420,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7349
7420
  const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7350
7421
 
7351
7422
  memcpy(utmp, x[i].scales, 12);
7423
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7424
+ const uint32_t uaux = utmp[1] & kmask1;
7425
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7426
+ utmp[2] = uaux;
7427
+ utmp[0] &= kmask1;
7352
7428
 
7353
7429
  const uint8_t * restrict q4 = x[i].qs;
7354
7430
  const int8_t * restrict q8 = y[i].qs;
@@ -7388,16 +7464,17 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7388
7464
 
7389
7465
  __m256 vd = __lasx_xvreplfr2vr_s(d);
7390
7466
  acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
7467
+
7391
7468
  }
7392
7469
 
7393
7470
  acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
7394
7471
  __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
7395
7472
  acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
7396
7473
 
7474
+
7397
7475
  ft_union fi;
7398
7476
  fi.i = __lsx_vpickve2gr_w(acc_m, 0);
7399
7477
  *s = hsum_float_8(acc) + fi.f ;
7400
-
7401
7478
  #else
7402
7479
 
7403
7480
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -7935,6 +8012,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7935
8012
  *s = vec_extract(vsumf0, 0);
7936
8013
 
7937
8014
  #elif defined __loongarch_asx
8015
+ GGML_UNUSED(kmask1);
8016
+ GGML_UNUSED(kmask2);
8017
+ GGML_UNUSED(kmask3);
7938
8018
 
7939
8019
  const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7940
8020
  const __m128i mzero = __lsx_vldi(0);
@@ -7953,6 +8033,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7953
8033
  const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7954
8034
 
7955
8035
  memcpy(utmp, x[i].scales, 12);
8036
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
8037
+ const uint32_t uaux = utmp[1] & kmask1;
8038
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
8039
+ utmp[2] = uaux;
8040
+ utmp[0] &= kmask1;
7956
8041
 
7957
8042
  const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
7958
8043
 
@@ -7971,6 +8056,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7971
8056
  __m256i sumi = __lasx_xvldi(0);
7972
8057
 
7973
8058
  int bit = 0;
8059
+ __m256i xvbit;
7974
8060
 
7975
8061
  for (int j = 0; j < QK_K/64; ++j) {
7976
8062
 
@@ -7979,13 +8065,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7979
8065
 
7980
8066
  const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
7981
8067
 
8068
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
7982
8069
  const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
7983
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
8070
+ const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
7984
8071
  const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
7985
8072
  hmask = __lasx_xvslli_h(hmask, 1);
7986
8073
 
8074
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
7987
8075
  const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
7988
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
8076
+ const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
7989
8077
  const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
7990
8078
  hmask = __lasx_xvslli_h(hmask, 1);
7991
8079
 
@@ -7999,10 +8087,12 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7999
8087
  p16_1 = lasx_madd_h(scale_1, p16_1);
8000
8088
 
8001
8089
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
8090
+
8002
8091
  }
8003
8092
 
8004
8093
  __m256 vd = __lasx_xvreplfr2vr_s(d);
8005
8094
  acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
8095
+
8006
8096
  }
8007
8097
 
8008
8098
  *s = hsum_float_8(acc) + summs;
@@ -6,6 +6,7 @@
6
6
  #include <string>
7
7
  #include <vector>
8
8
  #include <memory>
9
+ #include <mutex>
9
10
  #include <unordered_map>
10
11
  #include <unordered_set>
11
12
  #ifdef _WIN32
@@ -47,6 +48,7 @@ struct socket_t {
47
48
  sockfd_t fd;
48
49
  socket_t(sockfd_t fd) : fd(fd) {}
49
50
  ~socket_t() {
51
+ GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
50
52
  #ifdef _WIN32
51
53
  closesocket(this->fd);
52
54
  #else
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
97
99
  }
98
100
 
99
101
  struct ggml_backend_rpc_buffer_type_context {
100
- std::shared_ptr<socket_t> sock;
102
+ std::string endpoint;
101
103
  std::string name;
102
104
  size_t alignment;
103
105
  size_t max_size;
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
106
108
  struct ggml_backend_rpc_context {
107
109
  std::string endpoint;
108
110
  std::string name;
109
- std::shared_ptr<socket_t> sock;
110
- ggml_backend_buffer_type_t buft;
111
111
  };
112
112
 
113
113
  struct ggml_backend_rpc_buffer_context {
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231
231
  return true;
232
232
  }
233
233
 
234
- static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
235
- std::string str(endpoint);
236
- size_t pos = str.find(':');
234
+ static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
235
+ size_t pos = endpoint.find(':');
237
236
  if (pos == std::string::npos) {
238
237
  return false;
239
238
  }
240
- host = str.substr(0, pos);
241
- port = std::stoi(str.substr(pos + 1));
239
+ host = endpoint.substr(0, pos);
240
+ port = std::stoi(endpoint.substr(pos + 1));
242
241
  return true;
243
242
  }
244
243
 
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273
272
 
274
273
  // RPC client-side implementation
275
274
 
275
+ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
276
+ static std::mutex mutex;
277
+ std::lock_guard<std::mutex> lock(mutex);
278
+ static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
279
+ static bool initialized = false;
280
+
281
+ auto it = sockets.find(endpoint);
282
+ if (it != sockets.end()) {
283
+ if (auto sock = it->second.lock()) {
284
+ return sock;
285
+ }
286
+ }
287
+ std::string host;
288
+ int port;
289
+ if (!parse_endpoint(endpoint, host, port)) {
290
+ return nullptr;
291
+ }
292
+ #ifdef _WIN32
293
+ if (!initialized) {
294
+ WSADATA wsaData;
295
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
296
+ if (res != 0) {
297
+ return nullptr;
298
+ }
299
+ initialized = true;
300
+ }
301
+ #else
302
+ UNUSED(initialized);
303
+ #endif
304
+ auto sock = socket_connect(host.c_str(), port);
305
+ if (sock == nullptr) {
306
+ return nullptr;
307
+ }
308
+ GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
309
+ sockets[endpoint] = sock;
310
+ return sock;
311
+ }
312
+
276
313
  GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
277
314
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
278
315
  return ctx->name.c_str();
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442
479
  std::vector<uint8_t> input(input_size, 0);
443
480
  memcpy(input.data(), &size, sizeof(size));
444
481
  std::vector<uint8_t> output;
445
- bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
482
+ auto sock = get_socket(buft_ctx->endpoint);
483
+ bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
446
484
  GGML_ASSERT(status);
447
485
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
448
486
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453
491
  if (remote_ptr != 0) {
454
492
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
455
493
  ggml_backend_rpc_buffer_interface,
456
- new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
494
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
457
495
  remote_size);
458
496
  return buffer;
459
497
  } else {
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508
546
  }
509
547
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
510
548
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
511
- return buft_ctx->sock == rpc_ctx->sock;
549
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
512
550
  }
513
551
 
514
552
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521
559
  /* .is_host = */ NULL,
522
560
  };
523
561
 
524
-
525
562
  GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
526
563
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
527
564
 
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530
567
 
531
568
  GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
532
569
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
533
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
534
- delete buft_ctx;
535
- delete rpc_ctx->buft;
536
570
  delete rpc_ctx;
537
571
  delete backend;
538
572
  }
539
573
 
540
574
  GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
541
575
  ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
542
- return ctx->buft;
576
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
543
577
  }
544
578
 
545
579
  GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590
624
  std::vector<uint8_t> input;
591
625
  serialize_graph(cgraph, input);
592
626
  std::vector<uint8_t> output;
593
- bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
627
+ auto sock = get_socket(rpc_ctx->endpoint);
628
+ bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
594
629
  GGML_ASSERT(status);
595
630
  GGML_ASSERT(output.size() == 1);
596
631
  return (enum ggml_status)output[0];
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624
659
  /* .event_synchronize = */ NULL,
625
660
  };
626
661
 
627
- static std::unordered_map<std::string, ggml_backend_t> instances;
628
-
629
662
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
630
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
631
- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
632
- }
633
-
634
- GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
635
- std::string endpoint_str(endpoint);
636
- if (instances.find(endpoint_str) != instances.end()) {
637
- return instances[endpoint_str];
638
- }
639
- #ifdef _WIN32
640
- {
641
- WSADATA wsaData;
642
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
643
- if (res != 0) {
644
- return nullptr;
645
- }
646
- }
647
- #endif
648
- fprintf(stderr, "Connecting to %s\n", endpoint);
649
- std::string host;
650
- int port;
651
- if (!parse_endpoint(endpoint, host, port)) {
652
- return nullptr;
653
- }
654
- auto sock = socket_connect(host.c_str(), port);
663
+ static std::mutex mutex;
664
+ std::lock_guard<std::mutex> lock(mutex);
665
+ // NOTE: buffer types are allocated and never freed; this is by design
666
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
667
+ auto it = buft_map.find(endpoint);
668
+ if (it != buft_map.end()) {
669
+ return it->second;
670
+ }
671
+ auto sock = get_socket(endpoint);
655
672
  if (sock == nullptr) {
656
673
  return nullptr;
657
674
  }
658
675
  size_t alignment = get_alignment(sock);
659
676
  size_t max_size = get_max_size(sock);
660
677
  ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661
- /* .sock = */ sock,
662
- /* .name = */ "RPC" + std::to_string(sock->fd),
678
+ /* .endpoint = */ endpoint,
679
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
663
680
  /* .alignment = */ alignment,
664
- /* .max_size = */ max_size
681
+ /* .max_size = */ max_size
665
682
  };
666
683
 
667
684
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
668
685
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
669
686
  /* .context = */ buft_ctx
670
687
  };
688
+ buft_map[endpoint] = buft;
689
+ return buft;
690
+ }
671
691
 
692
+ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
672
693
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673
- /* .endpoint = */ endpoint,
674
- /* .name = */ "RPC" + std::to_string(sock->fd),
675
- /* .sock = */ sock,
676
- /* .buft = */ buft
694
+ /* .endpoint = */ endpoint,
695
+ /* .name = */ "RPC",
677
696
  };
678
697
 
679
- instances[endpoint] = new ggml_backend {
698
+ ggml_backend_t backend = new ggml_backend {
680
699
  /* .guid = */ ggml_backend_rpc_guid(),
681
700
  /* .interface = */ ggml_backend_rpc_interface,
682
701
  /* .context = */ ctx
683
702
  };
684
-
685
- return instances[endpoint];
703
+ return backend;
686
704
  }
687
705
 
688
706
  GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706
724
  }
707
725
 
708
726
  GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
709
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
710
- if (backend == nullptr) {
727
+ auto sock = get_socket(endpoint);
728
+ if (sock == nullptr) {
711
729
  *free = 0;
712
730
  *total = 0;
713
731
  return;
714
732
  }
715
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
716
- get_device_memory(ctx->sock, free, total);
733
+ get_device_memory(sock, free, total);
717
734
  }
718
735
 
719
736
  // RPC server-side implementation