llama_cpp 0.6.0 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -3
- data/ext/llama_cpp/src/ggml-cuda.cu +122 -72
- data/ext/llama_cpp/src/ggml-metal.m +4 -5
- data/ext/llama_cpp/src/ggml-metal.metal +9 -2
- data/ext/llama_cpp/src/ggml-opencl.cpp +119 -53
- data/ext/llama_cpp/src/ggml.c +755 -320
- data/ext/llama_cpp/src/ggml.h +13 -0
- data/ext/llama_cpp/src/k_quants.c +744 -2
- data/ext/llama_cpp/src/llama.cpp +779 -113
- data/ext/llama_cpp/src/llama.h +22 -6
- data/ext/llama_cpp/src/unicode.h +462 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +5 -0
- metadata +3 -2
@@ -80,9 +80,9 @@
|
|
80
80
|
#include "ggml.h"
|
81
81
|
|
82
82
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
83
|
-
#define
|
83
|
+
#define CC_VOLTA 700
|
84
84
|
#define CC_OFFSET_AMD 1000000
|
85
|
-
#define CC_RDNA2 CC_OFFSET_AMD + 1030
|
85
|
+
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
86
86
|
|
87
87
|
#if defined(GGML_USE_HIPBLAS)
|
88
88
|
#define __CUDA_ARCH__ 1300
|
@@ -715,7 +715,8 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
|
715
715
|
|
716
716
|
//================================== k-quants
|
717
717
|
|
718
|
-
|
718
|
+
template<typename dst_t>
|
719
|
+
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
719
720
|
|
720
721
|
const int i = blockIdx.x;
|
721
722
|
const block_q2_K * x = (const block_q2_K *) vx;
|
@@ -727,7 +728,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
|
727
728
|
const int is = 8*n + l/16;
|
728
729
|
|
729
730
|
const uint8_t q = x[i].qs[32*n + l];
|
730
|
-
|
731
|
+
dst_t * y = yy + i*QK_K + 128*n;
|
731
732
|
|
732
733
|
float dall = __low2half(x[i].dm);
|
733
734
|
float dmin = __high2half(x[i].dm);
|
@@ -739,7 +740,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
|
739
740
|
const int is = tid/16; // 0 or 1
|
740
741
|
const int il = tid%16; // 0...15
|
741
742
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
742
|
-
|
743
|
+
dst_t * y = yy + i*QK_K + 16*is + il;
|
743
744
|
float dall = __low2half(x[i].dm);
|
744
745
|
float dmin = __high2half(x[i].dm);
|
745
746
|
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
@@ -748,7 +749,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
|
|
748
749
|
|
749
750
|
}
|
750
751
|
|
751
|
-
|
752
|
+
template<typename dst_t>
|
753
|
+
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
752
754
|
|
753
755
|
const int i = blockIdx.x;
|
754
756
|
const block_q3_K * x = (const block_q3_K *) vx;
|
@@ -772,7 +774,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
|
|
772
774
|
float d_all = x[i].d;
|
773
775
|
float dl = d_all * (us - 32);
|
774
776
|
|
775
|
-
|
777
|
+
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
776
778
|
const uint8_t * q = x[i].qs + 32*n;
|
777
779
|
const uint8_t * hm = x[i].hmask;
|
778
780
|
|
@@ -784,7 +786,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float
|
|
784
786
|
const int im = il/8; // 0...1
|
785
787
|
const int in = il%8; // 0...7
|
786
788
|
|
787
|
-
|
789
|
+
dst_t * y = yy + i*QK_K + 16*is + il;
|
788
790
|
|
789
791
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
790
792
|
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
@@ -812,7 +814,8 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
|
812
814
|
}
|
813
815
|
#endif
|
814
816
|
|
815
|
-
|
817
|
+
template<typename dst_t>
|
818
|
+
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
816
819
|
const block_q4_K * x = (const block_q4_K *) vx;
|
817
820
|
|
818
821
|
const int i = blockIdx.x;
|
@@ -825,7 +828,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
|
|
825
828
|
const int is = 2*il;
|
826
829
|
const int n = 4;
|
827
830
|
|
828
|
-
|
831
|
+
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
829
832
|
|
830
833
|
const float dall = __low2half(x[i].dm);
|
831
834
|
const float dmin = __high2half(x[i].dm);
|
@@ -844,7 +847,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
|
|
844
847
|
#else
|
845
848
|
const int tid = threadIdx.x;
|
846
849
|
const uint8_t * q = x[i].qs;
|
847
|
-
|
850
|
+
dst_t * y = yy + i*QK_K;
|
848
851
|
const float d = (float)x[i].dm[0];
|
849
852
|
const float m = (float)x[i].dm[1];
|
850
853
|
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
@@ -852,7 +855,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
|
|
852
855
|
#endif
|
853
856
|
}
|
854
857
|
|
855
|
-
|
858
|
+
template<typename dst_t>
|
859
|
+
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
856
860
|
const block_q5_K * x = (const block_q5_K *) vx;
|
857
861
|
|
858
862
|
const int i = blockIdx.x;
|
@@ -864,7 +868,7 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
|
|
864
868
|
const int ir = tid%16; // ir is in 0...15
|
865
869
|
const int is = 2*il; // is is in 0...6
|
866
870
|
|
867
|
-
|
871
|
+
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
868
872
|
|
869
873
|
const float dall = __low2half(x[i].dm);
|
870
874
|
const float dmin = __high2half(x[i].dm);
|
@@ -892,13 +896,14 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
|
|
892
896
|
const int is = tid/16; // 0 or 1
|
893
897
|
const uint8_t h = x[i].qh[in] >> im;
|
894
898
|
const float d = x[i].d;
|
895
|
-
|
899
|
+
dst_t * y = yy + i*QK_K + tid;
|
896
900
|
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
897
901
|
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
898
902
|
#endif
|
899
903
|
}
|
900
904
|
|
901
|
-
|
905
|
+
template<typename dst_t>
|
906
|
+
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
902
907
|
const block_q6_K * x = (const block_q6_K *) vx;
|
903
908
|
|
904
909
|
const int i = blockIdx.x;
|
@@ -910,7 +915,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
|
|
910
915
|
const int il = tid - 32*ip; // 0...32
|
911
916
|
const int is = 8*ip + il/16;
|
912
917
|
|
913
|
-
|
918
|
+
dst_t * y = yy + i*QK_K + 128*ip + il;
|
914
919
|
|
915
920
|
const float d = x[i].d;
|
916
921
|
|
@@ -929,7 +934,7 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float
|
|
929
934
|
const int ip = tid/16; // 0 or 1
|
930
935
|
const int il = tid - 16*ip; // 0...15
|
931
936
|
|
932
|
-
|
937
|
+
dst_t * y = yy + i*QK_K + 16*ip + il;
|
933
938
|
|
934
939
|
const float d = x[i].d;
|
935
940
|
|
@@ -3548,7 +3553,7 @@ template <bool need_check> static __global__ void
|
|
3548
3553
|
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
|
3549
3554
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3550
3555
|
|
3551
|
-
#elif __CUDA_ARCH__ >=
|
3556
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3552
3557
|
const int mmq_x = MMQ_X_Q4_0_AMPERE;
|
3553
3558
|
const int mmq_y = MMQ_Y_Q4_0_AMPERE;
|
3554
3559
|
const int nwarps = NWARPS_Q4_0_AMPERE;
|
@@ -3568,7 +3573,7 @@ template <bool need_check> static __global__ void
|
|
3568
3573
|
#else
|
3569
3574
|
(void) vec_dot_q4_0_q8_1_mul_mat;
|
3570
3575
|
assert(false);
|
3571
|
-
#endif // __CUDA_ARCH__ >=
|
3576
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3572
3577
|
}
|
3573
3578
|
|
3574
3579
|
#define MMQ_X_Q4_1_RDNA2 64
|
@@ -3589,9 +3594,9 @@ template <bool need_check> static __global__ void
|
|
3589
3594
|
#if defined(RDNA3) || defined(RDNA2)
|
3590
3595
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2)
|
3591
3596
|
#endif // defined(RDNA3) || defined(RDNA2)
|
3592
|
-
#elif __CUDA_ARCH__ <
|
3597
|
+
#elif __CUDA_ARCH__ < CC_VOLTA
|
3593
3598
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
|
3594
|
-
#endif // __CUDA_ARCH__ <
|
3599
|
+
#endif // __CUDA_ARCH__ < CC_VOLTA
|
3595
3600
|
mul_mat_q4_1(
|
3596
3601
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
3597
3602
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
@@ -3611,7 +3616,7 @@ template <bool need_check> static __global__ void
|
|
3611
3616
|
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
|
3612
3617
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3613
3618
|
|
3614
|
-
#elif __CUDA_ARCH__ >=
|
3619
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3615
3620
|
const int mmq_x = MMQ_X_Q4_1_AMPERE;
|
3616
3621
|
const int mmq_y = MMQ_Y_Q4_1_AMPERE;
|
3617
3622
|
const int nwarps = NWARPS_Q4_1_AMPERE;
|
@@ -3631,7 +3636,7 @@ template <bool need_check> static __global__ void
|
|
3631
3636
|
#else
|
3632
3637
|
(void) vec_dot_q4_1_q8_1_mul_mat;
|
3633
3638
|
assert(false);
|
3634
|
-
#endif // __CUDA_ARCH__ >=
|
3639
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3635
3640
|
}
|
3636
3641
|
|
3637
3642
|
#define MMQ_X_Q5_0_RDNA2 64
|
@@ -3672,7 +3677,7 @@ template <bool need_check> static __global__ void
|
|
3672
3677
|
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
|
3673
3678
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3674
3679
|
|
3675
|
-
#elif __CUDA_ARCH__ >=
|
3680
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3676
3681
|
const int mmq_x = MMQ_X_Q5_0_AMPERE;
|
3677
3682
|
const int mmq_y = MMQ_Y_Q5_0_AMPERE;
|
3678
3683
|
const int nwarps = NWARPS_Q5_0_AMPERE;
|
@@ -3692,7 +3697,7 @@ template <bool need_check> static __global__ void
|
|
3692
3697
|
#else
|
3693
3698
|
(void) vec_dot_q5_0_q8_1_mul_mat;
|
3694
3699
|
assert(false);
|
3695
|
-
#endif // __CUDA_ARCH__ >=
|
3700
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3696
3701
|
}
|
3697
3702
|
|
3698
3703
|
#define MMQ_X_Q5_1_RDNA2 64
|
@@ -3733,7 +3738,7 @@ mul_mat_q5_1(
|
|
3733
3738
|
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
|
3734
3739
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3735
3740
|
|
3736
|
-
#elif __CUDA_ARCH__ >=
|
3741
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3737
3742
|
const int mmq_x = MMQ_X_Q5_1_AMPERE;
|
3738
3743
|
const int mmq_y = MMQ_Y_Q5_1_AMPERE;
|
3739
3744
|
const int nwarps = NWARPS_Q5_1_AMPERE;
|
@@ -3753,7 +3758,7 @@ mul_mat_q5_1(
|
|
3753
3758
|
#else
|
3754
3759
|
(void) vec_dot_q5_1_q8_1_mul_mat;
|
3755
3760
|
assert(false);
|
3756
|
-
#endif // __CUDA_ARCH__ >=
|
3761
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3757
3762
|
}
|
3758
3763
|
|
3759
3764
|
#define MMQ_X_Q8_0_RDNA2 64
|
@@ -3794,7 +3799,7 @@ template <bool need_check> static __global__ void
|
|
3794
3799
|
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
|
3795
3800
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3796
3801
|
|
3797
|
-
#elif __CUDA_ARCH__ >=
|
3802
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3798
3803
|
const int mmq_x = MMQ_X_Q8_0_AMPERE;
|
3799
3804
|
const int mmq_y = MMQ_Y_Q8_0_AMPERE;
|
3800
3805
|
const int nwarps = NWARPS_Q8_0_AMPERE;
|
@@ -3814,7 +3819,7 @@ template <bool need_check> static __global__ void
|
|
3814
3819
|
#else
|
3815
3820
|
(void) vec_dot_q8_0_q8_1_mul_mat;
|
3816
3821
|
assert(false);
|
3817
|
-
#endif // __CUDA_ARCH__ >=
|
3822
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3818
3823
|
}
|
3819
3824
|
|
3820
3825
|
#define MMQ_X_Q2_K_RDNA2 64
|
@@ -3855,7 +3860,7 @@ mul_mat_q2_K(
|
|
3855
3860
|
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
|
3856
3861
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3857
3862
|
|
3858
|
-
#elif __CUDA_ARCH__ >=
|
3863
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3859
3864
|
const int mmq_x = MMQ_X_Q2_K_AMPERE;
|
3860
3865
|
const int mmq_y = MMQ_Y_Q2_K_AMPERE;
|
3861
3866
|
const int nwarps = NWARPS_Q2_K_AMPERE;
|
@@ -3875,7 +3880,7 @@ mul_mat_q2_K(
|
|
3875
3880
|
#else
|
3876
3881
|
(void) vec_dot_q2_K_q8_1_mul_mat;
|
3877
3882
|
assert(false);
|
3878
|
-
#endif // __CUDA_ARCH__ >=
|
3883
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3879
3884
|
}
|
3880
3885
|
|
3881
3886
|
#define MMQ_X_Q3_K_RDNA2 128
|
@@ -3896,9 +3901,9 @@ template <bool need_check> static __global__ void
|
|
3896
3901
|
#if defined(RDNA3) || defined(RDNA2)
|
3897
3902
|
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2)
|
3898
3903
|
#endif // defined(RDNA3) || defined(RDNA2)
|
3899
|
-
#elif __CUDA_ARCH__ <
|
3904
|
+
#elif __CUDA_ARCH__ < CC_VOLTA
|
3900
3905
|
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
|
3901
|
-
#endif // __CUDA_ARCH__ <
|
3906
|
+
#endif // __CUDA_ARCH__ < CC_VOLTA
|
3902
3907
|
mul_mat_q3_K(
|
3903
3908
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
3904
3909
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
@@ -3918,7 +3923,7 @@ template <bool need_check> static __global__ void
|
|
3918
3923
|
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
|
3919
3924
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3920
3925
|
|
3921
|
-
#elif __CUDA_ARCH__ >=
|
3926
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3922
3927
|
const int mmq_x = MMQ_X_Q3_K_AMPERE;
|
3923
3928
|
const int mmq_y = MMQ_Y_Q3_K_AMPERE;
|
3924
3929
|
const int nwarps = NWARPS_Q3_K_AMPERE;
|
@@ -3938,7 +3943,7 @@ template <bool need_check> static __global__ void
|
|
3938
3943
|
#else
|
3939
3944
|
(void) vec_dot_q3_K_q8_1_mul_mat;
|
3940
3945
|
assert(false);
|
3941
|
-
#endif // __CUDA_ARCH__ >=
|
3946
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
3942
3947
|
}
|
3943
3948
|
|
3944
3949
|
#define MMQ_X_Q4_K_RDNA2 64
|
@@ -3959,9 +3964,9 @@ template <bool need_check> static __global__ void
|
|
3959
3964
|
#if defined(RDNA3) || defined(RDNA2)
|
3960
3965
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2)
|
3961
3966
|
#endif // defined(RDNA3) || defined(RDNA2)
|
3962
|
-
#elif __CUDA_ARCH__ <
|
3967
|
+
#elif __CUDA_ARCH__ < CC_VOLTA
|
3963
3968
|
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
|
3964
|
-
#endif // __CUDA_ARCH__ <
|
3969
|
+
#endif // __CUDA_ARCH__ < CC_VOLTA
|
3965
3970
|
mul_mat_q4_K(
|
3966
3971
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
3967
3972
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
@@ -3981,7 +3986,7 @@ template <bool need_check> static __global__ void
|
|
3981
3986
|
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
|
3982
3987
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
3983
3988
|
|
3984
|
-
#elif __CUDA_ARCH__ >=
|
3989
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
3985
3990
|
const int mmq_x = MMQ_X_Q4_K_AMPERE;
|
3986
3991
|
const int mmq_y = MMQ_Y_Q4_K_AMPERE;
|
3987
3992
|
const int nwarps = NWARPS_Q4_K_AMPERE;
|
@@ -4001,7 +4006,7 @@ template <bool need_check> static __global__ void
|
|
4001
4006
|
#else
|
4002
4007
|
(void) vec_dot_q4_K_q8_1_mul_mat;
|
4003
4008
|
assert(false);
|
4004
|
-
#endif // __CUDA_ARCH__ >=
|
4009
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
4005
4010
|
}
|
4006
4011
|
|
4007
4012
|
#define MMQ_X_Q5_K_RDNA2 64
|
@@ -4042,7 +4047,7 @@ mul_mat_q5_K(
|
|
4042
4047
|
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
|
4043
4048
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
4044
4049
|
|
4045
|
-
#elif __CUDA_ARCH__ >=
|
4050
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
4046
4051
|
const int mmq_x = MMQ_X_Q5_K_AMPERE;
|
4047
4052
|
const int mmq_y = MMQ_Y_Q5_K_AMPERE;
|
4048
4053
|
const int nwarps = NWARPS_Q5_K_AMPERE;
|
@@ -4062,7 +4067,7 @@ mul_mat_q5_K(
|
|
4062
4067
|
#else
|
4063
4068
|
(void) vec_dot_q5_K_q8_1_mul_mat;
|
4064
4069
|
assert(false);
|
4065
|
-
#endif // __CUDA_ARCH__ >=
|
4070
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
4066
4071
|
}
|
4067
4072
|
|
4068
4073
|
#define MMQ_X_Q6_K_RDNA2 64
|
@@ -4083,9 +4088,9 @@ template <bool need_check> static __global__ void
|
|
4083
4088
|
#if defined(RDNA3) || defined(RDNA2)
|
4084
4089
|
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2)
|
4085
4090
|
#endif // defined(RDNA3) || defined(RDNA2)
|
4086
|
-
#elif __CUDA_ARCH__ <
|
4091
|
+
#elif __CUDA_ARCH__ < CC_VOLTA
|
4087
4092
|
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
|
4088
|
-
#endif // __CUDA_ARCH__ <
|
4093
|
+
#endif // __CUDA_ARCH__ < CC_VOLTA
|
4089
4094
|
mul_mat_q6_K(
|
4090
4095
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
4091
4096
|
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
@@ -4105,7 +4110,7 @@ template <bool need_check> static __global__ void
|
|
4105
4110
|
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
|
4106
4111
|
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
4107
4112
|
|
4108
|
-
#elif __CUDA_ARCH__ >=
|
4113
|
+
#elif __CUDA_ARCH__ >= CC_VOLTA
|
4109
4114
|
const int mmq_x = MMQ_X_Q6_K_AMPERE;
|
4110
4115
|
const int mmq_y = MMQ_Y_Q6_K_AMPERE;
|
4111
4116
|
const int nwarps = NWARPS_Q6_K_AMPERE;
|
@@ -4125,7 +4130,7 @@ template <bool need_check> static __global__ void
|
|
4125
4130
|
#else
|
4126
4131
|
(void) vec_dot_q6_K_q8_1_mul_mat;
|
4127
4132
|
assert(false);
|
4128
|
-
#endif // __CUDA_ARCH__ >=
|
4133
|
+
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
4129
4134
|
}
|
4130
4135
|
|
4131
4136
|
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
@@ -4604,32 +4609,38 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
|
|
4604
4609
|
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
|
4605
4610
|
}
|
4606
4611
|
|
4607
|
-
|
4612
|
+
template<typename dst_t>
|
4613
|
+
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4608
4614
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4609
4615
|
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4610
4616
|
}
|
4611
4617
|
|
4612
|
-
|
4618
|
+
template<typename dst_t>
|
4619
|
+
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4613
4620
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4614
4621
|
dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4615
4622
|
}
|
4616
4623
|
|
4617
|
-
|
4624
|
+
template<typename dst_t>
|
4625
|
+
static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4618
4626
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4619
4627
|
dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4620
4628
|
}
|
4621
4629
|
|
4622
|
-
|
4630
|
+
template<typename dst_t>
|
4631
|
+
static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4623
4632
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4624
4633
|
dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4625
4634
|
}
|
4626
4635
|
|
4627
|
-
|
4636
|
+
template<typename dst_t>
|
4637
|
+
static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4628
4638
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
4629
4639
|
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
4630
4640
|
}
|
4631
4641
|
|
4632
|
-
|
4642
|
+
template<typename dst_t>
|
4643
|
+
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4633
4644
|
const int nb = k / QK_K;
|
4634
4645
|
#if QK_K == 256
|
4635
4646
|
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
@@ -4638,7 +4649,8 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu
|
|
4638
4649
|
#endif
|
4639
4650
|
}
|
4640
4651
|
|
4641
|
-
|
4652
|
+
template<typename dst_t>
|
4653
|
+
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4642
4654
|
const int nb = k / QK_K;
|
4643
4655
|
#if QK_K == 256
|
4644
4656
|
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
@@ -4647,12 +4659,14 @@ static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cu
|
|
4647
4659
|
#endif
|
4648
4660
|
}
|
4649
4661
|
|
4650
|
-
|
4662
|
+
template<typename dst_t>
|
4663
|
+
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4651
4664
|
const int nb = k / QK_K;
|
4652
4665
|
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
|
4653
4666
|
}
|
4654
4667
|
|
4655
|
-
|
4668
|
+
template<typename dst_t>
|
4669
|
+
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4656
4670
|
const int nb = k / QK_K;
|
4657
4671
|
#if QK_K == 256
|
4658
4672
|
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
@@ -4661,7 +4675,8 @@ static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cu
|
|
4661
4675
|
#endif
|
4662
4676
|
}
|
4663
4677
|
|
4664
|
-
|
4678
|
+
template<typename dst_t>
|
4679
|
+
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
4665
4680
|
const int nb = k / QK_K;
|
4666
4681
|
#if QK_K == 256
|
4667
4682
|
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
@@ -4868,6 +4883,26 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
|
|
4868
4883
|
|
4869
4884
|
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
4870
4885
|
switch (type) {
|
4886
|
+
case GGML_TYPE_Q4_0:
|
4887
|
+
return dequantize_row_q4_0_cuda;
|
4888
|
+
case GGML_TYPE_Q4_1:
|
4889
|
+
return dequantize_row_q4_1_cuda;
|
4890
|
+
case GGML_TYPE_Q5_0:
|
4891
|
+
return dequantize_row_q5_0_cuda;
|
4892
|
+
case GGML_TYPE_Q5_1:
|
4893
|
+
return dequantize_row_q5_1_cuda;
|
4894
|
+
case GGML_TYPE_Q8_0:
|
4895
|
+
return dequantize_row_q8_0_cuda;
|
4896
|
+
case GGML_TYPE_Q2_K:
|
4897
|
+
return dequantize_row_q2_K_cuda;
|
4898
|
+
case GGML_TYPE_Q3_K:
|
4899
|
+
return dequantize_row_q3_K_cuda;
|
4900
|
+
case GGML_TYPE_Q4_K:
|
4901
|
+
return dequantize_row_q4_K_cuda;
|
4902
|
+
case GGML_TYPE_Q5_K:
|
4903
|
+
return dequantize_row_q5_K_cuda;
|
4904
|
+
case GGML_TYPE_Q6_K:
|
4905
|
+
return dequantize_row_q6_K_cuda;
|
4871
4906
|
case GGML_TYPE_F32:
|
4872
4907
|
return convert_fp32_to_fp16_cuda;
|
4873
4908
|
default:
|
@@ -4921,7 +4956,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
|
4921
4956
|
mmq_x = MMQ_X_Q4_0_RDNA1;
|
4922
4957
|
mmq_y = MMQ_Y_Q4_0_RDNA1;
|
4923
4958
|
nwarps = NWARPS_Q4_0_RDNA1;
|
4924
|
-
} else if (compute_capability >=
|
4959
|
+
} else if (compute_capability >= CC_VOLTA) {
|
4925
4960
|
mmq_x = MMQ_X_Q4_0_AMPERE;
|
4926
4961
|
mmq_y = MMQ_Y_Q4_0_AMPERE;
|
4927
4962
|
nwarps = NWARPS_Q4_0_AMPERE;
|
@@ -4966,7 +5001,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
|
4966
5001
|
mmq_x = MMQ_X_Q4_1_RDNA1;
|
4967
5002
|
mmq_y = MMQ_Y_Q4_1_RDNA1;
|
4968
5003
|
nwarps = NWARPS_Q4_1_RDNA1;
|
4969
|
-
} else if (compute_capability >=
|
5004
|
+
} else if (compute_capability >= CC_VOLTA) {
|
4970
5005
|
mmq_x = MMQ_X_Q4_1_AMPERE;
|
4971
5006
|
mmq_y = MMQ_Y_Q4_1_AMPERE;
|
4972
5007
|
nwarps = NWARPS_Q4_1_AMPERE;
|
@@ -5011,7 +5046,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
|
5011
5046
|
mmq_x = MMQ_X_Q5_0_RDNA1;
|
5012
5047
|
mmq_y = MMQ_Y_Q5_0_RDNA1;
|
5013
5048
|
nwarps = NWARPS_Q5_0_RDNA1;
|
5014
|
-
} else if (compute_capability >=
|
5049
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5015
5050
|
mmq_x = MMQ_X_Q5_0_AMPERE;
|
5016
5051
|
mmq_y = MMQ_Y_Q5_0_AMPERE;
|
5017
5052
|
nwarps = NWARPS_Q5_0_AMPERE;
|
@@ -5056,7 +5091,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
|
5056
5091
|
mmq_x = MMQ_X_Q5_1_RDNA1;
|
5057
5092
|
mmq_y = MMQ_Y_Q5_1_RDNA1;
|
5058
5093
|
nwarps = NWARPS_Q5_1_RDNA1;
|
5059
|
-
} else if (compute_capability >=
|
5094
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5060
5095
|
mmq_x = MMQ_X_Q5_1_AMPERE;
|
5061
5096
|
mmq_y = MMQ_Y_Q5_1_AMPERE;
|
5062
5097
|
nwarps = NWARPS_Q5_1_AMPERE;
|
@@ -5101,7 +5136,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
|
5101
5136
|
mmq_x = MMQ_X_Q8_0_RDNA1;
|
5102
5137
|
mmq_y = MMQ_Y_Q8_0_RDNA1;
|
5103
5138
|
nwarps = NWARPS_Q8_0_RDNA1;
|
5104
|
-
} else if (compute_capability >=
|
5139
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5105
5140
|
mmq_x = MMQ_X_Q8_0_AMPERE;
|
5106
5141
|
mmq_y = MMQ_Y_Q8_0_AMPERE;
|
5107
5142
|
nwarps = NWARPS_Q8_0_AMPERE;
|
@@ -5146,7 +5181,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
|
5146
5181
|
mmq_x = MMQ_X_Q2_K_RDNA1;
|
5147
5182
|
mmq_y = MMQ_Y_Q2_K_RDNA1;
|
5148
5183
|
nwarps = NWARPS_Q2_K_RDNA1;
|
5149
|
-
} else if (compute_capability >=
|
5184
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5150
5185
|
mmq_x = MMQ_X_Q2_K_AMPERE;
|
5151
5186
|
mmq_y = MMQ_Y_Q2_K_AMPERE;
|
5152
5187
|
nwarps = NWARPS_Q2_K_AMPERE;
|
@@ -5193,7 +5228,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
|
5193
5228
|
mmq_x = MMQ_X_Q3_K_RDNA1;
|
5194
5229
|
mmq_y = MMQ_Y_Q3_K_RDNA1;
|
5195
5230
|
nwarps = NWARPS_Q3_K_RDNA1;
|
5196
|
-
} else if (compute_capability >=
|
5231
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5197
5232
|
mmq_x = MMQ_X_Q3_K_AMPERE;
|
5198
5233
|
mmq_y = MMQ_Y_Q3_K_AMPERE;
|
5199
5234
|
nwarps = NWARPS_Q3_K_AMPERE;
|
@@ -5239,7 +5274,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
|
5239
5274
|
mmq_x = MMQ_X_Q4_K_RDNA1;
|
5240
5275
|
mmq_y = MMQ_Y_Q4_K_RDNA1;
|
5241
5276
|
nwarps = NWARPS_Q4_K_RDNA1;
|
5242
|
-
} else if (compute_capability >=
|
5277
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5243
5278
|
mmq_x = MMQ_X_Q4_K_AMPERE;
|
5244
5279
|
mmq_y = MMQ_Y_Q4_K_AMPERE;
|
5245
5280
|
nwarps = NWARPS_Q4_K_AMPERE;
|
@@ -5284,7 +5319,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
|
5284
5319
|
mmq_x = MMQ_X_Q5_K_RDNA1;
|
5285
5320
|
mmq_y = MMQ_Y_Q5_K_RDNA1;
|
5286
5321
|
nwarps = NWARPS_Q5_K_RDNA1;
|
5287
|
-
} else if (compute_capability >=
|
5322
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5288
5323
|
mmq_x = MMQ_X_Q5_K_AMPERE;
|
5289
5324
|
mmq_y = MMQ_Y_Q5_K_AMPERE;
|
5290
5325
|
nwarps = NWARPS_Q5_K_AMPERE;
|
@@ -5329,7 +5364,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
|
|
5329
5364
|
mmq_x = MMQ_X_Q6_K_RDNA1;
|
5330
5365
|
mmq_y = MMQ_Y_Q6_K_RDNA1;
|
5331
5366
|
nwarps = NWARPS_Q6_K_RDNA1;
|
5332
|
-
} else if (compute_capability >=
|
5367
|
+
} else if (compute_capability >= CC_VOLTA) {
|
5333
5368
|
mmq_x = MMQ_X_Q6_K_AMPERE;
|
5334
5369
|
mmq_y = MMQ_Y_Q6_K_AMPERE;
|
5335
5370
|
nwarps = NWARPS_Q6_K_AMPERE;
|
@@ -5907,7 +5942,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
5907
5942
|
switch(type) {
|
5908
5943
|
case GGML_TYPE_Q4_0:
|
5909
5944
|
case GGML_TYPE_Q4_1:
|
5910
|
-
return max_compute_capability >=
|
5945
|
+
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
5911
5946
|
case GGML_TYPE_Q5_0:
|
5912
5947
|
case GGML_TYPE_Q5_1:
|
5913
5948
|
case GGML_TYPE_Q8_0:
|
@@ -5918,7 +5953,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|
5918
5953
|
case GGML_TYPE_Q3_K:
|
5919
5954
|
case GGML_TYPE_Q4_K:
|
5920
5955
|
case GGML_TYPE_Q5_K:
|
5921
|
-
return max_compute_capability >=
|
5956
|
+
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
5922
5957
|
case GGML_TYPE_Q6_K:
|
5923
5958
|
return 64;
|
5924
5959
|
default:
|
@@ -6083,8 +6118,19 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
6083
6118
|
|
6084
6119
|
const int compute_capability = g_compute_capabilities[id];
|
6085
6120
|
|
6086
|
-
if (compute_capability >=
|
6087
|
-
// convert src1 to fp16, multiply as fp16, convert dst to fp32
|
6121
|
+
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
6122
|
+
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
6123
|
+
half * src0_as_f16 = nullptr;
|
6124
|
+
size_t src0_as = 0;
|
6125
|
+
if (src0->type != GGML_TYPE_F16) {
|
6126
|
+
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
6127
|
+
GGML_ASSERT(to_fp16_cuda != nullptr);
|
6128
|
+
size_t ne = row_diff*ne00;
|
6129
|
+
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
|
6130
|
+
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
6131
|
+
}
|
6132
|
+
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
6133
|
+
|
6088
6134
|
half * src1_as_f16 = nullptr;
|
6089
6135
|
size_t src1_as = 0;
|
6090
6136
|
if (src1->type != GGML_TYPE_F16) {
|
@@ -6106,9 +6152,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
6106
6152
|
CUBLAS_CHECK(
|
6107
6153
|
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
6108
6154
|
row_diff, src1_ncols, ne10,
|
6109
|
-
&alpha_f16,
|
6110
|
-
src1_ptr,
|
6111
|
-
&beta_f16, dst_f16,
|
6155
|
+
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
6156
|
+
src1_ptr, CUDA_R_16F, ne10,
|
6157
|
+
&beta_f16, dst_f16, CUDA_R_16F, ldc,
|
6112
6158
|
CUBLAS_COMPUTE_16F,
|
6113
6159
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
6114
6160
|
|
@@ -6117,6 +6163,10 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
6117
6163
|
|
6118
6164
|
ggml_cuda_pool_free(dst_f16, dst_as);
|
6119
6165
|
|
6166
|
+
if (src0_as != 0) {
|
6167
|
+
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
6168
|
+
}
|
6169
|
+
|
6120
6170
|
if (src1_as != 0) {
|
6121
6171
|
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
6122
6172
|
}
|
@@ -1213,12 +1213,9 @@ void ggml_metal_graph_compute(
|
|
1213
1213
|
float max_bias;
|
1214
1214
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
1215
1215
|
|
1216
|
-
if (__builtin_popcount(n_head) != 1) {
|
1217
|
-
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
1218
|
-
}
|
1219
|
-
|
1220
1216
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
1221
1217
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
1218
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
1222
1219
|
|
1223
1220
|
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
1224
1221
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -1239,7 +1236,9 @@ void ggml_metal_graph_compute(
|
|
1239
1236
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1240
1237
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1241
1238
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1242
|
-
[encoder setBytes:&m0
|
1239
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1240
|
+
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
1241
|
+
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
1243
1242
|
|
1244
1243
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1245
1244
|
} break;
|
@@ -830,7 +830,9 @@ kernel void kernel_alibi_f32(
|
|
830
830
|
constant uint64_t & nb1,
|
831
831
|
constant uint64_t & nb2,
|
832
832
|
constant uint64_t & nb3,
|
833
|
-
constant
|
833
|
+
constant float & m0,
|
834
|
+
constant float & m1,
|
835
|
+
constant int & n_heads_log2_floor,
|
834
836
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
835
837
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
836
838
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -846,7 +848,12 @@ kernel void kernel_alibi_f32(
|
|
846
848
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
847
849
|
|
848
850
|
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
849
|
-
float m_k
|
851
|
+
float m_k;
|
852
|
+
if (i2 < n_heads_log2_floor) {
|
853
|
+
m_k = pow(m0, i2 + 1);
|
854
|
+
} else {
|
855
|
+
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
|
856
|
+
}
|
850
857
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
851
858
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
852
859
|
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|