llama_cpp 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,5 +1,6 @@
1
1
  #include <cstddef>
2
2
  #include <cstdint>
3
+ #include <limits>
3
4
  #include <stdint.h>
4
5
  #include <stdio.h>
5
6
  #include <atomic>
@@ -24,7 +25,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
24
25
  } \
25
26
  } while (0)
26
27
 
27
- #if CUDART_VERSION >= 12
28
+ #if CUDART_VERSION >= 12000
28
29
  #define CUBLAS_CHECK(err) \
29
30
  do { \
30
31
  cublasStatus_t err_ = (err); \
@@ -48,6 +49,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
48
49
  typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
49
50
  typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
50
51
  typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
52
+ typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
51
53
  typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
52
54
  typedef void (*ggml_cuda_op_t)(
53
55
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
@@ -151,7 +153,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
151
153
  #define CUDA_ADD_BLOCK_SIZE 256
152
154
  #define CUDA_MUL_BLOCK_SIZE 256
153
155
  #define CUDA_SILU_BLOCK_SIZE 256
156
+ #define CUDA_CPY_BLOCK_SIZE 32
157
+ #define CUDA_SCALE_BLOCK_SIZE 256
154
158
  #define CUDA_ROPE_BLOCK_SIZE 256
159
+ #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
155
160
  #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
156
161
 
157
162
  // dmmv = dequantize_mul_mat_vec
@@ -655,10 +660,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
655
660
  }
656
661
 
657
662
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
658
- static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
663
+ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
659
664
  // qk = quantized weights per x block
660
665
  // qr = number of quantized weights per data value in x block
661
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
666
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
667
+
668
+ if (row >= nrows) {
669
+ return;
670
+ }
671
+
662
672
  const int tid = threadIdx.x;
663
673
 
664
674
  const int iter_stride = 2*GGML_CUDA_DMMV_X;
@@ -703,8 +713,13 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
703
713
  }
704
714
 
705
715
  template <int n_thread, dot_kernel_k_t dot_kernel>
706
- static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) {
707
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
716
+ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
717
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
718
+
719
+ if (row >= nrows) {
720
+ return;
721
+ }
722
+
708
723
  const int tid = threadIdx.x;
709
724
 
710
725
  const int iter_stride = QK_K;
@@ -737,6 +752,139 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y
737
752
  }
738
753
  }
739
754
 
755
+ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
756
+ const half * x = (half *) vx;
757
+
758
+ const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
759
+ const int channel = blockDim.z*blockIdx.z + threadIdx.z;
760
+
761
+ const int nrows_y = ncols_x;
762
+ const int nrows_dst = nrows_x;
763
+ const int row_dst = row_x;
764
+
765
+ float tmp = 0.0f;
766
+
767
+ for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
768
+ const int col_x = col_x0 + threadIdx.x;
769
+
770
+ if (col_x >= ncols_x) {
771
+ break;
772
+ }
773
+
774
+ // x is transposed and permuted
775
+ const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
776
+ const float xi = __half2float(x[ix]);
777
+
778
+ const int row_y = col_x;
779
+
780
+
781
+ // y is not transposed but permuted
782
+ const int iy = channel*nrows_y + row_y;
783
+
784
+ tmp += xi * y[iy];
785
+ }
786
+
787
+ // dst is not transposed and not permuted
788
+ const int idst = channel*nrows_dst + row_dst;
789
+
790
+ // sum up partial sums and write back result
791
+ __syncthreads();
792
+ #pragma unroll
793
+ for (int mask = 16; mask > 0; mask >>= 1) {
794
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
795
+ }
796
+
797
+ if (threadIdx.x == 0) {
798
+ dst[idst] = tmp;
799
+ }
800
+ }
801
+
802
+ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
803
+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
804
+ const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
805
+
806
+ const half * x = (half *) vx;
807
+
808
+ const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
809
+ const int channel = blockDim.z*blockIdx.z + threadIdx.z;
810
+
811
+ const int nrows_y = ncols_x;
812
+ const int nrows_dst = nrows_x;
813
+ const int row_dst = row_x;
814
+
815
+ const int idst = channel*nrows_dst + row_dst;
816
+
817
+ float tmp = 0.0f;
818
+
819
+ for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
820
+ const int col_x = col_x0 + threadIdx.x;
821
+
822
+ if (col_x >= ncols_x) {
823
+ break;
824
+ }
825
+
826
+ const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
827
+ const float xi = __half2float(x[ix]);
828
+
829
+ const int row_y = col_x;
830
+
831
+ const int iy = channel*nrows_y + row_y;
832
+
833
+ tmp += xi * y[iy];
834
+ }
835
+
836
+ // sum up partial sums and write back result
837
+ __syncthreads();
838
+ #pragma unroll
839
+ for (int mask = 16; mask > 0; mask >>= 1) {
840
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
841
+ }
842
+
843
+ if (threadIdx.x == 0) {
844
+ dst[idst] = tmp;
845
+ }
846
+ }
847
+
848
+ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
849
+ const float * xi = (float *) cxi;
850
+ float * dsti = (float *) cdsti;
851
+
852
+ *dsti = *xi;
853
+ }
854
+
855
+ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
856
+ const float * xi = (float *) cxi;
857
+ half * dsti = (half *) cdsti;
858
+
859
+ *dsti = __float2half(*xi);
860
+ }
861
+
862
+ template <cpy_kernel_t cpy_1>
863
+ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
864
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
865
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
866
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
867
+
868
+ if (i >= ne) {
869
+ return;
870
+ }
871
+
872
+ // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
873
+ // then combine those indices with the corresponding byte offsets to get the total offsets
874
+ const int i02 = i / (ne00*ne01);
875
+ const int i01 = (i - i02*ne01*ne00) / ne00;
876
+ const int i00 = i - i02*ne01*ne00 - i01*ne00;
877
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
878
+
879
+ const int i12 = i / (ne10*ne11);
880
+ const int i11 = (i - i12*ne10*ne11) / ne10;
881
+ const int i10 = i - i12*ne10*ne11 - i11*ne10;
882
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
883
+
884
+ cpy_1(cx + x_offset, cdst + dst_offset);
885
+ }
886
+
887
+ // rope == RoPE == rotary positional embedding
740
888
  static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
741
889
  const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
742
890
 
@@ -758,6 +906,72 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
758
906
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
759
907
  }
760
908
 
909
+ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
910
+ const int col = blockDim.x*blockIdx.x + threadIdx.x;
911
+ const int row = blockDim.y*blockIdx.y + threadIdx.y;
912
+
913
+ if (col >= ncols) {
914
+ return;
915
+ }
916
+
917
+ const int i = row*ncols + col;
918
+ // dst[i] = col > n_past + row ? -INFINITY : x[i];
919
+ dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
920
+ }
921
+
922
+ // the CUDA soft max implementation differs from the CPU implementation
923
+ // instead of doubles floats are used
924
+ // values are also not normalized to the maximum value by subtracting it in the exponential function
925
+ // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
926
+ static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
927
+ const int row = blockDim.y*blockIdx.y + threadIdx.y;
928
+ const int block_size = blockDim.x;
929
+ const int tid = threadIdx.x;
930
+
931
+ float tmp = 0.0;
932
+
933
+ for (int block_start = 0; block_start < ncols; block_start += block_size) {
934
+ const int col = block_start + tid;
935
+
936
+ if (col >= ncols) {
937
+ break;
938
+ }
939
+
940
+ const int i = row*ncols + col;
941
+ const float val = expf(x[i]);
942
+ tmp += val;
943
+ dst[i] = val;
944
+ }
945
+
946
+ // sum up partial sums
947
+ __syncthreads();
948
+ #pragma unroll
949
+ for (int mask = 16; mask > 0; mask >>= 1) {
950
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
951
+ }
952
+
953
+ for (int block_start = 0; block_start < ncols; block_start += block_size) {
954
+ const int col = block_start + tid;
955
+
956
+ if (col >= ncols) {
957
+ break;
958
+ }
959
+
960
+ const int i = row*ncols + col;
961
+ dst[i] /= tmp;
962
+ }
963
+ }
964
+
965
+ static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
966
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
967
+
968
+ if (i >= k) {
969
+ return;
970
+ }
971
+
972
+ dst[i] = scale * x[i];
973
+ }
974
+
761
975
  static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
762
976
  const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
763
977
  add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
@@ -831,73 +1045,92 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
831
1045
 
832
1046
  static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
833
1047
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
834
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
1048
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1049
+ const dim3 block_nums(1, block_num_y, 1);
835
1050
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
836
1051
  dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
837
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
1052
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
838
1053
  }
839
1054
 
840
1055
  static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
841
1056
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
842
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
1057
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1058
+ const dim3 block_nums(1, block_num_y, 1);
843
1059
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
844
1060
  dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
845
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
1061
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
846
1062
  }
847
1063
 
848
1064
  static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
849
1065
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
850
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
1066
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1067
+ const dim3 block_nums(1, block_num_y, 1);
851
1068
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
852
1069
  dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
853
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
1070
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
854
1071
  }
855
1072
 
856
1073
  static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
857
1074
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
858
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
1075
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1076
+ const dim3 block_nums(1, block_num_y, 1);
859
1077
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
860
1078
  dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
861
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
1079
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
862
1080
  }
863
1081
 
864
1082
  static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
865
1083
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
866
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
1084
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1085
+ const dim3 block_nums(1, block_num_y, 1);
867
1086
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
868
1087
  dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
869
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
1088
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
870
1089
  }
871
1090
 
872
1091
  static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
873
1092
  GGML_ASSERT(ncols % QK_K == 0);
874
1093
  const int ny = 2;
1094
+ const int block_num_y = (nrows + ny - 1) / ny;
1095
+ const dim3 block_nums(1, block_num_y, 1);
875
1096
  const dim3 block_dims(32, ny, 1);
876
- dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols);
1097
+ dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
877
1098
  }
878
1099
 
879
1100
  static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
880
1101
  GGML_ASSERT(ncols % QK_K == 0);
881
- const dim3 block_dims(32, 2, 1);
882
- dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
1102
+ const int ny = 2;
1103
+ const int block_num_y = (nrows + ny - 1) / ny;
1104
+ const dim3 block_nums(1, block_num_y, 1);
1105
+ const dim3 block_dims(32, ny, 1);
1106
+ dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
883
1107
  }
884
1108
 
885
1109
  static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
886
1110
  GGML_ASSERT(ncols % QK_K == 0);
887
- const dim3 block_dims(32, 2, 1);
888
- dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
1111
+ const int ny = 2;
1112
+ const int block_num_y = (nrows + ny - 1) / ny;
1113
+ const dim3 block_nums(1, block_num_y, 1);
1114
+ const dim3 block_dims(32, ny, 1);
1115
+ dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
889
1116
  }
890
1117
 
891
1118
  static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
892
1119
  GGML_ASSERT(ncols % QK_K == 0);
893
- const dim3 block_dims(32, 2, 1);
894
- dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
1120
+ const int ny = 2;
1121
+ const int block_num_y = (nrows + ny - 1) / ny;
1122
+ const dim3 block_nums(1, block_num_y, 1);
1123
+ const dim3 block_dims(32, ny, 1);
1124
+ dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
895
1125
  }
896
1126
 
897
1127
  static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
898
1128
  GGML_ASSERT(ncols % QK_K == 0);
899
- const dim3 block_dims(32, 2, 1);
900
- dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
1129
+ const int ny = 2;
1130
+ const int block_num_y = (nrows + ny - 1) / ny;
1131
+ const dim3 block_nums(1, block_num_y, 1);
1132
+ const dim3 block_dims(32, ny, 1);
1133
+ dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
901
1134
  }
902
1135
 
903
1136
  static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -907,10 +1140,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
907
1140
 
908
1141
  static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
909
1142
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
910
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
1143
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1144
+ const dim3 block_nums(1, block_num_y, 1);
911
1145
  const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
912
1146
  dequantize_mul_mat_vec<1, 1, convert_f16>
913
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
1147
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
914
1148
  }
915
1149
 
916
1150
  static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
@@ -942,6 +1176,47 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
942
1176
  }
943
1177
  }
944
1178
 
1179
+ static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
1180
+ const dim3 block_nums(1, nrows_x, nchannels_x);
1181
+ const dim3 block_dims(WARP_SIZE, 1, 1);
1182
+ mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
1183
+ }
1184
+
1185
+ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
1186
+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
1187
+ const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
1188
+
1189
+ const dim3 block_nums(1, nrows_x, nchannels_x);
1190
+ const dim3 block_dims(WARP_SIZE, 1, 1);
1191
+ mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
1192
+ (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
1193
+ }
1194
+
1195
+ static void ggml_cpy_f32_f32_cuda(
1196
+ const char * cx, char * cdst, const int ne,
1197
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
1198
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
1199
+
1200
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
1201
+ cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
1202
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
1203
+ }
1204
+
1205
+ static void ggml_cpy_f32_f16_cuda(
1206
+ const char * cx, char * cdst, const int ne,
1207
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
1208
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
1209
+
1210
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
1211
+ cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
1212
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
1213
+ }
1214
+
1215
+ static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
1216
+ const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
1217
+ scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
1218
+ }
1219
+
945
1220
  static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
946
1221
  GGML_ASSERT(nrows % 2 == 0);
947
1222
  const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
@@ -950,6 +1225,19 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
950
1225
  rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
951
1226
  }
952
1227
 
1228
+ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
1229
+ const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
1230
+ const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
1231
+ const dim3 block_nums(block_num_x, nrows_x, 1);
1232
+ diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
1233
+ }
1234
+
1235
+ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
1236
+ const dim3 block_dims(WARP_SIZE, 1, 1);
1237
+ const dim3 block_nums(1, nrows_x, 1);
1238
+ soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
1239
+ }
1240
+
953
1241
  // buffer pool for cuda
954
1242
  #define MAX_CUDA_BUFFERS 256
955
1243
 
@@ -1105,6 +1393,9 @@ void * ggml_cuda_host_malloc(size_t size) {
1105
1393
  void * ptr = nullptr;
1106
1394
  cudaError_t err = cudaMallocHost((void **) &ptr, size);
1107
1395
  if (err != cudaSuccess) {
1396
+ // The allocation error can be bypassed. A null ptr will assigned out of this function.
1397
+ // This can fixed the OOM error in WSL.
1398
+ cudaGetLastError();
1108
1399
  fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
1109
1400
  size/1024.0/1024.0, cudaGetErrorString(err));
1110
1401
  return nullptr;
@@ -1117,10 +1408,25 @@ void ggml_cuda_host_free(void * ptr) {
1117
1408
  CUDA_CHECK(cudaFreeHost(ptr));
1118
1409
  }
1119
1410
 
1120
- static cudaError_t ggml_cuda_h2d_tensor_2d(
1411
+ static cudaError_t ggml_cuda_cpy_tensor_2d(
1121
1412
  void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
1122
1413
 
1123
- char * dst_char = (char *) dst;
1414
+ cudaMemcpyKind kind;
1415
+ char * src_ptr;
1416
+ if (src->backend == GGML_BACKEND_CPU) {
1417
+ kind = cudaMemcpyHostToDevice;
1418
+ src_ptr = (char *) src->data;
1419
+ } else if (src->backend == GGML_BACKEND_GPU) {
1420
+ kind = cudaMemcpyDeviceToDevice;
1421
+ struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
1422
+ int id;
1423
+ CUDA_CHECK(cudaGetDevice(&id));
1424
+ src_ptr = (char *) extra->data_device[id];
1425
+ } else {
1426
+ GGML_ASSERT(false);
1427
+ }
1428
+ char * dst_ptr = (char *) dst;
1429
+
1124
1430
  const int64_t ne0 = src->ne[0];
1125
1431
  const int64_t nb0 = src->nb[0];
1126
1432
  const int64_t nb1 = src->nb[1];
@@ -1131,17 +1437,17 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(
1131
1437
  const int64_t bs = ggml_blck_size(type);
1132
1438
  int64_t i1_diff = i1_high - i1_low;
1133
1439
 
1134
- const void * x = (const void *) ((const char *) src->data + i1_low*nb1 + i2*nb2 + i3*nb3);
1440
+ const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
1135
1441
  if (nb0 == ts && nb1 == ts*ne0/bs) {
1136
- return cudaMemcpyAsync(dst_char, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream);
1442
+ return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
1137
1443
  } else if (nb0 == ts) {
1138
- return cudaMemcpy2DAsync(dst_char, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyHostToDevice, stream);
1444
+ return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
1139
1445
  } else {
1140
1446
  for (int64_t i1 = 0; i1 < i1_diff; i1++) {
1141
1447
  const void * rx = (const void *) ((const char *) x + i1*nb1);
1142
- void * rd = (void *) (dst_char + i1*ts*ne0/bs);
1448
+ void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
1143
1449
  // pretend the row is a matrix with cols=1
1144
- cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
1450
+ cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
1145
1451
  if (r != cudaSuccess) return r;
1146
1452
  }
1147
1453
  return cudaSuccess;
@@ -1377,8 +1683,81 @@ inline void ggml_cuda_op_rope(
1377
1683
  (void) i1;
1378
1684
  }
1379
1685
 
1686
+ inline void ggml_cuda_op_diag_mask_inf(
1687
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1688
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1689
+ cudaStream_t & cudaStream_main){
1690
+
1691
+ GGML_ASSERT(src0_ddf_i != nullptr);
1692
+ GGML_ASSERT(dst_ddf_i != nullptr);
1693
+
1694
+ const int64_t ne00 = src0->ne[0];
1695
+ const int64_t ne01 = src0->ne[1];
1696
+ const int64_t i01_diff = i01_high - i01_low;
1697
+
1698
+ const int n_past = ((int32_t *) src1->data)[0];
1699
+
1700
+ // compute
1701
+ diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
1702
+ CUDA_CHECK(cudaGetLastError());
1703
+
1704
+ (void) dst;
1705
+ (void) src0_ddq_i;
1706
+ (void) src1_ddf_i;
1707
+ (void) i02;
1708
+ (void) i1;
1709
+ }
1710
+
1711
+ inline void ggml_cuda_op_soft_max(
1712
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1713
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1714
+ cudaStream_t & cudaStream_main){
1715
+
1716
+ GGML_ASSERT(src0_ddf_i != nullptr);
1717
+ GGML_ASSERT(dst_ddf_i != nullptr);
1718
+
1719
+ const int64_t ne00 = src0->ne[0];
1720
+ const int64_t i01_diff = i01_high - i01_low;
1721
+
1722
+ // compute
1723
+ soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
1724
+ CUDA_CHECK(cudaGetLastError());
1725
+
1726
+ (void) src1;
1727
+ (void) dst;
1728
+ (void) src0_ddq_i;
1729
+ (void) src1_ddf_i;
1730
+ (void) i02;
1731
+ (void) i1;
1732
+ }
1733
+
1734
+ inline void ggml_cuda_op_scale(
1735
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
1736
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1737
+ cudaStream_t & cudaStream_main){
1738
+
1739
+ GGML_ASSERT(src0_ddf_i != nullptr);
1740
+ GGML_ASSERT(dst_ddf_i != nullptr);
1741
+
1742
+ const float scale = ((float *) src1->data)[0];
1743
+
1744
+ const int64_t ne00 = src0->ne[0];
1745
+ const int64_t i01_diff = i01_high - i01_low;
1746
+
1747
+ // compute
1748
+ scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main);
1749
+ CUDA_CHECK(cudaGetLastError());
1750
+
1751
+ (void) src1;
1752
+ (void) dst;
1753
+ (void) src0_ddq_i;
1754
+ (void) src1_ddf_i;
1755
+ (void) i02;
1756
+ (void) i1;
1757
+ }
1758
+
1380
1759
  static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
1381
- ggml_cuda_op_t op, bool src0_needs_f32) {
1760
+ ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
1382
1761
  const int64_t ne00 = src0->ne[0];
1383
1762
  const int64_t ne01 = src0->ne[1];
1384
1763
  const int64_t ne02 = src0->ne[2];
@@ -1401,21 +1780,27 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1401
1780
  GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
1402
1781
 
1403
1782
  // strides for iteration over dims 3 and 2
1404
- const int64_t src0_stride = ne00 * ne01;
1405
- const int64_t src1_stride = ne10 * ne11;
1406
- const int64_t dst_stride = ne0 * ne1;
1407
- const int64_t num_iters = ne02 * ne03;
1783
+ const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
1784
+ const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
1785
+ const int64_t src0_stride = ne00 * ne01 * stride_mod;
1786
+ const int64_t src1_stride = ne10 * ne11 * stride_mod;
1787
+ const int64_t dst_stride = ne0 * ne1 * stride_mod;
1408
1788
 
1409
1789
  const size_t src0_ts = ggml_type_size(src0->type);
1410
1790
  const size_t src0_bs = ggml_blck_size(src0->type);
1411
1791
 
1412
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
1792
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
1413
1793
  struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
1414
- struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
1794
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
1415
1795
 
1416
1796
  const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
1797
+ const bool src0_is_contiguous = ggml_is_contiguous(src0);
1417
1798
  const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
1418
1799
 
1800
+ const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
1801
+ const bool src1_stays_on_host = use_src1 && (
1802
+ dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
1803
+
1419
1804
  const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
1420
1805
 
1421
1806
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
@@ -1424,13 +1809,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1424
1809
  char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
1425
1810
  float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
1426
1811
  float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
1427
- float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
1812
+ float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
1428
1813
 
1429
1814
  // asq = actual size quantized, asf = actual size float
1430
1815
  size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
1431
1816
  size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
1432
1817
  size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
1433
- size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
1818
+ size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
1434
1819
 
1435
1820
  for (int id = 0; id < g_device_count; ++id) {
1436
1821
  if (!split && id != g_main_device) {
@@ -1443,9 +1828,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1443
1828
  int64_t row_low, row_high;
1444
1829
  if (split) {
1445
1830
  row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
1446
- row_low -= row_low % GGML_CUDA_DMMV_Y;
1447
1831
  row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
1448
- row_high -= row_high % GGML_CUDA_DMMV_Y;
1449
1832
  } else {
1450
1833
  row_low = 0;
1451
1834
  row_high = nrows0;
@@ -1458,7 +1841,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1458
1841
 
1459
1842
  cudaSetDevice(id);
1460
1843
 
1461
- if (src0_on_device) {
1844
+ if (src0_on_device && src0_is_contiguous) {
1462
1845
  if (src0_is_f32) {
1463
1846
  src0_ddf[id] = (float *) src0_extra->data_device[id];
1464
1847
  } else {
@@ -1476,8 +1859,8 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1476
1859
  src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
1477
1860
  }
1478
1861
 
1479
- if (use_src1) {
1480
- if (src1_on_device) {
1862
+ if (use_src1 && !src1_stays_on_host) {
1863
+ if (src1_on_device && src1_is_contiguous) {
1481
1864
  src1_ddf[id] = (float *) src1_extra->data_device[id];
1482
1865
  } else {
1483
1866
  src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
@@ -1490,26 +1873,32 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1490
1873
  dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
1491
1874
  }
1492
1875
 
1493
- for (int64_t i03 = 0; i03 < ne03; i03++) {
1876
+ const int64_t i03_max = flatten_rows ? 1 : ne03;
1877
+ const int64_t i02_max = flatten_rows ? 1 : ne02;
1878
+ const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
1879
+
1880
+ for (int64_t i03 = 0; i03 < i03_max; i03++) {
1494
1881
  const int64_t i13 = i03 % ne13;
1495
- for (int64_t i02 = 0; i02 < ne02; i02++) {
1882
+ for (int64_t i02 = 0; i02 < i02_max; i02++) {
1496
1883
  const int64_t i12 = i02 % ne12;
1497
1884
 
1498
1885
  const int64_t i0 = i03*ne02 + i02;
1499
- const int64_t i0_offset_low = row_low/ne01;
1500
- const int64_t i0_offset_high = row_high/ne01;
1886
+
1887
+ // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
1888
+ const int64_t i0_offset_low = row_low/rows_per_iter;
1889
+ const int64_t i0_offset_high = row_high/rows_per_iter;
1501
1890
 
1502
1891
  int64_t i01_low = 0;
1503
- int64_t i01_high = ne01;
1892
+ int64_t i01_high = rows_per_iter;
1504
1893
  if (split) {
1505
1894
  if (i0 < i0_offset_low || i0 > i0_offset_high) {
1506
1895
  continue;
1507
1896
  }
1508
1897
  if (i0 == i0_offset_low) {
1509
- i01_low = row_low % ne01;
1898
+ i01_low = row_low % rows_per_iter;
1510
1899
  }
1511
1900
  if (i0 == i0_offset_high) {
1512
- i01_high = row_high % ne01;
1901
+ i01_high = row_high % rows_per_iter;
1513
1902
  }
1514
1903
  }
1515
1904
 
@@ -1518,7 +1907,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1518
1907
  // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
1519
1908
  // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
1520
1909
  GGML_ASSERT(i01_low == 0 || g_device_count > 1);
1521
- GGML_ASSERT(i01_high == ne01 || g_device_count > 1);
1910
+ GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
1522
1911
 
1523
1912
  const int64_t i01_diff = i01_high - i01_low;
1524
1913
  if (i01_diff == 0) {
@@ -1526,24 +1915,23 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1526
1915
  }
1527
1916
  const int64_t i11 = i13*ne12 + i12;
1528
1917
 
1529
- cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
1918
+ cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
1530
1919
  cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
1531
- cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
1920
+ cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
1532
1921
 
1533
1922
  // for split tensors the data begins at i0 == i0_offset_low
1534
1923
  char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
1535
1924
  float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
1536
1925
  float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
1537
- float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
1926
+ float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
1538
1927
 
1539
1928
  // for split tensors the data pointer needs to be rounded down
1540
1929
  // to the bin edge for i03, i02 bins beyond the first
1541
1930
  if (i0 - i0_offset_low > 0) {
1931
+ GGML_ASSERT(!flatten_rows);
1542
1932
  src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
1543
1933
  src0_ddf_i -= (row_low % ne01)*ne00;
1544
- }
1545
- if (i0 - i0_offset_low > 0) {
1546
- dst_ddf_i -= (row_low % ne0)*ne1;
1934
+ dst_ddf_i -= (row_low % ne0)*ne1;
1547
1935
  }
1548
1936
 
1549
1937
  // the main device memory buffer can be on VRAM scratch, with space for all partial results
@@ -1553,30 +1941,37 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1553
1941
  }
1554
1942
 
1555
1943
  // copy src0, src1 to device if necessary
1556
- if (use_src1) {
1944
+ if (use_src1 && !src1_stays_on_host) {
1557
1945
  if (src1->backend == GGML_BACKEND_CPU) {
1558
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1));
1559
- } else if (src1->backend == GGML_BACKEND_GPU) {
1946
+ GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
1947
+ int64_t nrows1 = flatten_rows ? nrows0 : ne11;
1948
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
1949
+ } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
1560
1950
  if (id != g_main_device) {
1951
+ GGML_ASSERT(!flatten_rows);
1561
1952
  float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
1562
1953
  src1_ddf_i_source += i11*src1_stride;
1563
1954
  CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
1564
1955
  cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
1565
1956
  }
1957
+ } else if (src1_on_device && !src1_is_contiguous) {
1958
+ GGML_ASSERT(!split);
1959
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
1566
1960
  } else {
1567
1961
  GGML_ASSERT(false);
1568
1962
  }
1569
1963
  }
1570
1964
  CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
1571
- if (!src0_on_device) {
1965
+
1966
+ if (!src0_on_device || !src0_is_contiguous) {
1572
1967
  if (src0_is_f32) {
1573
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
1968
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
1574
1969
  } else {
1575
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
1970
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
1576
1971
  }
1577
1972
  }
1578
1973
 
1579
- // convert src0 to f32 if it's necessary for the ggml_cuda_op
1974
+ // convert src0 to f32 if it is necessary for the ggml_cuda_op
1580
1975
  if (src0_needs_f32 && !src0_is_f32) {
1581
1976
  to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
1582
1977
  CUDA_CHECK(cudaGetLastError());
@@ -1641,39 +2036,30 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1641
2036
 
1642
2037
  void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1643
2038
  GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1644
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true);
2039
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
1645
2040
  }
1646
2041
 
1647
2042
  void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1648
2043
  GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1649
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true);
2044
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
1650
2045
  }
1651
2046
 
1652
2047
  void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1653
2048
  GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1654
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true);
2049
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
1655
2050
  }
1656
2051
 
1657
2052
  void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1658
2053
  GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1659
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true);
2054
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
1660
2055
  }
1661
2056
 
1662
2057
  bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1663
- GGML_ASSERT(src0->backend != GGML_BACKEND_GPU);
1664
2058
  const int64_t ne10 = src1->ne[0];
1665
2059
 
1666
2060
  const int64_t ne0 = dst->ne[0];
1667
2061
  const int64_t ne1 = dst->ne[1];
1668
2062
 
1669
- // if (strcmp(dst->name, "KQ") == 0 || strcmp(dst->name, "KQV") == 0) {
1670
- // fprintf(stderr, "(%ld, %ld, %ld, %ld) + (%ld, %ld, %ld, %ld) -> (%ld, %ld, %ld, %ld)\n",
1671
- // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
1672
- // src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
1673
- // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
1674
- // return false;
1675
- // }
1676
-
1677
2063
  // TODO: find the optimal values for these
1678
2064
  if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
1679
2065
  src1->type == GGML_TYPE_F32 &&
@@ -1685,23 +2071,158 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
1685
2071
  return false;
1686
2072
  }
1687
2073
 
2074
+ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
2075
+ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
2076
+ GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
2077
+ GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
2078
+ GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
2079
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
2080
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
2081
+
2082
+ const int64_t ne00 = src0->ne[0];
2083
+ const int64_t ne01 = src0->ne[1];
2084
+ const int64_t ne02 = src0->ne[2];
2085
+
2086
+ CUDA_CHECK(cudaSetDevice(g_main_device));
2087
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2088
+
2089
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2090
+ void * src0_ddq = src0_extra->data_device[g_main_device];
2091
+
2092
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
2093
+ float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
2094
+
2095
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
2096
+ float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
2097
+
2098
+ ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
2099
+
2100
+ CUDA_CHECK(cudaDeviceSynchronize());
2101
+ }
2102
+
2103
+ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
2104
+ GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
2105
+ GGML_ASSERT(!ggml_is_permuted(src0));
2106
+ GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
2107
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
2108
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
2109
+
2110
+ const int64_t ne00 = src0->ne[0];
2111
+ const int64_t ne01 = src0->ne[1];
2112
+ const int64_t ne02 = src0->ne[2];
2113
+
2114
+ const int64_t nb01 = src0->nb[1];
2115
+ const int64_t nb02 = src0->nb[2];
2116
+
2117
+ CUDA_CHECK(cudaSetDevice(g_main_device));
2118
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2119
+
2120
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2121
+ void * src0_ddq = src0_extra->data_device[g_main_device];
2122
+
2123
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
2124
+ float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
2125
+
2126
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
2127
+ float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
2128
+
2129
+ const int row_stride_x = nb01 / sizeof(half);
2130
+ const int channel_stride_x = nb02 / sizeof(half);
2131
+
2132
+ ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
2133
+
2134
+ CUDA_CHECK(cudaDeviceSynchronize());
2135
+ }
2136
+
1688
2137
  void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1689
- if (src0->type == GGML_TYPE_F32) {
1690
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
2138
+ bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
2139
+ src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
2140
+
2141
+ if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
2142
+ ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
2143
+ } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
2144
+ ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
2145
+ }else if (src0->type == GGML_TYPE_F32) {
2146
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
1691
2147
  } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
1692
- if (src1->ne[1] == 1) {
1693
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
2148
+ if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
2149
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
1694
2150
  } else {
1695
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
2151
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
1696
2152
  }
1697
2153
  } else {
1698
2154
  GGML_ASSERT(false);
1699
2155
  }
1700
2156
  }
1701
2157
 
2158
+ void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2159
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2160
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
2161
+ }
2162
+
2163
+ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2164
+ const int64_t ne = ggml_nelements(src0);
2165
+ GGML_ASSERT(ne == ggml_nelements(src1));
2166
+
2167
+ GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
2168
+ GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
2169
+
2170
+ GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
2171
+ GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
2172
+
2173
+ const int64_t ne00 = src0->ne[0];
2174
+ const int64_t ne01 = src0->ne[1];
2175
+ GGML_ASSERT(src0->ne[3] == 1);
2176
+
2177
+ const int64_t nb00 = src0->nb[0];
2178
+ const int64_t nb01 = src0->nb[1];
2179
+ const int64_t nb02 = src0->nb[2];
2180
+
2181
+ const int64_t ne10 = src1->ne[0];
2182
+ const int64_t ne11 = src1->ne[1];
2183
+ GGML_ASSERT(src1->ne[3] == 1);
2184
+
2185
+ const int64_t nb10 = src1->nb[0];
2186
+ const int64_t nb11 = src1->nb[1];
2187
+ const int64_t nb12 = src1->nb[2];
2188
+
2189
+ CUDA_CHECK(cudaSetDevice(g_main_device));
2190
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2191
+
2192
+ const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2193
+ const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
2194
+
2195
+ char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
2196
+ char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
2197
+
2198
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
2199
+ ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
2200
+ ne10, ne11, nb10, nb11, nb12, cudaStream_main);
2201
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
2202
+ ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
2203
+ ne10, ne11, nb10, nb11, nb12, cudaStream_main);
2204
+ } else {
2205
+ GGML_ASSERT(false);
2206
+ }
2207
+
2208
+ CUDA_CHECK(cudaDeviceSynchronize());
2209
+
2210
+ (void) dst;
2211
+ }
2212
+
2213
+ void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2214
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2215
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
2216
+ }
2217
+
2218
+ void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2219
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2220
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
2221
+ }
2222
+
1702
2223
  void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1703
2224
  GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1704
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true);
2225
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
1705
2226
  }
1706
2227
 
1707
2228
  void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -1710,16 +2231,14 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
1710
2231
  (void) dst;
1711
2232
  }
1712
2233
 
1713
- void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
1714
- FILE * fp = fopen(fname, "rb");
2234
+ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
1715
2235
  int nrows = ggml_nrows(tensor);
1716
2236
  const size_t nb1 = tensor->nb[1];
1717
2237
  ggml_backend backend = tensor->backend;
1718
2238
  struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
2239
+ memset(extra, 0, sizeof(*extra));
1719
2240
 
1720
2241
  for (int id = 0; id < g_device_count; ++id) {
1721
- extra->data_device[id] = nullptr;
1722
-
1723
2242
  if (backend == GGML_BACKEND_GPU && id != g_main_device) {
1724
2243
  continue;
1725
2244
  }
@@ -1732,10 +2251,7 @@ void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const
1732
2251
  row_high = nrows;
1733
2252
  } else if (backend == GGML_BACKEND_GPU_SPLIT) {
1734
2253
  row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
1735
- row_low -= row_low % GGML_CUDA_DMMV_Y;
1736
2254
  row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
1737
- row_high -= row_high % GGML_CUDA_DMMV_Y;
1738
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
1739
2255
  } else {
1740
2256
  GGML_ASSERT(false);
1741
2257
  }
@@ -1745,35 +2261,19 @@ void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const
1745
2261
 
1746
2262
  int64_t nrows_split = row_high - row_low;
1747
2263
 
1748
- const size_t offset_split = offset + row_low*nb1;
2264
+ const size_t offset_split = row_low*nb1;
1749
2265
  const size_t size = ggml_nbytes_split(tensor, nrows_split);
1750
2266
 
1751
2267
  void * buf;
1752
2268
  CUDA_CHECK(cudaMalloc(&buf, size));
1753
- void * buf_host = malloc(size);
1754
-
1755
- #ifdef _WIN32
1756
- int ret = _fseeki64(fp, (__int64) offset_split, SEEK_SET);
1757
- #else
1758
- int ret = fseek(fp, (long) offset_split, SEEK_SET);
1759
- #endif
1760
- GGML_ASSERT(ret == 0); // same
1761
-
1762
- size_t ret2 = fread(buf_host, size, 1, fp);
1763
- if (ret2 != 1) {
1764
- fprintf(stderr, "unexpectedly reached end of file");
1765
- exit(1);
1766
- }
2269
+ void * buf_host = (char*)data + offset_split;
1767
2270
 
1768
2271
  cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
1769
- cudaDeviceSynchronize();
1770
2272
 
1771
- free(buf_host);
1772
2273
  extra->data_device[id] = buf;
1773
2274
  }
1774
2275
 
1775
2276
  tensor->extra = extra;
1776
- fclose(fp);
1777
2277
  }
1778
2278
 
1779
2279
  void ggml_cuda_free_data(struct ggml_tensor * tensor) {
@@ -1795,47 +2295,78 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
1795
2295
  delete extra;
1796
2296
  }
1797
2297
 
1798
- void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
1799
- if (tensor->src0 != nullptr && tensor->src0->op == GGML_OP_RESHAPE) {
1800
- ggml_cuda_assign_buffers(tensor);
2298
+ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2299
+ if (scratch && g_scratch_size == 0) {
2300
+ return;
1801
2301
  }
1802
2302
 
1803
- const size_t size = ggml_nbytes(tensor);
1804
- GGML_ASSERT(size <= g_scratch_size);
1805
- if (g_scratch_offset + size > g_scratch_size) {
1806
- g_scratch_offset = 0;
2303
+ // recursively assign CUDA buffers until a compute tensor is found
2304
+ if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
2305
+ const ggml_op src0_op = tensor->src0->op;
2306
+ if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
2307
+ ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
2308
+ }
2309
+ }
2310
+ if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
2311
+ ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
1807
2312
  }
1808
2313
 
1809
2314
  tensor->backend = GGML_BACKEND_GPU;
1810
2315
  struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
1811
2316
 
1812
- bool inplace = tensor->src0 != nullptr && tensor->src0->data == tensor->data;
2317
+ const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
2318
+ tensor->op == GGML_OP_VIEW;
2319
+ const size_t size = ggml_nbytes(tensor);
1813
2320
 
1814
2321
  CUDA_CHECK(cudaSetDevice(g_main_device));
1815
2322
  if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
1816
2323
  struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
1817
- extra->data_device[g_main_device] = src0_extra->data_device;
1818
- GGML_ASSERT(false);
1819
- } else {
2324
+ char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
2325
+ size_t offset = 0;
2326
+ if (tensor->op == GGML_OP_VIEW) {
2327
+ memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
2328
+ }
2329
+ extra->data_device[g_main_device] = src0_ddc + offset;
2330
+ } else if (tensor->op == GGML_OP_CPY) {
2331
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
2332
+ void * src1_ddv = src1_extra->data_device[g_main_device];
2333
+ extra->data_device[g_main_device] = src1_ddv;
2334
+ } else if (scratch) {
2335
+ GGML_ASSERT(size <= g_scratch_size);
2336
+ if (g_scratch_offset + size > g_scratch_size) {
2337
+ g_scratch_offset = 0;
2338
+ }
2339
+
1820
2340
  char * data = (char *) g_scratch_buffer;
1821
2341
  if (data == nullptr) {
1822
2342
  CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
1823
2343
  g_scratch_buffer = data;
1824
2344
  }
1825
2345
  extra->data_device[g_main_device] = data + g_scratch_offset;
1826
- }
1827
2346
 
1828
- // fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]);
1829
- g_scratch_offset += size;
1830
- // fprintf(stderr, "%s: scratch %d, %p - %p\n",
1831
- // tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size);
2347
+ g_scratch_offset += size;
2348
+
2349
+ GGML_ASSERT(g_scratch_offset <= g_scratch_size);
2350
+ } else { // allocate new buffers outside of scratch
2351
+ void * data;
2352
+ CUDA_CHECK(cudaMalloc(&data, size));
2353
+ CUDA_CHECK(cudaMemset(data, 0, size));
2354
+ extra->data_device[g_main_device] = data;
2355
+ }
1832
2356
 
1833
- GGML_ASSERT(g_scratch_offset <= g_scratch_size);
1834
2357
  tensor->extra = extra;
1835
2358
  }
1836
2359
 
2360
+ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
2361
+ ggml_cuda_assign_buffers_impl(tensor, true);
2362
+ }
2363
+
2364
+ void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
2365
+ ggml_cuda_assign_buffers_impl(tensor, false);
2366
+ }
2367
+
1837
2368
  void ggml_cuda_set_main_device(int main_device) {
1838
- if (main_device > g_device_count) {
2369
+ if (main_device >= g_device_count) {
1839
2370
  fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
1840
2371
  main_device, g_device_count, g_main_device);
1841
2372
  return;
@@ -1852,6 +2383,15 @@ void ggml_cuda_set_scratch_size(size_t scratch_size) {
1852
2383
  g_scratch_size = scratch_size;
1853
2384
  }
1854
2385
 
2386
+ void ggml_cuda_free_scratch() {
2387
+ if (g_scratch_buffer == nullptr) {
2388
+ return;
2389
+ }
2390
+
2391
+ CUDA_CHECK(cudaFree(g_scratch_buffer));
2392
+ g_scratch_buffer = nullptr;
2393
+ }
2394
+
1855
2395
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
1856
2396
  ggml_cuda_func_t func;
1857
2397
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
@@ -1889,12 +2429,39 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
1889
2429
  }
1890
2430
  func = ggml_cuda_mul_mat;
1891
2431
  break;
2432
+ case GGML_OP_SCALE:
2433
+ if (!any_on_device) {
2434
+ return false;
2435
+ }
2436
+ func = ggml_cuda_scale;
2437
+ break;
2438
+ case GGML_OP_CPY:
2439
+ if (!any_on_device) {
2440
+ return false;
2441
+ }
2442
+ func = ggml_cuda_cpy;
2443
+ break;
1892
2444
  case GGML_OP_RESHAPE:
2445
+ case GGML_OP_VIEW:
2446
+ case GGML_OP_PERMUTE:
2447
+ case GGML_OP_TRANSPOSE:
1893
2448
  if (!any_on_device) {
1894
2449
  return false;
1895
2450
  }
1896
2451
  func = ggml_cuda_nop;
1897
2452
  break;
2453
+ case GGML_OP_DIAG_MASK_INF:
2454
+ if (!any_on_device) {
2455
+ return false;
2456
+ }
2457
+ func = ggml_cuda_diag_mask_inf;
2458
+ break;
2459
+ case GGML_OP_SOFT_MAX:
2460
+ if (!any_on_device) {
2461
+ return false;
2462
+ }
2463
+ func = ggml_cuda_soft_max;
2464
+ break;
1898
2465
  case GGML_OP_ROPE:
1899
2466
  if (!any_on_device) {
1900
2467
  return false;