llama_cpp 0.15.3 → 0.15.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +12 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- data/vendor/tmp/llama.cpp/Makefile +4 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +27 -10
- data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +65 -11
- data/vendor/tmp/llama.cpp/ggml-metal.metal +69 -27
- data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +338 -160
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +2 -0
- data/vendor/tmp/llama.cpp/ggml.c +145 -101
- data/vendor/tmp/llama.cpp/ggml.h +18 -3
- data/vendor/tmp/llama.cpp/llama.cpp +637 -249
- data/vendor/tmp/llama.cpp/llama.h +11 -5
- metadata +2 -2
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
3813
3813
|
return;
|
3814
3814
|
}
|
3815
3815
|
#endif
|
3816
|
-
#if defined(
|
3816
|
+
#if defined(__ARM_FEATURE_SVE)
|
3817
|
+
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
3818
|
+
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
3819
|
+
|
3820
|
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
3821
|
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
3822
|
+
|
3823
|
+
assert(nb % 2 == 0); // TODO: handle odd nb
|
3824
|
+
|
3825
|
+
for (int i = 0; i < nb; i += 2) {
|
3826
|
+
const block_q4_0 * restrict x0 = &x[i + 0];
|
3827
|
+
const block_q4_0 * restrict x1 = &x[i + 1];
|
3828
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
3829
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
3830
|
+
|
3831
|
+
// load x
|
3832
|
+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
3833
|
+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
3834
|
+
|
3835
|
+
// 4-bit -> 8-bit
|
3836
|
+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
3837
|
+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
3838
|
+
|
3839
|
+
// sub 8
|
3840
|
+
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
3841
|
+
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
3842
|
+
|
3843
|
+
// load y
|
3844
|
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
3845
|
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
3846
|
+
|
3847
|
+
// dot product
|
3848
|
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
3849
|
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
3850
|
+
}
|
3851
|
+
|
3852
|
+
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
3853
|
+
#elif defined(__ARM_NEON)
|
3817
3854
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3818
3855
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
3819
3856
|
|
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
5384
5421
|
return;
|
5385
5422
|
}
|
5386
5423
|
#endif
|
5387
|
-
#if defined(
|
5424
|
+
#if defined(__ARM_FEATURE_SVE)
|
5425
|
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
5426
|
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
5427
|
+
|
5428
|
+
assert(nb % 2 == 0); // TODO: handle odd nb
|
5429
|
+
|
5430
|
+
for (int i = 0; i < nb; i += 2) {
|
5431
|
+
const block_q8_0 * restrict x0 = &x[i + 0];
|
5432
|
+
const block_q8_0 * restrict x1 = &x[i + 1];
|
5433
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
5434
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
5435
|
+
|
5436
|
+
// load x
|
5437
|
+
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
5438
|
+
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
5439
|
+
|
5440
|
+
// load y
|
5441
|
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
5442
|
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
5443
|
+
|
5444
|
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
5445
|
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
5446
|
+
}
|
5447
|
+
|
5448
|
+
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
5449
|
+
#elif defined(__ARM_NEON)
|
5388
5450
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
5389
5451
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
5390
5452
|
|
@@ -6026,6 +6088,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
6026
6088
|
|
6027
6089
|
const uint8_t * restrict q2 = x[i].qs;
|
6028
6090
|
const int8_t * restrict q8 = y[i].qs;
|
6091
|
+
|
6029
6092
|
const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
|
6030
6093
|
const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
|
6031
6094
|
const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
|
@@ -6745,6 +6808,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
6745
6808
|
for (int i = 0; i < nb; ++i) {
|
6746
6809
|
|
6747
6810
|
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
6811
|
+
const uint8_t * restrict q3 = x[i].qs;
|
6812
|
+
const int8_t * restrict q8 = y[i].qs;
|
6748
6813
|
// Set up scales
|
6749
6814
|
memcpy(aux, x[i].scales, 12);
|
6750
6815
|
__m128i scales128 = lsx_set_w(
|
@@ -6766,29 +6831,32 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
6766
6831
|
|
6767
6832
|
int bit = 0;
|
6768
6833
|
int is = 0;
|
6834
|
+
__m256i xvbit;
|
6769
6835
|
|
6770
|
-
const uint8_t * restrict q3 = x[i].qs;
|
6771
|
-
const int8_t * restrict q8 = y[i].qs;
|
6772
6836
|
|
6773
6837
|
for (int j = 0; j < QK_K/128; ++j) {
|
6774
6838
|
// load low 2 bits
|
6775
6839
|
const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
|
6776
6840
|
|
6841
|
+
xvbit = __lasx_xvreplgr2vr_h(bit);
|
6777
6842
|
// prepare low and high bits
|
6778
6843
|
const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
|
6779
|
-
const __m256i q3h_0 = __lasx_xvslli_h(
|
6844
|
+
const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
6780
6845
|
++bit;
|
6781
6846
|
|
6847
|
+
xvbit = __lasx_xvreplgr2vr_h(bit);
|
6782
6848
|
const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
|
6783
|
-
const __m256i q3h_1 = __lasx_xvslli_h(
|
6849
|
+
const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
6784
6850
|
++bit;
|
6785
6851
|
|
6852
|
+
xvbit = __lasx_xvreplgr2vr_h(bit);
|
6786
6853
|
const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
|
6787
|
-
const __m256i q3h_2 = __lasx_xvslli_h(
|
6854
|
+
const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
6788
6855
|
++bit;
|
6789
6856
|
|
6857
|
+
xvbit = __lasx_xvreplgr2vr_h(bit);
|
6790
6858
|
const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
|
6791
|
-
const __m256i q3h_3 = __lasx_xvslli_h(
|
6859
|
+
const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
6792
6860
|
++bit;
|
6793
6861
|
|
6794
6862
|
// load Q8 quants
|
@@ -7337,6 +7405,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7337
7405
|
*s = vec_extract(vsumf0, 0);
|
7338
7406
|
|
7339
7407
|
#elif defined __loongarch_asx
|
7408
|
+
GGML_UNUSED(kmask1);
|
7409
|
+
GGML_UNUSED(kmask2);
|
7410
|
+
GGML_UNUSED(kmask3);
|
7340
7411
|
|
7341
7412
|
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
7342
7413
|
|
@@ -7349,6 +7420,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7349
7420
|
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
7350
7421
|
|
7351
7422
|
memcpy(utmp, x[i].scales, 12);
|
7423
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
7424
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
7425
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
7426
|
+
utmp[2] = uaux;
|
7427
|
+
utmp[0] &= kmask1;
|
7352
7428
|
|
7353
7429
|
const uint8_t * restrict q4 = x[i].qs;
|
7354
7430
|
const int8_t * restrict q8 = y[i].qs;
|
@@ -7388,16 +7464,17 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7388
7464
|
|
7389
7465
|
__m256 vd = __lasx_xvreplfr2vr_s(d);
|
7390
7466
|
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
|
7467
|
+
|
7391
7468
|
}
|
7392
7469
|
|
7393
7470
|
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
|
7394
7471
|
__m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
|
7395
7472
|
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
|
7396
7473
|
|
7474
|
+
|
7397
7475
|
ft_union fi;
|
7398
7476
|
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
|
7399
7477
|
*s = hsum_float_8(acc) + fi.f ;
|
7400
|
-
|
7401
7478
|
#else
|
7402
7479
|
|
7403
7480
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
@@ -7935,6 +8012,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7935
8012
|
*s = vec_extract(vsumf0, 0);
|
7936
8013
|
|
7937
8014
|
#elif defined __loongarch_asx
|
8015
|
+
GGML_UNUSED(kmask1);
|
8016
|
+
GGML_UNUSED(kmask2);
|
8017
|
+
GGML_UNUSED(kmask3);
|
7938
8018
|
|
7939
8019
|
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
7940
8020
|
const __m128i mzero = __lsx_vldi(0);
|
@@ -7953,6 +8033,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7953
8033
|
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
7954
8034
|
|
7955
8035
|
memcpy(utmp, x[i].scales, 12);
|
8036
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
8037
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
8038
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
8039
|
+
utmp[2] = uaux;
|
8040
|
+
utmp[0] &= kmask1;
|
7956
8041
|
|
7957
8042
|
const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
|
7958
8043
|
|
@@ -7971,6 +8056,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7971
8056
|
__m256i sumi = __lasx_xvldi(0);
|
7972
8057
|
|
7973
8058
|
int bit = 0;
|
8059
|
+
__m256i xvbit;
|
7974
8060
|
|
7975
8061
|
for (int j = 0; j < QK_K/64; ++j) {
|
7976
8062
|
|
@@ -7979,13 +8065,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7979
8065
|
|
7980
8066
|
const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
|
7981
8067
|
|
8068
|
+
xvbit = __lasx_xvreplgr2vr_h(bit++);
|
7982
8069
|
const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
|
7983
|
-
const __m256i q5h_0 = __lasx_xvslli_h(
|
8070
|
+
const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
|
7984
8071
|
const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
|
7985
8072
|
hmask = __lasx_xvslli_h(hmask, 1);
|
7986
8073
|
|
8074
|
+
xvbit = __lasx_xvreplgr2vr_h(bit++);
|
7987
8075
|
const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
|
7988
|
-
const __m256i q5h_1 = __lasx_xvslli_h(
|
8076
|
+
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
|
7989
8077
|
const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
|
7990
8078
|
hmask = __lasx_xvslli_h(hmask, 1);
|
7991
8079
|
|
@@ -7999,10 +8087,12 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
7999
8087
|
p16_1 = lasx_madd_h(scale_1, p16_1);
|
8000
8088
|
|
8001
8089
|
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
|
8090
|
+
|
8002
8091
|
}
|
8003
8092
|
|
8004
8093
|
__m256 vd = __lasx_xvreplfr2vr_s(d);
|
8005
8094
|
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
|
8095
|
+
|
8006
8096
|
}
|
8007
8097
|
|
8008
8098
|
*s = hsum_float_8(acc) + summs;
|
@@ -6,6 +6,7 @@
|
|
6
6
|
#include <string>
|
7
7
|
#include <vector>
|
8
8
|
#include <memory>
|
9
|
+
#include <mutex>
|
9
10
|
#include <unordered_map>
|
10
11
|
#include <unordered_set>
|
11
12
|
#ifdef _WIN32
|
@@ -47,6 +48,7 @@ struct socket_t {
|
|
47
48
|
sockfd_t fd;
|
48
49
|
socket_t(sockfd_t fd) : fd(fd) {}
|
49
50
|
~socket_t() {
|
51
|
+
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
|
50
52
|
#ifdef _WIN32
|
51
53
|
closesocket(this->fd);
|
52
54
|
#else
|
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
|
97
99
|
}
|
98
100
|
|
99
101
|
struct ggml_backend_rpc_buffer_type_context {
|
100
|
-
std::
|
102
|
+
std::string endpoint;
|
101
103
|
std::string name;
|
102
104
|
size_t alignment;
|
103
105
|
size_t max_size;
|
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
|
|
106
108
|
struct ggml_backend_rpc_context {
|
107
109
|
std::string endpoint;
|
108
110
|
std::string name;
|
109
|
-
std::shared_ptr<socket_t> sock;
|
110
|
-
ggml_backend_buffer_type_t buft;
|
111
111
|
};
|
112
112
|
|
113
113
|
struct ggml_backend_rpc_buffer_context {
|
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|
231
231
|
return true;
|
232
232
|
}
|
233
233
|
|
234
|
-
static bool parse_endpoint(const
|
235
|
-
|
236
|
-
size_t pos = str.find(':');
|
234
|
+
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
235
|
+
size_t pos = endpoint.find(':');
|
237
236
|
if (pos == std::string::npos) {
|
238
237
|
return false;
|
239
238
|
}
|
240
|
-
host =
|
241
|
-
port = std::stoi(
|
239
|
+
host = endpoint.substr(0, pos);
|
240
|
+
port = std::stoi(endpoint.substr(pos + 1));
|
242
241
|
return true;
|
243
242
|
}
|
244
243
|
|
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
273
272
|
|
274
273
|
// RPC client-side implementation
|
275
274
|
|
275
|
+
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
276
|
+
static std::mutex mutex;
|
277
|
+
std::lock_guard<std::mutex> lock(mutex);
|
278
|
+
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
|
279
|
+
static bool initialized = false;
|
280
|
+
|
281
|
+
auto it = sockets.find(endpoint);
|
282
|
+
if (it != sockets.end()) {
|
283
|
+
if (auto sock = it->second.lock()) {
|
284
|
+
return sock;
|
285
|
+
}
|
286
|
+
}
|
287
|
+
std::string host;
|
288
|
+
int port;
|
289
|
+
if (!parse_endpoint(endpoint, host, port)) {
|
290
|
+
return nullptr;
|
291
|
+
}
|
292
|
+
#ifdef _WIN32
|
293
|
+
if (!initialized) {
|
294
|
+
WSADATA wsaData;
|
295
|
+
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
296
|
+
if (res != 0) {
|
297
|
+
return nullptr;
|
298
|
+
}
|
299
|
+
initialized = true;
|
300
|
+
}
|
301
|
+
#else
|
302
|
+
UNUSED(initialized);
|
303
|
+
#endif
|
304
|
+
auto sock = socket_connect(host.c_str(), port);
|
305
|
+
if (sock == nullptr) {
|
306
|
+
return nullptr;
|
307
|
+
}
|
308
|
+
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
309
|
+
sockets[endpoint] = sock;
|
310
|
+
return sock;
|
311
|
+
}
|
312
|
+
|
276
313
|
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
|
277
314
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
278
315
|
return ctx->name.c_str();
|
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
442
479
|
std::vector<uint8_t> input(input_size, 0);
|
443
480
|
memcpy(input.data(), &size, sizeof(size));
|
444
481
|
std::vector<uint8_t> output;
|
445
|
-
|
482
|
+
auto sock = get_socket(buft_ctx->endpoint);
|
483
|
+
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
|
446
484
|
GGML_ASSERT(status);
|
447
485
|
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
448
486
|
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
453
491
|
if (remote_ptr != 0) {
|
454
492
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
455
493
|
ggml_backend_rpc_buffer_interface,
|
456
|
-
new ggml_backend_rpc_buffer_context{
|
494
|
+
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
|
457
495
|
remote_size);
|
458
496
|
return buffer;
|
459
497
|
} else {
|
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
|
|
508
546
|
}
|
509
547
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
510
548
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
511
|
-
return buft_ctx->
|
549
|
+
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
512
550
|
}
|
513
551
|
|
514
552
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
521
559
|
/* .is_host = */ NULL,
|
522
560
|
};
|
523
561
|
|
524
|
-
|
525
562
|
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
526
563
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
527
564
|
|
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
|
530
567
|
|
531
568
|
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
532
569
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
533
|
-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
|
534
|
-
delete buft_ctx;
|
535
|
-
delete rpc_ctx->buft;
|
536
570
|
delete rpc_ctx;
|
537
571
|
delete backend;
|
538
572
|
}
|
539
573
|
|
540
574
|
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
|
541
575
|
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
542
|
-
return ctx->
|
576
|
+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
543
577
|
}
|
544
578
|
|
545
579
|
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|
590
624
|
std::vector<uint8_t> input;
|
591
625
|
serialize_graph(cgraph, input);
|
592
626
|
std::vector<uint8_t> output;
|
593
|
-
|
627
|
+
auto sock = get_socket(rpc_ctx->endpoint);
|
628
|
+
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
|
594
629
|
GGML_ASSERT(status);
|
595
630
|
GGML_ASSERT(output.size() == 1);
|
596
631
|
return (enum ggml_status)output[0];
|
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
624
659
|
/* .event_synchronize = */ NULL,
|
625
660
|
};
|
626
661
|
|
627
|
-
static std::unordered_map<std::string, ggml_backend_t> instances;
|
628
|
-
|
629
662
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
#ifdef _WIN32
|
640
|
-
{
|
641
|
-
WSADATA wsaData;
|
642
|
-
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
643
|
-
if (res != 0) {
|
644
|
-
return nullptr;
|
645
|
-
}
|
646
|
-
}
|
647
|
-
#endif
|
648
|
-
fprintf(stderr, "Connecting to %s\n", endpoint);
|
649
|
-
std::string host;
|
650
|
-
int port;
|
651
|
-
if (!parse_endpoint(endpoint, host, port)) {
|
652
|
-
return nullptr;
|
653
|
-
}
|
654
|
-
auto sock = socket_connect(host.c_str(), port);
|
663
|
+
static std::mutex mutex;
|
664
|
+
std::lock_guard<std::mutex> lock(mutex);
|
665
|
+
// NOTE: buffer types are allocated and never freed; this is by design
|
666
|
+
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
667
|
+
auto it = buft_map.find(endpoint);
|
668
|
+
if (it != buft_map.end()) {
|
669
|
+
return it->second;
|
670
|
+
}
|
671
|
+
auto sock = get_socket(endpoint);
|
655
672
|
if (sock == nullptr) {
|
656
673
|
return nullptr;
|
657
674
|
}
|
658
675
|
size_t alignment = get_alignment(sock);
|
659
676
|
size_t max_size = get_max_size(sock);
|
660
677
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
661
|
-
/* .
|
662
|
-
/* .name
|
678
|
+
/* .endpoint = */ endpoint,
|
679
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
663
680
|
/* .alignment = */ alignment,
|
664
|
-
/* .max_size
|
681
|
+
/* .max_size = */ max_size
|
665
682
|
};
|
666
683
|
|
667
684
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
668
685
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
669
686
|
/* .context = */ buft_ctx
|
670
687
|
};
|
688
|
+
buft_map[endpoint] = buft;
|
689
|
+
return buft;
|
690
|
+
}
|
671
691
|
|
692
|
+
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
672
693
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
673
|
-
/* .endpoint
|
674
|
-
/* .name
|
675
|
-
/* .sock = */ sock,
|
676
|
-
/* .buft = */ buft
|
694
|
+
/* .endpoint = */ endpoint,
|
695
|
+
/* .name = */ "RPC",
|
677
696
|
};
|
678
697
|
|
679
|
-
|
698
|
+
ggml_backend_t backend = new ggml_backend {
|
680
699
|
/* .guid = */ ggml_backend_rpc_guid(),
|
681
700
|
/* .interface = */ ggml_backend_rpc_interface,
|
682
701
|
/* .context = */ ctx
|
683
702
|
};
|
684
|
-
|
685
|
-
return instances[endpoint];
|
703
|
+
return backend;
|
686
704
|
}
|
687
705
|
|
688
706
|
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
|
|
706
724
|
}
|
707
725
|
|
708
726
|
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
709
|
-
|
710
|
-
if (
|
727
|
+
auto sock = get_socket(endpoint);
|
728
|
+
if (sock == nullptr) {
|
711
729
|
*free = 0;
|
712
730
|
*total = 0;
|
713
731
|
return;
|
714
732
|
}
|
715
|
-
|
716
|
-
get_device_memory(ctx->sock, free, total);
|
733
|
+
get_device_memory(sock, free, total);
|
717
734
|
}
|
718
735
|
|
719
736
|
// RPC server-side implementation
|