llama_cpp 0.6.0 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -3
- data/ext/llama_cpp/src/ggml-cuda.cu +122 -72
- data/ext/llama_cpp/src/ggml-metal.m +4 -5
- data/ext/llama_cpp/src/ggml-metal.metal +9 -2
- data/ext/llama_cpp/src/ggml-opencl.cpp +119 -53
- data/ext/llama_cpp/src/ggml.c +755 -320
- data/ext/llama_cpp/src/ggml.h +13 -0
- data/ext/llama_cpp/src/k_quants.c +744 -2
- data/ext/llama_cpp/src/llama.cpp +779 -113
- data/ext/llama_cpp/src/llama.h +22 -6
- data/ext/llama_cpp/src/unicode.h +462 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -0
- metadata +3 -2
@@ -80,9 +80,9 @@
|
|
80
80
|
#include "ggml.h"
|
81
81
|
|
82
82
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
83
|
-
#define
|
83
|
+
#define CC_VOLTA 700
|
84
84
|
#define CC_OFFSET_AMD 1000000
|
85
|
-
#define CC_RDNA2 CC_OFFSET_AMD + 1030
|
85
|
+
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
86
86
|
|
87
87
|
#if defined(GGML_USE_HIPBLAS)
|
88
88
|
#define __CUDA_ARCH__ 1300
|
@@ -715,7 +715,8 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
|
715
715
|
|
716
716
|
//================================== k-quants
|
717
717
|
|
718
|
-
|
718
|
+
template<typename dst_t>
|
719
|
+
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
719
720
|
|
720
721
|
const int i = blockIdx.x;
|
721
722
|
const block_q2_K * x = (const block_q2_K *) vx;
|
@@ -727,7 +728,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
|
727
728
|
const int is = 8*n + l/16;
|
728
729
|
|
729
730
|
const uint8_t q = x[i].qs[32*n + l];
|
730
|
-
|
731
|
+
dst_t * y = yy + i*QK_K + 128*n;
|
731
732
|
|
732
733
|
float dall = __low2half(x[i].dm);
|
733
734
|
float dmin = __high2half(x[i].dm);
|
@@ -739,7 +740,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
|
739
740
|
const int is = tid/16; // 0 or 1
|
740
741
|
const int il = tid%16; // 0...15
|
741
742
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
742
|
-
|
743
|
+
dst_t * y = yy + i*QK_K + 16*is + il;
|
743
744
|
float dall = __low2half(x[i].dm);
|
744
745
|
float dmin = __high2half(x[i].dm);
|
745
746
|
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
@@ -748,7 +749,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
|
748
749
|
|
749
750
|
}
|
750
751
|
|
751
|
-
|
752
|
+
template<typename dst_t>
|
753
|
+
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
752
754
|
|
753
755
|
const int i = blockIdx.x;
|
754
756
|
const block_q3_K * x = (const block_q3_K *) vx;
|
@@ -772,7 +774,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
|
|
772
774
|
float d_all = x[i].d;
|
773
775
|
float dl = d_all * (us - 32);
|
774
776
|
|
775
|
-
|
777
|
+
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
776
778
|
const uint8_t * q = x[i].qs + 32*n;
|
777
779
|
const uint8_t * hm = x[i].hmask;
|
778
780
|
|
@@ -784,7 +786,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
|
|
784
786
|
const int im = il/8; // 0...1
|
785
787
|
const int in = il%8; // 0...7
|
786
788
|
|
787
|
-
|
789
|
+
dst_t * y = yy + i*QK_K + 16*is + il;
|
788
790
|
|
789
791
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
790
792
|
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
@@ -812,7 +814,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
|
812
814
|
}
|
813
815
|
#endif
|
814
816
|
|
815
|
-
|
817
|
+
template<typename dst_t>
|
818
|
+
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
816
819
|
const block_q4_K * x = (const block_q4_K *) vx;
|
817
820
|
|
818
821
|
const int i = blockIdx.x;
|
@@ -825,7 +828,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
|
|
825
828
|
const int is = 2*il;
|
826
829
|
const int n = 4;
|
827
830
|
|
828
|
-
|
831
|
+
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
829
832
|
|
830
833
|
const float dall = __low2half(x[i].dm);
|
831
834
|
const float dmin = __high2half(x[i].dm);
|
@@ -844,7 +847,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
|
|
844
847
|
#else
|
845
848
|
const int tid = threadIdx.x;
|
846
849
|
const uint8_t * q = x[i].qs;
|
847
|
-
|
850
|
+
dst_t * y = yy + i*QK_K;
|
848
851
|
const float d = (float)x[i].dm[0];
|
849
852
|
const float m = (float)x[i].dm[1];
|
850
853
|
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
@@ -852,7 +855,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
|
|
852
855
|
#endif
|
853
856
|
}
|
854
857
|
|
855
|
-
|
858
|
+
template<typename dst_t>
|
859
|
+
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
856
860
|
const block_q5_K * x = (const block_q5_K *) vx;
|
857
861
|
|
858
862
|
const int i = blockIdx.x;
|
@@ -864,7 +868,7 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
|
|
864
868
|
const int ir = tid%16; // ir is in 0...15
|
865
869
|
const int is = 2*il; // is is in 0...6
|
866
870
|
|
867
|
-
|
871
|
+
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
868
872
|
|
869
873
|
const float dall = __low2half(x[i].dm);
|
870
874
|
const float dmin = __high2half(x[i].dm);
|
@@ -892,13 +896,14 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
|
|
892
896
|
const int is = tid/16; // 0 or 1
|
893
897
|
const uint8_t h = x[i].qh[in] >> im;
|
894
898
|
const float d = x[i].d;
|
895
|
-
|
899
|
+
dst_t * y = yy + i*QK_K + tid;
|
896
900
|
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
897
901
|
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
898
902
|
#endif
|
899
903
|
}
|
900
904
|
|
901
|
-
|
905
|
+
template<typename dst_t>
|
906
|
+
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
902
907
|
const block_q6_K * x = (const block_q6_K *) vx;
|
903
908
|
|
904
909
|
const int i = blockIdx.x;
|
@@ -910,7 +915,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
|
|
910
915
|
const int il = tid - 32*ip; // 0...32
|
911
916
|
const int is = 8*ip + il/16;
|
912
917
|
|
913
|
-
|
918
|
+
dst_t * y = yy + i*QK_K + 128*ip + il;
|
914
919
|
|
915
920
|
const float d = x[i].d;
|
916
921
|
|
@@ -929,7 +934,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
|
|
929
934
|
const int ip = tid/16; // 0 or 1
|
930
935
|
const int il = tid - 16*ip; // 0...15
|
931
936
|
|
932
|
-
|
937
|
+
dst_t * y = yy + i*QK_K + 16*ip + il;
|
933
938
|
|
934
939
|
const float d = x[i].d;
|
935
940
|
|
@@ -3548,7 +3553,7 @@ template <bool need_check> static __global__ void
|
|
3548
3553
|
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
|
3549
3554
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3550
3555
|
|
3551
|
-
#elif __CUDA_ARCH__ >=
|
3556
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3552
3557
|
const int mmq_x = MMQ_X_Q4_0_AMPERE;
|
3553
3558
|
const int mmq_y = MMQ_Y_Q4_0_AMPERE;
|
3554
3559
|
const int nwarps = NWARPS_Q4_0_AMPERE;
|
@@ -3568,7 +3573,7 @@ template <bool need_check> static __global__ void
|
|
3568
3573
|
#else
|
3569
3574
|
(void) vec_dot_q4_0_q8_1_mul_mat;
|
3570
3575
|
assert(false);
|
3571
|
-
#endif // __CUDA_ARCH__ >=
|
3576
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3572
3577
|
}
|
3573
3578
|
|
3574
3579
|
#define MMQ_X_Q4_1_RDNA2 64
|
@@ -3589,9 +3594,9 @@ template <bool need_check> static __global__ void
|
|
3589
3594
|
#if defined(RDNA3) || defined(RDNA2)
|
3590
3595
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
|
3591
3596
|
#endif // defined(RDNA3) || defined(RDNA2)
|
3592
|
-
#elif __CUDA_ARCH__ <
|
3597
|
+
#elif __CUDA_ARCH__ < CC_VOLTA
|
3593
3598
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
|
3594
|
-
#endif // __CUDA_ARCH__ <
|
3599
|
+
#endif // __CUDA_ARCH__ < CC_VOLTA
|
3595
3600
|
mul_mat_q4_1(
|
3596
3601
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
3597
3602
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
@@ -3611,7 +3616,7 @@ template <bool need_check> static __global__ void
|
|
3611
3616
|
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
|
3612
3617
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3613
3618
|
|
3614
|
-
#elif __CUDA_ARCH__ >=
|
3619
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3615
3620
|
const int mmq_x = MMQ_X_Q4_1_AMPERE;
|
3616
3621
|
const int mmq_y = MMQ_Y_Q4_1_AMPERE;
|
3617
3622
|
const int nwarps = NWARPS_Q4_1_AMPERE;
|
@@ -3631,7 +3636,7 @@ template <bool need_check> static __global__ void
|
|
3631
3636
|
#else
|
3632
3637
|
(void) vec_dot_q4_1_q8_1_mul_mat;
|
3633
3638
|
assert(false);
|
3634
|
-
#endif // __CUDA_ARCH__ >=
|
3639
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3635
3640
|
}
|
3636
3641
|
|
3637
3642
|
#define MMQ_X_Q5_0_RDNA2 64
|
@@ -3672,7 +3677,7 @@ template <bool need_check> static __global__ void
|
|
3672
3677
|
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
|
3673
3678
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3674
3679
|
|
3675
|
-
#elif __CUDA_ARCH__ >=
|
3680
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3676
3681
|
const int mmq_x = MMQ_X_Q5_0_AMPERE;
|
3677
3682
|
const int mmq_y = MMQ_Y_Q5_0_AMPERE;
|
3678
3683
|
const int nwarps = NWARPS_Q5_0_AMPERE;
|
@@ -3692,7 +3697,7 @@ template <bool need_check> static __global__ void
|
|
3692
3697
|
#else
|
3693
3698
|
(void) vec_dot_q5_0_q8_1_mul_mat;
|
3694
3699
|
assert(false);
|
3695
|
-
#endif // __CUDA_ARCH__ >=
|
3700
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3696
3701
|
}
|
3697
3702
|
|
3698
3703
|
#define MMQ_X_Q5_1_RDNA2 64
|
@@ -3733,7 +3738,7 @@ mul_mat_q5_1(
|
|
3733
3738
|
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
|
3734
3739
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3735
3740
|
|
3736
|
-
#elif __CUDA_ARCH__ >=
|
3741
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3737
3742
|
const int mmq_x = MMQ_X_Q5_1_AMPERE;
|
3738
3743
|
const int mmq_y = MMQ_Y_Q5_1_AMPERE;
|
3739
3744
|
const int nwarps = NWARPS_Q5_1_AMPERE;
|
@@ -3753,7 +3758,7 @@ mul_mat_q5_1(
|
|
3753
3758
|
#else
|
3754
3759
|
(void) vec_dot_q5_1_q8_1_mul_mat;
|
3755
3760
|
assert(false);
|
3756
|
-
#endif // __CUDA_ARCH__ >=
|
3761
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3757
3762
|
}
|
3758
3763
|
|
3759
3764
|
#define MMQ_X_Q8_0_RDNA2 64
|
@@ -3794,7 +3799,7 @@ template <bool need_check> static __global__ void
|
|
3794
3799
|
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
|
3795
3800
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3796
3801
|
|
3797
|
-
#elif __CUDA_ARCH__ >=
|
3802
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3798
3803
|
const int mmq_x = MMQ_X_Q8_0_AMPERE;
|
3799
3804
|
const int mmq_y = MMQ_Y_Q8_0_AMPERE;
|
3800
3805
|
const int nwarps = NWARPS_Q8_0_AMPERE;
|
@@ -3814,7 +3819,7 @@ template <bool need_check> static __global__ void
|
|
3814
3819
|
#else
|
3815
3820
|
(void) vec_dot_q8_0_q8_1_mul_mat;
|
3816
3821
|
assert(false);
|
3817
|
-
#endif // __CUDA_ARCH__ >=
|
3822
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3818
3823
|
}
|
3819
3824
|
|
3820
3825
|
#define MMQ_X_Q2_K_RDNA2 64
|
@@ -3855,7 +3860,7 @@ mul_mat_q2_K(
|
|
3855
3860
|
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
|
3856
3861
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3857
3862
|
|
3858
|
-
#elif __CUDA_ARCH__ >=
|
3863
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3859
3864
|
const int mmq_x = MMQ_X_Q2_K_AMPERE;
|
3860
3865
|
const int mmq_y = MMQ_Y_Q2_K_AMPERE;
|
3861
3866
|
const int nwarps = NWARPS_Q2_K_AMPERE;
|
@@ -3875,7 +3880,7 @@ mul_mat_q2_K(
|
|
3875
3880
|
#else
|
3876
3881
|
(void) vec_dot_q2_K_q8_1_mul_mat;
|
3877
3882
|
assert(false);
|
3878
|
-
#endif // __CUDA_ARCH__ >=
|
3883
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3879
3884
|
}
|
3880
3885
|
|
3881
3886
|
#define MMQ_X_Q3_K_RDNA2 128
|
@@ -3896,9 +3901,9 @@ template <bool need_check> static __global__ void
|
|
3896
3901
|
#if defined(RDNA3) || defined(RDNA2)
|
3897
3902
|
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
|
3898
3903
|
#endif // defined(RDNA3) || defined(RDNA2)
|
3899
|
-
#elif __CUDA_ARCH__ <
|
3904
|
+
#elif __CUDA_ARCH__ < CC_VOLTA
|
3900
3905
|
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
|
3901
|
-
#endif // __CUDA_ARCH__ <
|
3906
|
+
#endif // __CUDA_ARCH__ < CC_VOLTA
|
3902
3907
|
mul_mat_q3_K(
|
3903
3908
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
3904
3909
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
@@ -3918,7 +3923,7 @@ template <bool need_check> static __global__ void
|
|
3918
3923
|
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
|
3919
3924
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3920
3925
|
|
3921
|
-
#elif __CUDA_ARCH__ >=
|
3926
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3922
3927
|
const int mmq_x = MMQ_X_Q3_K_AMPERE;
|
3923
3928
|
const int mmq_y = MMQ_Y_Q3_K_AMPERE;
|
3924
3929
|
const int nwarps = NWARPS_Q3_K_AMPERE;
|
@@ -3938,7 +3943,7 @@ template <bool need_check> static __global__ void
|
|
3938
3943
|
#else
|
3939
3944
|
(void) vec_dot_q3_K_q8_1_mul_mat;
|
3940
3945
|
assert(false);
|
3941
|
-
#endif // __CUDA_ARCH__ >=
|
3946
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3942
3947
|
}
|
3943
3948
|
|
3944
3949
|
#define MMQ_X_Q4_K_RDNA2 64
|
@@ -3959,9 +3964,9 @@ template <bool need_check> static __global__ void
|
|
3959
3964
|
#if defined(RDNA3) || defined(RDNA2)
|
3960
3965
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
|
3961
3966
|
#endif // defined(RDNA3) || defined(RDNA2)
|
3962
|
-
#elif __CUDA_ARCH__ <
|
3967
|
+
#elif __CUDA_ARCH__ < CC_VOLTA
|
3963
3968
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
|
3964
|
-
#endif // __CUDA_ARCH__ <
|
3969
|
+
#endif // __CUDA_ARCH__ < CC_VOLTA
|
3965
3970
|
mul_mat_q4_K(
|
3966
3971
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
3967
3972
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
@@ -3981,7 +3986,7 @@ template <bool need_check> static __global__ void
|
|
3981
3986
|
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
|
3982
3987
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3983
3988
|
|
3984
|
-
#elif __CUDA_ARCH__ >=
|
3989
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3985
3990
|
const int mmq_x = MMQ_X_Q4_K_AMPERE;
|
3986
3991
|
const int mmq_y = MMQ_Y_Q4_K_AMPERE;
|
3987
3992
|
const int nwarps = NWARPS_Q4_K_AMPERE;
|
@@ -4001,7 +4006,7 @@ template <bool need_check> static __global__ void
|
|
4001
4006
|
#else
|
4002
4007
|
(void) vec_dot_q4_K_q8_1_mul_mat;
|
4003
4008
|
assert(false);
|
4004
|
-
#endif // __CUDA_ARCH__ >=
|
4009
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
4005
4010
|
}
|
4006
4011
|
|
4007
4012
|
#define MMQ_X_Q5_K_RDNA2 64
|
@@ -4042,7 +4047,7 @@ mul_mat_q5_K(
|
|
4042
4047
|
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
|
4043
4048
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
4044
4049
|
|
4045
|
-
#elif __CUDA_ARCH__ >=
|
4050
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
4046
4051
|
const int mmq_x = MMQ_X_Q5_K_AMPERE;
|
4047
4052
|
const int mmq_y = MMQ_Y_Q5_K_AMPERE;
|
4048
4053
|
const int nwarps = NWARPS_Q5_K_AMPERE;
|
@@ -4062,7 +4067,7 @@ mul_mat_q5_K(
|
|
4062
4067
|
#else
|
4063
4068
|
(void) vec_dot_q5_K_q8_1_mul_mat;
|
4064
4069
|
assert(false);
|
4065
|
-
#endif // __CUDA_ARCH__ >=
|
4070
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
4066
4071
|
}
|
4067
4072
|
|
4068
4073
|
#define MMQ_X_Q6_K_RDNA2 64
|
@@ -4083,9 +4088,9 @@ template <bool need_check> static __global__ void
|
|
4083
4088
|
#if defined(RDNA3) || defined(RDNA2)
|
4084
4089
|
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
|
4085
4090
|
#endif // defined(RDNA3) || defined(RDNA2)
|
4086
|
-
#elif __CUDA_ARCH__ <
|
4091
|
+
#elif __CUDA_ARCH__ < CC_VOLTA
|
4087
4092
|
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
|
4088
|
-
#endif // __CUDA_ARCH__ <
|
4093
|
+
#endif // __CUDA_ARCH__ < CC_VOLTA
|
4089
4094
|
mul_mat_q6_K(
|
4090
4095
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
4091
4096
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
@@ -4105,7 +4110,7 @@ template <bool need_check> static __global__ void
|
|
4105
4110
|
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
|
4106
4111
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
4107
4112
|
|
4108
|
-
#elif __CUDA_ARCH__ >=
|
4113
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
4109
4114
|
const int mmq_x = MMQ_X_Q6_K_AMPERE;
|
4110
4115
|
const int mmq_y = MMQ_Y_Q6_K_AMPERE;
|
4111
4116
|
const int nwarps = NWARPS_Q6_K_AMPERE;
|
@@ -4125,7 +4130,7 @@ template <bool need_check> static __global__ void
|
|
4125
4130
|
#else
|
4126
4131
|
(void) vec_dot_q6_K_q8_1_mul_mat;
|
4127
4132
|
assert(false);
|
4128
|
-
#endif // __CUDA_ARCH__ >=
|
4133
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
4129
4134
|
}
|
4130
4135
|
|
4131
4136
|
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
@@ -4604,32 +4609,38 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
|
|
4604
4609
|
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
|
4605
4610
|
}
|
4606
4611
|
|
4607
|
-
|
4612
|
+
template<typename dst_t>
|
4613
|
+
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4608
4614
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4609
4615
|
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4610
4616
|
}
|
4611
4617
|
|
4612
|
-
|
4618
|
+
template<typename dst_t>
|
4619
|
+
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4613
4620
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4614
4621
|
dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4615
4622
|
}
|
4616
4623
|
|
4617
|
-
|
4624
|
+
template<typename dst_t>
|
4625
|
+
static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4618
4626
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4619
4627
|
dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4620
4628
|
}
|
4621
4629
|
|
4622
|
-
|
4630
|
+
template<typename dst_t>
|
4631
|
+
static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4623
4632
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4624
4633
|
dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4625
4634
|
}
|
4626
4635
|
|
4627
|
-
|
4636
|
+
template<typename dst_t>
|
4637
|
+
static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4628
4638
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4629
4639
|
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4630
4640
|
}
|
4631
4641
|
|
4632
|
-
|
4642
|
+
template<typename dst_t>
|
4643
|
+
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4633
4644
|
const int nb = k / QK_K;
|
4634
4645
|
#if QK_K == 256
|
4635
4646
|
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
@@ -4638,7 +4649,8 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu
|
|
4638
4649
|
#endif
|
4639
4650
|
}
|
4640
4651
|
|
4641
|
-
|
4652
|
+
template<typename dst_t>
|
4653
|
+
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4642
4654
|
const int nb = k / QK_K;
|
4643
4655
|
#if QK_K == 256
|
4644
4656
|
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
@@ -4647,12 +4659,14 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu
|
|
4647
4659
|
#endif
|
4648
4660
|
}
|
4649
4661
|
|
4650
|
-
|
4662
|
+
template<typename dst_t>
|
4663
|
+
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4651
4664
|
const int nb = k / QK_K;
|
4652
4665
|
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
|
4653
4666
|
}
|
4654
4667
|
|
4655
|
-
|
4668
|
+
template<typename dst_t>
|
4669
|
+
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4656
4670
|
const int nb = k / QK_K;
|
4657
4671
|
#if QK_K == 256
|
4658
4672
|
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
@@ -4661,7 +4675,8 @@ static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cu
|
|
4661
4675
|
#endif
|
4662
4676
|
}
|
4663
4677
|
|
4664
|
-
|
4678
|
+
template<typename dst_t>
|
4679
|
+
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4665
4680
|
const int nb = k / QK_K;
|
4666
4681
|
#if QK_K == 256
|
4667
4682
|
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
@@ -4868,6 +4883,26 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
|
|
4868
4883
|
|
4869
4884
|
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
4870
4885
|
switch (type) {
|
4886
|
+
case GGML_TYPE_Q4_0:
|
4887
|
+
return dequantize_row_q4_0_cuda;
|
4888
|
+
case GGML_TYPE_Q4_1:
|
4889
|
+
return dequantize_row_q4_1_cuda;
|
4890
|
+
case GGML_TYPE_Q5_0:
|
4891
|
+
return dequantize_row_q5_0_cuda;
|
4892
|
+
case GGML_TYPE_Q5_1:
|
4893
|
+
return dequantize_row_q5_1_cuda;
|
4894
|
+
case GGML_TYPE_Q8_0:
|
4895
|
+
return dequantize_row_q8_0_cuda;
|
4896
|
+
case GGML_TYPE_Q2_K:
|
4897
|
+
return dequantize_row_q2_K_cuda;
|
4898
|
+
case GGML_TYPE_Q3_K:
|
4899
|
+
return dequantize_row_q3_K_cuda;
|
4900
|
+
case GGML_TYPE_Q4_K:
|
4901
|
+
return dequantize_row_q4_K_cuda;
|
4902
|
+
case GGML_TYPE_Q5_K:
|
4903
|
+
return dequantize_row_q5_K_cuda;
|
4904
|
+
case GGML_TYPE_Q6_K:
|
4905
|
+
return dequantize_row_q6_K_cuda;
|
4871
4906
|
case GGML_TYPE_F32:
|
4872
4907
|
return convert_fp32_to_fp16_cuda;
|
4873
4908
|
default:
|
@@ -4921,7 +4956,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
|
4921
4956
|
mmq_x = MMQ_X_Q4_0_RDNA1;
|
4922
4957
|
mmq_y = MMQ_Y_Q4_0_RDNA1;
|
4923
4958
|
nwarps = NWARPS_Q4_0_RDNA1;
|
4924
|
-
} else if (compute_capability >=
|
4959
|
+
} else if (compute_capability >= CC_VOLTA) {
|
4925
4960
|
mmq_x = MMQ_X_Q4_0_AMPERE;
|
4926
4961
|
mmq_y = MMQ_Y_Q4_0_AMPERE;
|
4927
4962
|
nwarps = NWARPS_Q4_0_AMPERE;
|
@@ -4966,7 +5001,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
|
4966
5001
|
mmq_x = MMQ_X_Q4_1_RDNA1;
|
4967
5002
|
mmq_y = MMQ_Y_Q4_1_RDNA1;
|
4968
5003
|
nwarps = NWARPS_Q4_1_RDNA1;
|
4969
|
-
} else if (compute_capability >=
|
5004
|
+
} else if (compute_capability >= CC_VOLTA) {
|
4970
5005
|
mmq_x = MMQ_X_Q4_1_AMPERE;
|
4971
5006
|
mmq_y = MMQ_Y_Q4_1_AMPERE;
|
4972
5007
|
nwarps = NWARPS_Q4_1_AMPERE;
|
@@ -5011,7 +5046,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
|
5011
5046
|
mmq_x = MMQ_X_Q5_0_RDNA1;
|
5012
5047
|
mmq_y = MMQ_Y_Q5_0_RDNA1;
|
5013
5048
|
nwarps = NWARPS_Q5_0_RDNA1;
|
5014
|
-
} else if (compute_capability >=
|
5049
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5015
5050
|
mmq_x = MMQ_X_Q5_0_AMPERE;
|
5016
5051
|
mmq_y = MMQ_Y_Q5_0_AMPERE;
|
5017
5052
|
nwarps = NWARPS_Q5_0_AMPERE;
|
@@ -5056,7 +5091,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
|
5056
5091
|
mmq_x = MMQ_X_Q5_1_RDNA1;
|
5057
5092
|
mmq_y = MMQ_Y_Q5_1_RDNA1;
|
5058
5093
|
nwarps = NWARPS_Q5_1_RDNA1;
|
5059
|
-
} else if (compute_capability >=
|
5094
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5060
5095
|
mmq_x = MMQ_X_Q5_1_AMPERE;
|
5061
5096
|
mmq_y = MMQ_Y_Q5_1_AMPERE;
|
5062
5097
|
nwarps = NWARPS_Q5_1_AMPERE;
|
@@ -5101,7 +5136,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
|
5101
5136
|
mmq_x = MMQ_X_Q8_0_RDNA1;
|
5102
5137
|
mmq_y = MMQ_Y_Q8_0_RDNA1;
|
5103
5138
|
nwarps = NWARPS_Q8_0_RDNA1;
|
5104
|
-
} else if (compute_capability >=
|
5139
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5105
5140
|
mmq_x = MMQ_X_Q8_0_AMPERE;
|
5106
5141
|
mmq_y = MMQ_Y_Q8_0_AMPERE;
|
5107
5142
|
nwarps = NWARPS_Q8_0_AMPERE;
|
@@ -5146,7 +5181,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
|
5146
5181
|
mmq_x = MMQ_X_Q2_K_RDNA1;
|
5147
5182
|
mmq_y = MMQ_Y_Q2_K_RDNA1;
|
5148
5183
|
nwarps = NWARPS_Q2_K_RDNA1;
|
5149
|
-
} else if (compute_capability >=
|
5184
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5150
5185
|
mmq_x = MMQ_X_Q2_K_AMPERE;
|
5151
5186
|
mmq_y = MMQ_Y_Q2_K_AMPERE;
|
5152
5187
|
nwarps = NWARPS_Q2_K_AMPERE;
|
@@ -5193,7 +5228,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
|
5193
5228
|
mmq_x = MMQ_X_Q3_K_RDNA1;
|
5194
5229
|
mmq_y = MMQ_Y_Q3_K_RDNA1;
|
5195
5230
|
nwarps = NWARPS_Q3_K_RDNA1;
|
5196
|
-
} else if (compute_capability >=
|
5231
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5197
5232
|
mmq_x = MMQ_X_Q3_K_AMPERE;
|
5198
5233
|
mmq_y = MMQ_Y_Q3_K_AMPERE;
|
5199
5234
|
nwarps = NWARPS_Q3_K_AMPERE;
|
@@ -5239,7 +5274,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
|
5239
5274
|
mmq_x = MMQ_X_Q4_K_RDNA1;
|
5240
5275
|
mmq_y = MMQ_Y_Q4_K_RDNA1;
|
5241
5276
|
nwarps = NWARPS_Q4_K_RDNA1;
|
5242
|
-
} else if (compute_capability >=
|
5277
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5243
5278
|
mmq_x = MMQ_X_Q4_K_AMPERE;
|
5244
5279
|
mmq_y = MMQ_Y_Q4_K_AMPERE;
|
5245
5280
|
nwarps = NWARPS_Q4_K_AMPERE;
|
@@ -5284,7 +5319,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
|
5284
5319
|
mmq_x = MMQ_X_Q5_K_RDNA1;
|
5285
5320
|
mmq_y = MMQ_Y_Q5_K_RDNA1;
|
5286
5321
|
nwarps = NWARPS_Q5_K_RDNA1;
|
5287
|
-
} else if (compute_capability >=
|
5322
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5288
5323
|
mmq_x = MMQ_X_Q5_K_AMPERE;
|
5289
5324
|
mmq_y = MMQ_Y_Q5_K_AMPERE;
|
5290
5325
|
nwarps = NWARPS_Q5_K_AMPERE;
|
@@ -5329,7 +5364,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
|
|
5329
5364
|
mmq_x = MMQ_X_Q6_K_RDNA1;
|
5330
5365
|
mmq_y = MMQ_Y_Q6_K_RDNA1;
|
5331
5366
|
nwarps = NWARPS_Q6_K_RDNA1;
|
5332
|
-
} else if (compute_capability >=
|
5367
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5333
5368
|
mmq_x = MMQ_X_Q6_K_AMPERE;
|
5334
5369
|
mmq_y = MMQ_Y_Q6_K_AMPERE;
|
5335
5370
|
nwarps = NWARPS_Q6_K_AMPERE;
|
@@ -5907,7 +5942,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
5907
5942
|
switch(type) {
|
5908
5943
|
case GGML_TYPE_Q4_0:
|
5909
5944
|
case GGML_TYPE_Q4_1:
|
5910
|
-
return max_compute_capability >=
|
5945
|
+
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
5911
5946
|
case GGML_TYPE_Q5_0:
|
5912
5947
|
case GGML_TYPE_Q5_1:
|
5913
5948
|
case GGML_TYPE_Q8_0:
|
@@ -5918,7 +5953,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
5918
5953
|
case GGML_TYPE_Q3_K:
|
5919
5954
|
case GGML_TYPE_Q4_K:
|
5920
5955
|
case GGML_TYPE_Q5_K:
|
5921
|
-
return max_compute_capability >=
|
5956
|
+
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
5922
5957
|
case GGML_TYPE_Q6_K:
|
5923
5958
|
return 64;
|
5924
5959
|
default:
|
@@ -6083,8 +6118,19 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
6083
6118
|
|
6084
6119
|
const int compute_capability = g_compute_capabilities[id];
|
6085
6120
|
|
6086
|
-
if (compute_capability >=
|
6087
|
-
// convert src1 to fp16, multiply as fp16, convert dst to fp32
|
6121
|
+
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
6122
|
+
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
6123
|
+
half * src0_as_f16 = nullptr;
|
6124
|
+
size_t src0_as = 0;
|
6125
|
+
if (src0->type != GGML_TYPE_F16) {
|
6126
|
+
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
6127
|
+
GGML_ASSERT(to_fp16_cuda != nullptr);
|
6128
|
+
size_t ne = row_diff*ne00;
|
6129
|
+
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
|
6130
|
+
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
6131
|
+
}
|
6132
|
+
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
6133
|
+
|
6088
6134
|
half * src1_as_f16 = nullptr;
|
6089
6135
|
size_t src1_as = 0;
|
6090
6136
|
if (src1->type != GGML_TYPE_F16) {
|
@@ -6106,9 +6152,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
6106
6152
|
CUBLAS_CHECK(
|
6107
6153
|
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
6108
6154
|
row_diff, src1_ncols, ne10,
|
6109
|
-
&alpha_f16,
|
6110
|
-
src1_ptr,
|
6111
|
-
&beta_f16, dst_f16,
|
6155
|
+
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
6156
|
+
src1_ptr, CUDA_R_16F, ne10,
|
6157
|
+
&beta_f16, dst_f16, CUDA_R_16F, ldc,
|
6112
6158
|
CUBLAS_COMPUTE_16F,
|
6113
6159
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
6114
6160
|
|
@@ -6117,6 +6163,10 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
6117
6163
|
|
6118
6164
|
ggml_cuda_pool_free(dst_f16, dst_as);
|
6119
6165
|
|
6166
|
+
if (src0_as != 0) {
|
6167
|
+
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
6168
|
+
}
|
6169
|
+
|
6120
6170
|
if (src1_as != 0) {
|
6121
6171
|
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
6122
6172
|
}
|
@@ -1213,12 +1213,9 @@ void ggml_metal_graph_compute(
|
|
1213
1213
|
float max_bias;
|
1214
1214
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
1215
1215
|
|
1216
|
-
if (__builtin_popcount(n_head) != 1) {
|
1217
|
-
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
1218
|
-
}
|
1219
|
-
|
1220
1216
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
1221
1217
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
1218
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
1222
1219
|
|
1223
1220
|
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
1224
1221
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1239,7 +1236,9 @@ void ggml_metal_graph_compute(
|
|
1239
1236
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1240
1237
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1241
1238
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1242
|
-
[encoder setBytes:&m0
|
1239
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1240
|
+
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
1241
|
+
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
1243
1242
|
|
1244
1243
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1245
1244
|
} break;
|
@@ -830,7 +830,9 @@ kernel void kernel_alibi_f32(
|
|
830
830
|
constant uint64_t & nb1,
|
831
831
|
constant uint64_t & nb2,
|
832
832
|
constant uint64_t & nb3,
|
833
|
-
constant
|
833
|
+
constant float & m0,
|
834
|
+
constant float & m1,
|
835
|
+
constant int & n_heads_log2_floor,
|
834
836
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
835
837
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
836
838
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -846,7 +848,12 @@ kernel void kernel_alibi_f32(
|
|
846
848
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
847
849
|
|
848
850
|
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
849
|
-
float m_k
|
851
|
+
float m_k;
|
852
|
+
if (i2 < n_heads_log2_floor) {
|
853
|
+
m_k = pow(m0, i2 + 1);
|
854
|
+
} else {
|
855
|
+
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
|
856
|
+
}
|
850
857
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
851
858
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
852
859
|
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|