llama_cpp 0.8.0 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -29,6 +29,8 @@
29
29
  #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
30
30
  #define cublasCreate hipblasCreate
31
31
  #define cublasGemmEx hipblasGemmEx
32
+ #define cublasGemmBatchedEx hipblasGemmBatchedEx
33
+ #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
32
34
  #define cublasHandle_t hipblasHandle_t
33
35
  #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
34
36
  #define cublasSetStream hipblasSetStream
@@ -4326,13 +4328,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
4326
4328
 
4327
4329
  const half * x = (const half *) vx;
4328
4330
 
4329
- const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
4330
- const int channel = blockDim.z*blockIdx.z + threadIdx.z;
4331
+ const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
4332
+ const int channel = blockDim.z*blockIdx.z + threadIdx.z;
4331
4333
  const int channel_x = channel / channel_x_divisor;
4332
4334
 
4333
- const int nrows_y = ncols_x;
4335
+ const int nrows_y = ncols_x;
4334
4336
  const int nrows_dst = nrows_x;
4335
- const int row_dst = row_x;
4337
+ const int row_dst = row_x;
4336
4338
 
4337
4339
  const int idst = channel*nrows_dst + row_dst;
4338
4340
 
@@ -4345,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
4345
4347
  break;
4346
4348
  }
4347
4349
 
4348
- const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
4349
- const float xi = __half2float(x[ix]);
4350
-
4351
4350
  const int row_y = col_x;
4352
4351
 
4352
+ const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
4353
4353
  const int iy = channel*nrows_y + row_y;
4354
4354
 
4355
+ const float xi = __half2float(x[ix]);
4356
+
4355
4357
  tmp += xi * y[iy];
4356
4358
  }
4357
4359
 
@@ -5662,10 +5664,10 @@ void ggml_init_cublas() {
5662
5664
  GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
5663
5665
  int64_t total_vram = 0;
5664
5666
  fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
5665
- for (int64_t id = 0; id < g_device_count; ++id) {
5667
+ for (int id = 0; id < g_device_count; ++id) {
5666
5668
  cudaDeviceProp prop;
5667
5669
  CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
5668
- fprintf(stderr, " Device %ld: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
5670
+ fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
5669
5671
 
5670
5672
  g_tensor_split[id] = total_vram;
5671
5673
  total_vram += prop.totalGlobalMem;
@@ -5675,15 +5677,15 @@ void ggml_init_cublas() {
5675
5677
  g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
5676
5678
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
5677
5679
  }
5678
- for (int64_t id = 0; id < g_device_count; ++id) {
5680
+ for (int id = 0; id < g_device_count; ++id) {
5679
5681
  g_tensor_split[id] /= total_vram;
5680
5682
  }
5681
5683
 
5682
- for (int64_t id = 0; id < g_device_count; ++id) {
5684
+ for (int id = 0; id < g_device_count; ++id) {
5683
5685
  CUDA_CHECK(ggml_cuda_set_device(id));
5684
5686
 
5685
5687
  // create cuda streams
5686
- for (int64_t is = 0; is < MAX_STREAMS; ++is) {
5688
+ for (int is = 0; is < MAX_STREAMS; ++is) {
5687
5689
  CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking));
5688
5690
  }
5689
5691
 
@@ -6252,16 +6254,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
6252
6254
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
6253
6255
  const int64_t src1_padded_row_size, const cudaStream_t & stream) {
6254
6256
 
6255
- GGML_ASSERT(src0_dd_i != nullptr);
6257
+ GGML_ASSERT(src0_dd_i != nullptr);
6256
6258
  GGML_ASSERT(src1_ddf_i != nullptr);
6257
- GGML_ASSERT(dst_dd_i != nullptr);
6258
-
6259
+ GGML_ASSERT(dst_dd_i != nullptr);
6259
6260
 
6260
6261
  const int64_t ne00 = src0->ne[0];
6261
-
6262
6262
  const int64_t ne10 = src1->ne[0];
6263
6263
 
6264
6264
  const int64_t ne0 = dst->ne[0];
6265
+
6265
6266
  const int64_t row_diff = row_high - row_low;
6266
6267
 
6267
6268
  int id;
@@ -7013,7 +7014,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
7013
7014
  }
7014
7015
 
7015
7016
  static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7016
- GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
7017
+ GGML_ASSERT(!ggml_is_transposed(src0));
7018
+ GGML_ASSERT(!ggml_is_transposed(src1));
7017
7019
  GGML_ASSERT(!ggml_is_permuted(src0));
7018
7020
  GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
7019
7021
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -7023,11 +7025,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
7023
7025
  const int64_t ne01 = src0->ne[1];
7024
7026
  const int64_t ne02 = src0->ne[2];
7025
7027
 
7026
- const int64_t ne12 = src1->ne[2];
7027
-
7028
7028
  const int64_t nb01 = src0->nb[1];
7029
7029
  const int64_t nb02 = src0->nb[2];
7030
7030
 
7031
+ const int64_t ne12 = src1->ne[2];
7032
+
7031
7033
  CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7032
7034
  cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
7033
7035
 
@@ -7046,6 +7048,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
7046
7048
  ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
7047
7049
  }
7048
7050
 
7051
+ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7052
+ GGML_ASSERT(!ggml_is_transposed(src0));
7053
+ GGML_ASSERT(!ggml_is_transposed(src1));
7054
+ GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
7055
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
7056
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
7057
+
7058
+ const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
7059
+ const int64_t ne01 = src0->ne[1];
7060
+ const int64_t ne02 = src0->ne[2];
7061
+ const int64_t ne03 = src0->ne[3];
7062
+
7063
+ const int64_t nb01 = src0->nb[1];
7064
+ const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
7065
+ const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
7066
+
7067
+ const int64_t ne10 = src1->ne[0];
7068
+ const int64_t ne11 = src1->ne[1];
7069
+ const int64_t ne12 = src1->ne[2];
7070
+ const int64_t ne13 = src1->ne[3];
7071
+
7072
+ const int64_t nb11 = src1->nb[1];
7073
+ const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
7074
+ const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
7075
+
7076
+ const int64_t ne1 = ggml_nelements(src1);
7077
+ const int64_t ne = ggml_nelements(dst);
7078
+
7079
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7080
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
7081
+
7082
+ int id;
7083
+ CUDA_CHECK(cudaGetDevice(&id));
7084
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
7085
+
7086
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
7087
+ void * src0_ddq = src0_extra->data_device[g_main_device];
7088
+ half * src0_as_f16 = (half *) src0_ddq;
7089
+
7090
+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
7091
+ float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
7092
+
7093
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
7094
+ float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
7095
+
7096
+ // convert src1 to fp16
7097
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
7098
+ GGML_ASSERT(to_fp16_cuda != nullptr);
7099
+
7100
+ size_t src1_as = 0;
7101
+ half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
7102
+ to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
7103
+
7104
+ size_t dst_as = 0;
7105
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
7106
+
7107
+ GGML_ASSERT(ne12 % ne02 == 0);
7108
+ GGML_ASSERT(ne13 % ne03 == 0);
7109
+
7110
+ // broadcast factors
7111
+ const int64_t r2 = ne12/ne02;
7112
+ const int64_t r3 = ne13/ne03;
7113
+
7114
+ const half alpha_f16 = 1.0f;
7115
+ const half beta_f16 = 0.0f;
7116
+
7117
+ #if 0
7118
+ // use cublasGemmEx
7119
+ {
7120
+ for (int i13 = 0; i13 < ne13; ++i13) {
7121
+ for (int i12 = 0; i12 < ne12; ++i12) {
7122
+ int i03 = i13 / r3;
7123
+ int i02 = i12 / r2;
7124
+
7125
+ CUBLAS_CHECK(
7126
+ cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7127
+ ne01, ne11, ne10,
7128
+ &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7129
+ (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7130
+ &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
7131
+ CUBLAS_COMPUTE_16F,
7132
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7133
+ }
7134
+ }
7135
+ }
7136
+ #else
7137
+ if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
7138
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
7139
+ // use cublasGemmStridedBatchedEx
7140
+ CUBLAS_CHECK(
7141
+ cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7142
+ ne01, ne11, ne10,
7143
+ &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7144
+ (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
7145
+ &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
7146
+ ne12*ne13,
7147
+ CUBLAS_COMPUTE_16F,
7148
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7149
+ } else {
7150
+ // use cublasGemmBatchedEx
7151
+ // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
7152
+ const int ne23 = ne12*ne13;
7153
+
7154
+ // TODO: avoid this alloc
7155
+ void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
7156
+
7157
+ for (int i13 = 0; i13 < ne13; ++i13) {
7158
+ for (int i12 = 0; i12 < ne12; ++i12) {
7159
+ int i03 = i13 / r3;
7160
+ int i02 = i12 / r2;
7161
+
7162
+ ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
7163
+ ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
7164
+ ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
7165
+ }
7166
+ }
7167
+
7168
+ // allocate device memory for pointers
7169
+ void ** ptrs_as = nullptr;
7170
+ CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
7171
+
7172
+ // TODO: this does not work for some reason -- not sure why?
7173
+ //size_t ptrs_s = 0;
7174
+ //ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7175
+
7176
+ // copy pointers to device
7177
+ CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
7178
+
7179
+ free(ptrs);
7180
+
7181
+ CUBLAS_CHECK(
7182
+ cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7183
+ ne01, ne11, ne10,
7184
+ &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7185
+ (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7186
+ &beta_f16, ( void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
7187
+ ne23,
7188
+ CUBLAS_COMPUTE_16F,
7189
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7190
+
7191
+ // free device memory for pointers
7192
+ CUDA_CHECK(cudaFree(ptrs_as));
7193
+ //ggml_cuda_pool_free(ptrs_as, ptrs_s);
7194
+ }
7195
+ #endif
7196
+
7197
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
7198
+ to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7199
+
7200
+ ggml_cuda_pool_free(src1_as_f16, src1_as);
7201
+ ggml_cuda_pool_free(dst_f16, dst_as);
7202
+ }
7203
+
7049
7204
  static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7050
7205
  bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
7051
7206
  src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@@ -7058,10 +7213,23 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
7058
7213
  }
7059
7214
  }
7060
7215
 
7216
+ // debug helpers
7217
+ //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
7218
+ //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
7219
+ //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
7220
+ //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
7221
+ //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
7222
+ //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
7223
+
7061
7224
  if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
7225
+ // KQ single-batch
7062
7226
  ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
7063
- } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
7227
+ } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
7228
+ // KQV single-batch
7064
7229
  ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
7230
+ } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
7231
+ // KQ + KQV multi-batch
7232
+ ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
7065
7233
  } else if (src0->type == GGML_TYPE_F32) {
7066
7234
  ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
7067
7235
  } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
@@ -62,6 +62,7 @@ struct ggml_metal_context {
62
62
  GGML_METAL_DECL_KERNEL(mul);
63
63
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
64
64
  GGML_METAL_DECL_KERNEL(scale);
65
+ GGML_METAL_DECL_KERNEL(scale_4);
65
66
  GGML_METAL_DECL_KERNEL(silu);
66
67
  GGML_METAL_DECL_KERNEL(relu);
67
68
  GGML_METAL_DECL_KERNEL(gelu);
@@ -249,6 +250,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
249
250
  GGML_METAL_ADD_KERNEL(mul);
250
251
  GGML_METAL_ADD_KERNEL(mul_row);
251
252
  GGML_METAL_ADD_KERNEL(scale);
253
+ GGML_METAL_ADD_KERNEL(scale_4);
252
254
  GGML_METAL_ADD_KERNEL(silu);
253
255
  GGML_METAL_ADD_KERNEL(relu);
254
256
  GGML_METAL_ADD_KERNEL(gelu);
@@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
347
349
  GGML_METAL_DEL_KERNEL(mul);
348
350
  GGML_METAL_DEL_KERNEL(mul_row);
349
351
  GGML_METAL_DEL_KERNEL(scale);
352
+ GGML_METAL_DEL_KERNEL(scale_4);
350
353
  GGML_METAL_DEL_KERNEL(silu);
351
354
  GGML_METAL_DEL_KERNEL(relu);
352
355
  GGML_METAL_DEL_KERNEL(gelu);
@@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
923
926
 
924
927
  const float scale = *(const float *) src1->data;
925
928
 
926
- [encoder setComputePipelineState:ctx->pipeline_scale];
929
+ int64_t n = ggml_nelements(dst);
930
+
931
+ if (n % 4 == 0) {
932
+ n /= 4;
933
+ [encoder setComputePipelineState:ctx->pipeline_scale_4];
934
+ } else {
935
+ [encoder setComputePipelineState:ctx->pipeline_scale];
936
+ }
937
+
927
938
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
928
939
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
929
940
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
930
941
 
931
- const int64_t n = ggml_nelements(dst);
932
- GGML_ASSERT(n % 4 == 0);
933
-
934
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
942
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
935
943
  } break;
936
944
  case GGML_OP_UNARY:
937
945
  switch (ggml_get_unary_op(gf->nodes[i])) {
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
125
125
  }
126
126
 
127
127
  kernel void kernel_scale(
128
+ device const float * src0,
129
+ device float * dst,
130
+ constant float & scale,
131
+ uint tpig[[thread_position_in_grid]]) {
132
+ dst[tpig] = src0[tpig] * scale;
133
+ }
134
+
135
+ kernel void kernel_scale_4(
128
136
  device const float4 * src0,
129
137
  device float4 * dst,
130
- constant float & scale,
138
+ constant float & scale,
131
139
  uint tpig[[thread_position_in_grid]]) {
132
140
  dst[tpig] = src0[tpig] * scale;
133
141
  }