llama_cpp 0.15.3 → 0.15.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +12 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- data/vendor/tmp/llama.cpp/Makefile +4 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +27 -10
- data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +65 -11
- data/vendor/tmp/llama.cpp/ggml-metal.metal +69 -27
- data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +338 -160
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +2 -0
- data/vendor/tmp/llama.cpp/ggml.c +145 -101
- data/vendor/tmp/llama.cpp/ggml.h +18 -3
- data/vendor/tmp/llama.cpp/llama.cpp +637 -249
- data/vendor/tmp/llama.cpp/llama.h +11 -5
- metadata +2 -2
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
3813
3813
|
return;
|
3814
3814
|
}
|
3815
3815
|
#endif
|
3816
|
-
#if defined(
|
3816
|
+
#if defined(__ARM_FEATURE_SVE)
|
3817
|
+
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
3818
|
+
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
3819
|
+
|
3820
|
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
3821
|
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
3822
|
+
|
3823
|
+
assert(nb % 2 == 0); // TODO: handle odd nb
|
3824
|
+
|
3825
|
+
for (int i = 0; i < nb; i += 2) {
|
3826
|
+
const block_q4_0 * restrict x0 = &x[i + 0];
|
3827
|
+
const block_q4_0 * restrict x1 = &x[i + 1];
|
3828
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
3829
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
3830
|
+
|
3831
|
+
// load x
|
3832
|
+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
3833
|
+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
3834
|
+
|
3835
|
+
// 4-bit -> 8-bit
|
3836
|
+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
3837
|
+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
3838
|
+
|
3839
|
+
// sub 8
|
3840
|
+
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
3841
|
+
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
3842
|
+
|
3843
|
+
// load y
|
3844
|
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
3845
|
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
3846
|
+
|
3847
|
+
// dot product
|
3848
|
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
3849
|
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
3850
|
+
}
|
3851
|
+
|
3852
|
+
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
3853
|
+
#elif defined(__ARM_NEON)
|
3817
3854
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3818
3855
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3819
3856
|
|
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
5384
5421
|
return;
|
5385
5422
|
}
|
5386
5423
|
#endif
|
5387
|
-
#if defined(
|
5424
|
+
#if defined(__ARM_FEATURE_SVE)
|
5425
|
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
5426
|
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
5427
|
+
|
5428
|
+
assert(nb % 2 == 0); // TODO: handle odd nb
|
5429
|
+
|
5430
|
+
for (int i = 0; i < nb; i += 2) {
|
5431
|
+
const block_q8_0 * restrict x0 = &x[i + 0];
|
5432
|
+
const block_q8_0 * restrict x1 = &x[i + 1];
|
5433
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
5434
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
5435
|
+
|
5436
|
+
// load x
|
5437
|
+
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
5438
|
+
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
5439
|
+
|
5440
|
+
// load y
|
5441
|
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
5442
|
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
5443
|
+
|
5444
|
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
5445
|
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
5446
|
+
}
|
5447
|
+
|
5448
|
+
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
5449
|
+
#elif defined(__ARM_NEON)
|
5388
5450
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
5389
5451
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
5390
5452
|
|
@@ -6026,6 +6088,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
6026
6088
|
|
6027
6089
|
const uint8_t * restrict q2 = x[i].qs;
|
6028
6090
|
const int8_t * restrict q8 = y[i].qs;
|
6091
|
+
|
6029
6092
|
const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
|
6030
6093
|
const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
|
6031
6094
|
const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
|
@@ -6745,6 +6808,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
6745
6808
|
for (int i = 0; i < nb; ++i) {
|
6746
6809
|
|
6747
6810
|
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
6811
|
+
const uint8_t * restrict q3 = x[i].qs;
|
6812
|
+
const int8_t * restrict q8 = y[i].qs;
|
6748
6813
|
// Set up scales
|
6749
6814
|
memcpy(aux, x[i].scales, 12);
|
6750
6815
|
__m128i scales128 = lsx_set_w(
|
@@ -6766,29 +6831,32 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
6766
6831
|
|
6767
6832
|
int bit = 0;
|
6768
6833
|
int is = 0;
|
6834
|
+
__m256i xvbit;
|
6769
6835
|
|
6770
|
-
const uint8_t * restrict q3 = x[i].qs;
|
6771
|
-
const int8_t * restrict q8 = y[i].qs;
|
6772
6836
|
|
6773
6837
|
for (int j = 0; j < QK_K/128; ++j) {
|
6774
6838
|
// load low 2 bits
|
6775
6839
|
const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
|
6776
6840
|
|
6841
|
+
xvbit = __lasx_xvreplgr2vr_h(bit);
|
6777
6842
|
// prepare low and high bits
|
6778
6843
|
const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
|
6779
|
-
const __m256i q3h_0 = __lasx_xvslli_h(
|
6844
|
+
const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
6780
6845
|
++bit;
|
6781
6846
|
|
6847
|
+
xvbit = __lasx_xvreplgr2vr_h(bit);
|
6782
6848
|
const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
|
6783
|
-
const __m256i q3h_1 = __lasx_xvslli_h(
|
6849
|
+
const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
6784
6850
|
++bit;
|
6785
6851
|
|
6852
|
+
xvbit = __lasx_xvreplgr2vr_h(bit);
|
6786
6853
|
const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
|
6787
|
-
const __m256i q3h_2 = __lasx_xvslli_h(
|
6854
|
+
const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
6788
6855
|
++bit;
|
6789
6856
|
|
6857
|
+
xvbit = __lasx_xvreplgr2vr_h(bit);
|
6790
6858
|
const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
|
6791
|
-
const __m256i q3h_3 = __lasx_xvslli_h(
|
6859
|
+
const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
6792
6860
|
++bit;
|
6793
6861
|
|
6794
6862
|
// load Q8 quants
|
@@ -7337,6 +7405,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7337
7405
|
*s = vec_extract(vsumf0, 0);
|
7338
7406
|
|
7339
7407
|
#elif defined __loongarch_asx
|
7408
|
+
GGML_UNUSED(kmask1);
|
7409
|
+
GGML_UNUSED(kmask2);
|
7410
|
+
GGML_UNUSED(kmask3);
|
7340
7411
|
|
7341
7412
|
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
7342
7413
|
|
@@ -7349,6 +7420,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7349
7420
|
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
7350
7421
|
|
7351
7422
|
memcpy(utmp, x[i].scales, 12);
|
7423
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
7424
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
7425
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
7426
|
+
utmp[2] = uaux;
|
7427
|
+
utmp[0] &= kmask1;
|
7352
7428
|
|
7353
7429
|
const uint8_t * restrict q4 = x[i].qs;
|
7354
7430
|
const int8_t * restrict q8 = y[i].qs;
|
@@ -7388,16 +7464,17 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7388
7464
|
|
7389
7465
|
__m256 vd = __lasx_xvreplfr2vr_s(d);
|
7390
7466
|
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
|
7467
|
+
|
7391
7468
|
}
|
7392
7469
|
|
7393
7470
|
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
|
7394
7471
|
__m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
|
7395
7472
|
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
|
7396
7473
|
|
7474
|
+
|
7397
7475
|
ft_union fi;
|
7398
7476
|
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
|
7399
7477
|
*s = hsum_float_8(acc) + fi.f ;
|
7400
|
-
|
7401
7478
|
#else
|
7402
7479
|
|
7403
7480
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
@@ -7935,6 +8012,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7935
8012
|
*s = vec_extract(vsumf0, 0);
|
7936
8013
|
|
7937
8014
|
#elif defined __loongarch_asx
|
8015
|
+
GGML_UNUSED(kmask1);
|
8016
|
+
GGML_UNUSED(kmask2);
|
8017
|
+
GGML_UNUSED(kmask3);
|
7938
8018
|
|
7939
8019
|
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
7940
8020
|
const __m128i mzero = __lsx_vldi(0);
|
@@ -7953,6 +8033,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7953
8033
|
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
7954
8034
|
|
7955
8035
|
memcpy(utmp, x[i].scales, 12);
|
8036
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
8037
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
8038
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
8039
|
+
utmp[2] = uaux;
|
8040
|
+
utmp[0] &= kmask1;
|
7956
8041
|
|
7957
8042
|
const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
|
7958
8043
|
|
@@ -7971,6 +8056,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7971
8056
|
__m256i sumi = __lasx_xvldi(0);
|
7972
8057
|
|
7973
8058
|
int bit = 0;
|
8059
|
+
__m256i xvbit;
|
7974
8060
|
|
7975
8061
|
for (int j = 0; j < QK_K/64; ++j) {
|
7976
8062
|
|
@@ -7979,13 +8065,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7979
8065
|
|
7980
8066
|
const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
|
7981
8067
|
|
8068
|
+
xvbit = __lasx_xvreplgr2vr_h(bit++);
|
7982
8069
|
const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
|
7983
|
-
const __m256i q5h_0 = __lasx_xvslli_h(
|
8070
|
+
const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
|
7984
8071
|
const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
|
7985
8072
|
hmask = __lasx_xvslli_h(hmask, 1);
|
7986
8073
|
|
8074
|
+
xvbit = __lasx_xvreplgr2vr_h(bit++);
|
7987
8075
|
const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
|
7988
|
-
const __m256i q5h_1 = __lasx_xvslli_h(
|
8076
|
+
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
|
7989
8077
|
const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
|
7990
8078
|
hmask = __lasx_xvslli_h(hmask, 1);
|
7991
8079
|
|
@@ -7999,10 +8087,12 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7999
8087
|
p16_1 = lasx_madd_h(scale_1, p16_1);
|
8000
8088
|
|
8001
8089
|
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
|
8090
|
+
|
8002
8091
|
}
|
8003
8092
|
|
8004
8093
|
__m256 vd = __lasx_xvreplfr2vr_s(d);
|
8005
8094
|
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
|
8095
|
+
|
8006
8096
|
}
|
8007
8097
|
|
8008
8098
|
*s = hsum_float_8(acc) + summs;
|
@@ -6,6 +6,7 @@
|
|
6
6
|
#include <string>
|
7
7
|
#include <vector>
|
8
8
|
#include <memory>
|
9
|
+
#include <mutex>
|
9
10
|
#include <unordered_map>
|
10
11
|
#include <unordered_set>
|
11
12
|
#ifdef _WIN32
|
@@ -47,6 +48,7 @@ struct socket_t {
|
|
47
48
|
sockfd_t fd;
|
48
49
|
socket_t(sockfd_t fd) : fd(fd) {}
|
49
50
|
~socket_t() {
|
51
|
+
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
|
50
52
|
#ifdef _WIN32
|
51
53
|
closesocket(this->fd);
|
52
54
|
#else
|
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
|
97
99
|
}
|
98
100
|
|
99
101
|
struct ggml_backend_rpc_buffer_type_context {
|
100
|
-
std::
|
102
|
+
std::string endpoint;
|
101
103
|
std::string name;
|
102
104
|
size_t alignment;
|
103
105
|
size_t max_size;
|
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
|
|
106
108
|
struct ggml_backend_rpc_context {
|
107
109
|
std::string endpoint;
|
108
110
|
std::string name;
|
109
|
-
std::shared_ptr<socket_t> sock;
|
110
|
-
ggml_backend_buffer_type_t buft;
|
111
111
|
};
|
112
112
|
|
113
113
|
struct ggml_backend_rpc_buffer_context {
|
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|
231
231
|
return true;
|
232
232
|
}
|
233
233
|
|
234
|
-
static bool parse_endpoint(const
|
235
|
-
|
236
|
-
size_t pos = str.find(':');
|
234
|
+
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
235
|
+
size_t pos = endpoint.find(':');
|
237
236
|
if (pos == std::string::npos) {
|
238
237
|
return false;
|
239
238
|
}
|
240
|
-
host =
|
241
|
-
port = std::stoi(
|
239
|
+
host = endpoint.substr(0, pos);
|
240
|
+
port = std::stoi(endpoint.substr(pos + 1));
|
242
241
|
return true;
|
243
242
|
}
|
244
243
|
|
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
273
272
|
|
274
273
|
// RPC client-side implementation
|
275
274
|
|
275
|
+
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
276
|
+
static std::mutex mutex;
|
277
|
+
std::lock_guard<std::mutex> lock(mutex);
|
278
|
+
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
|
279
|
+
static bool initialized = false;
|
280
|
+
|
281
|
+
auto it = sockets.find(endpoint);
|
282
|
+
if (it != sockets.end()) {
|
283
|
+
if (auto sock = it->second.lock()) {
|
284
|
+
return sock;
|
285
|
+
}
|
286
|
+
}
|
287
|
+
std::string host;
|
288
|
+
int port;
|
289
|
+
if (!parse_endpoint(endpoint, host, port)) {
|
290
|
+
return nullptr;
|
291
|
+
}
|
292
|
+
#ifdef _WIN32
|
293
|
+
if (!initialized) {
|
294
|
+
WSADATA wsaData;
|
295
|
+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
296
|
+
if (res != 0) {
|
297
|
+
return nullptr;
|
298
|
+
}
|
299
|
+
initialized = true;
|
300
|
+
}
|
301
|
+
#else
|
302
|
+
UNUSED(initialized);
|
303
|
+
#endif
|
304
|
+
auto sock = socket_connect(host.c_str(), port);
|
305
|
+
if (sock == nullptr) {
|
306
|
+
return nullptr;
|
307
|
+
}
|
308
|
+
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
309
|
+
sockets[endpoint] = sock;
|
310
|
+
return sock;
|
311
|
+
}
|
312
|
+
|
276
313
|
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
|
277
314
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
278
315
|
return ctx->name.c_str();
|
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
442
479
|
std::vector<uint8_t> input(input_size, 0);
|
443
480
|
memcpy(input.data(), &size, sizeof(size));
|
444
481
|
std::vector<uint8_t> output;
|
445
|
-
|
482
|
+
auto sock = get_socket(buft_ctx->endpoint);
|
483
|
+
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
|
446
484
|
GGML_ASSERT(status);
|
447
485
|
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
448
486
|
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
453
491
|
if (remote_ptr != 0) {
|
454
492
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
455
493
|
ggml_backend_rpc_buffer_interface,
|
456
|
-
new ggml_backend_rpc_buffer_context{
|
494
|
+
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
|
457
495
|
remote_size);
|
458
496
|
return buffer;
|
459
497
|
} else {
|
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
|
|
508
546
|
}
|
509
547
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
510
548
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
511
|
-
return buft_ctx->
|
549
|
+
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
512
550
|
}
|
513
551
|
|
514
552
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
521
559
|
/* .is_host = */ NULL,
|
522
560
|
};
|
523
561
|
|
524
|
-
|
525
562
|
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
526
563
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
527
564
|
|
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
|
530
567
|
|
531
568
|
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
532
569
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
533
|
-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
|
534
|
-
delete buft_ctx;
|
535
|
-
delete rpc_ctx->buft;
|
536
570
|
delete rpc_ctx;
|
537
571
|
delete backend;
|
538
572
|
}
|
539
573
|
|
540
574
|
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
|
541
575
|
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
542
|
-
return ctx->
|
576
|
+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
543
577
|
}
|
544
578
|
|
545
579
|
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|
590
624
|
std::vector<uint8_t> input;
|
591
625
|
serialize_graph(cgraph, input);
|
592
626
|
std::vector<uint8_t> output;
|
593
|
-
|
627
|
+
auto sock = get_socket(rpc_ctx->endpoint);
|
628
|
+
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
|
594
629
|
GGML_ASSERT(status);
|
595
630
|
GGML_ASSERT(output.size() == 1);
|
596
631
|
return (enum ggml_status)output[0];
|
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
624
659
|
/* .event_synchronize = */ NULL,
|
625
660
|
};
|
626
661
|
|
627
|
-
static std::unordered_map<std::string, ggml_backend_t> instances;
|
628
|
-
|
629
662
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
#ifdef _WIN32
|
640
|
-
{
|
641
|
-
WSADATA wsaData;
|
642
|
-
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
643
|
-
if (res != 0) {
|
644
|
-
return nullptr;
|
645
|
-
}
|
646
|
-
}
|
647
|
-
#endif
|
648
|
-
fprintf(stderr, "Connecting to %s\n", endpoint);
|
649
|
-
std::string host;
|
650
|
-
int port;
|
651
|
-
if (!parse_endpoint(endpoint, host, port)) {
|
652
|
-
return nullptr;
|
653
|
-
}
|
654
|
-
auto sock = socket_connect(host.c_str(), port);
|
663
|
+
static std::mutex mutex;
|
664
|
+
std::lock_guard<std::mutex> lock(mutex);
|
665
|
+
// NOTE: buffer types are allocated and never freed; this is by design
|
666
|
+
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
667
|
+
auto it = buft_map.find(endpoint);
|
668
|
+
if (it != buft_map.end()) {
|
669
|
+
return it->second;
|
670
|
+
}
|
671
|
+
auto sock = get_socket(endpoint);
|
655
672
|
if (sock == nullptr) {
|
656
673
|
return nullptr;
|
657
674
|
}
|
658
675
|
size_t alignment = get_alignment(sock);
|
659
676
|
size_t max_size = get_max_size(sock);
|
660
677
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
661
|
-
/* .
|
662
|
-
/* .name
|
678
|
+
/* .endpoint = */ endpoint,
|
679
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
663
680
|
/* .alignment = */ alignment,
|
664
|
-
/* .max_size
|
681
|
+
/* .max_size = */ max_size
|
665
682
|
};
|
666
683
|
|
667
684
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
668
685
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
669
686
|
/* .context = */ buft_ctx
|
670
687
|
};
|
688
|
+
buft_map[endpoint] = buft;
|
689
|
+
return buft;
|
690
|
+
}
|
671
691
|
|
692
|
+
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
672
693
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
673
|
-
/* .endpoint
|
674
|
-
/* .name
|
675
|
-
/* .sock = */ sock,
|
676
|
-
/* .buft = */ buft
|
694
|
+
/* .endpoint = */ endpoint,
|
695
|
+
/* .name = */ "RPC",
|
677
696
|
};
|
678
697
|
|
679
|
-
|
698
|
+
ggml_backend_t backend = new ggml_backend {
|
680
699
|
/* .guid = */ ggml_backend_rpc_guid(),
|
681
700
|
/* .interface = */ ggml_backend_rpc_interface,
|
682
701
|
/* .context = */ ctx
|
683
702
|
};
|
684
|
-
|
685
|
-
return instances[endpoint];
|
703
|
+
return backend;
|
686
704
|
}
|
687
705
|
|
688
706
|
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
|
|
706
724
|
}
|
707
725
|
|
708
726
|
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
709
|
-
|
710
|
-
if (
|
727
|
+
auto sock = get_socket(endpoint);
|
728
|
+
if (sock == nullptr) {
|
711
729
|
*free = 0;
|
712
730
|
*total = 0;
|
713
731
|
return;
|
714
732
|
}
|
715
|
-
|
716
|
-
get_device_memory(ctx->sock, free, total);
|
733
|
+
get_device_memory(sock, free, total);
|
717
734
|
}
|
718
735
|
|
719
736
|
// RPC server-side implementation
|