llama_cpp 0.5.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -77,7 +77,7 @@ struct free_block {
77
77
  size_t size;
78
78
  };
79
79
 
80
- #define MAX_FREE_BLOCKS 128
80
+ #define MAX_FREE_BLOCKS 256
81
81
 
82
82
  struct ggml_allocr {
83
83
  void * data;
@@ -187,6 +187,7 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
187
187
  }
188
188
 
189
189
  tensor->data = addr;
190
+ AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data);
190
191
 
191
192
  #ifdef GGML_ALLOCATOR_DEBUG
192
193
  add_allocated_tensor(alloc, tensor);
@@ -218,7 +219,8 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
218
219
 
219
220
  size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
220
221
  size = aligned_offset(NULL, size, alloc->alignment);
221
- AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
222
+ AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
223
+ AT_PRINTF("%s: alloc->data = %p alloc->data+alloc->size = %p alloc->data+alloc->max_size = %p\n", __func__, alloc->data, (char*)alloc->data + alloc->size, (char*)alloc->data + alloc->max_size);
222
224
 
223
225
  #ifdef GGML_ALLOCATOR_DEBUG
224
226
  remove_allocated_tensor(alloc, tensor);
@@ -631,3 +633,7 @@ static size_t ggml_allocr_alloc_graph_tensors_n(
631
633
  size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
632
634
  return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
633
635
  }
636
+
637
+ size_t ggml_allocr_max_size(struct ggml_allocr * alloc) {
638
+ return alloc->max_size;
639
+ }
@@ -19,6 +19,7 @@ GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
19
19
  GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);
20
20
  GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
21
21
  GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
22
+ GGML_API size_t ggml_allocr_max_size(struct ggml_allocr * alloc);
22
23
 
23
24
 
24
25
  #ifdef __cplusplus
@@ -1,3 +1,4 @@
1
+ #include <algorithm>
1
2
  #include <cstddef>
2
3
  #include <cstdint>
3
4
  #include <limits>
@@ -14,9 +15,11 @@
14
15
  // for rocblas_initialize()
15
16
  #include "rocblas/rocblas.h"
16
17
  #endif // __HIP_PLATFORM_AMD__
18
+ #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
17
19
  #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
18
20
  #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
19
21
  #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
22
+ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
20
23
  #define CUBLAS_OP_N HIPBLAS_OP_N
21
24
  #define CUBLAS_OP_T HIPBLAS_OP_T
22
25
  #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
@@ -235,8 +238,12 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
235
238
  return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
236
239
  }
237
240
 
241
+ template<typename T>
242
+ using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
243
+ typedef to_t_cuda_t<float> to_fp32_cuda_t;
244
+ typedef to_t_cuda_t<half> to_fp16_cuda_t;
245
+
238
246
  typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
239
- typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
240
247
  typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
241
248
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
242
249
  typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -461,7 +468,7 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
461
468
  static bool g_mul_mat_q = true;
462
469
 
463
470
  static void * g_scratch_buffer = nullptr;
464
- static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
471
+ static size_t g_scratch_size = 0; // disabled by default
465
472
  static size_t g_scratch_offset = 0;
466
473
 
467
474
  static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@@ -1515,6 +1522,14 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1515
1522
  v.y = x[ib + iqs + 1];
1516
1523
  }
1517
1524
 
1525
+ static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
1526
+ const float * x = (const float *) vx;
1527
+
1528
+ // automatic half -> float type cast if dfloat == float
1529
+ v.x = x[ib + iqs + 0];
1530
+ v.y = x[ib + iqs + 1];
1531
+ }
1532
+
1518
1533
  static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
1519
1534
  const int ix = blockDim.x*blockIdx.x + threadIdx.x;
1520
1535
 
@@ -1554,8 +1569,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
1554
1569
  reinterpret_cast<half&>(y[ib].ds.y) = sum;
1555
1570
  }
1556
1571
 
1557
- template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
1558
- static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
1572
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1573
+ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
1559
1574
  const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
1560
1575
 
1561
1576
  if (i >= k) {
@@ -4355,8 +4370,10 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
4355
4370
  }
4356
4371
 
4357
4372
  // rope == RoPE == rotary positional embedding
4358
- static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
4359
- const float p_delta, const int p_delta_rows, const float theta_scale) {
4373
+
4374
+ template<typename T, bool has_pos>
4375
+ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4376
+ const int p_delta_rows, const float theta_scale) {
4360
4377
  const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
4361
4378
 
4362
4379
  if (col >= ncols) {
@@ -4365,8 +4382,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
4365
4382
 
4366
4383
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
4367
4384
  const int i = row*ncols + col;
4385
+ const int i2 = row/p_delta_rows;
4368
4386
 
4369
- const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
4387
+ const int p = has_pos ? pos[i2] : 0;
4388
+ const float p0 = p*freq_scale;
4389
+ const float theta = p0*powf(theta_scale, col/2);
4370
4390
  const float sin_theta = sinf(theta);
4371
4391
  const float cos_theta = cosf(theta);
4372
4392
 
@@ -4377,8 +4397,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
4377
4397
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
4378
4398
  }
4379
4399
 
4380
- static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
4381
- const float p_delta, const int p_delta_rows, const float theta_scale) {
4400
+ template<typename T, bool has_pos>
4401
+ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4402
+ const int p_delta_rows, const float theta_scale) {
4382
4403
  const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
4383
4404
 
4384
4405
  if (col >= ncols) {
@@ -4387,8 +4408,11 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
4387
4408
 
4388
4409
  const int row = blockDim.x*blockIdx.x + threadIdx.x;
4389
4410
  const int i = row*ncols + col/2;
4411
+ const int i2 = row/p_delta_rows;
4390
4412
 
4391
- const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
4413
+ const int p = has_pos ? pos[i2] : 0;
4414
+ const float p0 = p*freq_scale;
4415
+ const float theta = p0*powf(theta_scale, col/2);
4392
4416
  const float sin_theta = sinf(theta);
4393
4417
  const float cos_theta = cosf(theta);
4394
4418
 
@@ -4399,8 +4423,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
4399
4423
  dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
4400
4424
  }
4401
4425
 
4402
- static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
4403
- const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
4426
+ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4427
+ const int p_delta_rows, const float theta_scale, const int n_ctx) {
4404
4428
  const int col = blockDim.x*blockIdx.x + threadIdx.x;
4405
4429
  const int half_n_dims = ncols/4;
4406
4430
 
@@ -4410,11 +4434,13 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
4410
4434
 
4411
4435
  const int row = blockDim.y*blockIdx.y + threadIdx.y;
4412
4436
  const int i = row*ncols + col;
4437
+ const int i2 = row/p_delta_rows;
4413
4438
 
4414
4439
  const float col_theta_scale = powf(theta_scale, col);
4415
- const float p = p0 + p_delta*(row/p_delta_rows);
4440
+ // FIXME: this is likely wrong
4441
+ const int p = pos != nullptr ? pos[i2] : 0;
4416
4442
 
4417
- const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
4443
+ const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
4418
4444
  const float sin_theta = sinf(theta);
4419
4445
  const float cos_theta = cosf(theta);
4420
4446
 
@@ -4424,7 +4450,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
4424
4450
  dst[i + 0] = x0*cos_theta - x1*sin_theta;
4425
4451
  dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
4426
4452
 
4427
- const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
4453
+ const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
4428
4454
  const float sin_block_theta = sinf(block_theta);
4429
4455
  const float cos_block_theta = cosf(block_theta);
4430
4456
 
@@ -4826,6 +4852,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
4826
4852
  dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
4827
4853
  }
4828
4854
 
4855
+ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
4856
+ const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
4857
+ dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
4858
+ }
4859
+
4829
4860
  static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4830
4861
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
4831
4862
  const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -4835,6 +4866,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
4835
4866
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
4836
4867
  }
4837
4868
 
4869
+ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
4870
+ switch (type) {
4871
+ case GGML_TYPE_F32:
4872
+ return convert_fp32_to_fp16_cuda;
4873
+ default:
4874
+ return nullptr;
4875
+ }
4876
+ }
4877
+
4838
4878
  static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
4839
4879
  switch (type) {
4840
4880
  case GGML_TYPE_Q4_0:
@@ -5361,31 +5401,41 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
5361
5401
  scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
5362
5402
  }
5363
5403
 
5364
- static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5365
- const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5404
+ template<typename T>
5405
+ static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5406
+ const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5366
5407
  GGML_ASSERT(ncols % 2 == 0);
5367
5408
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
5368
5409
  const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
5369
5410
  const dim3 block_nums(nrows, num_blocks_x, 1);
5370
- rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
5411
+ if (pos == nullptr) {
5412
+ rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5413
+ } else {
5414
+ rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5415
+ }
5371
5416
  }
5372
5417
 
5373
- static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5374
- const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5418
+ template<typename T>
5419
+ static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5420
+ const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5375
5421
  GGML_ASSERT(ncols % 2 == 0);
5376
5422
  const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
5377
5423
  const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
5378
5424
  const dim3 block_nums(nrows, num_blocks_x, 1);
5379
- rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
5425
+ if (pos == nullptr) {
5426
+ rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5427
+ } else {
5428
+ rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5429
+ }
5380
5430
  }
5381
5431
 
5382
- static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5383
- const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
5432
+ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5433
+ const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
5384
5434
  GGML_ASSERT(ncols % 4 == 0);
5385
5435
  const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
5386
5436
  const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
5387
5437
  const dim3 block_nums(num_blocks_x, nrows, 1);
5388
- rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
5438
+ rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
5389
5439
  }
5390
5440
 
5391
5441
  static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -6016,8 +6066,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
6016
6066
  GGML_ASSERT(src1_ddf_i != nullptr);
6017
6067
  GGML_ASSERT(dst_dd_i != nullptr);
6018
6068
 
6019
- const float alpha = 1.0f;
6020
- const float beta = 0.0f;
6021
6069
 
6022
6070
  const int64_t ne00 = src0->ne[0];
6023
6071
 
@@ -6026,16 +6074,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
6026
6074
  const int64_t ne0 = dst->ne[0];
6027
6075
  const int64_t row_diff = row_high - row_low;
6028
6076
 
6029
- float * src0_ddq_as_f32;
6030
- size_t src0_as = 0;
6031
-
6032
- if (src0->type != GGML_TYPE_F32) {
6033
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
6034
- src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
6035
- to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6036
- }
6037
- const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
6038
-
6039
6077
  int id;
6040
6078
  CUDA_CHECK(cudaGetDevice(&id));
6041
6079
 
@@ -6043,16 +6081,72 @@ inline void ggml_cuda_op_mul_mat_cublas(
6043
6081
  // ldc == nrows of the matrix that cuBLAS writes into
6044
6082
  int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
6045
6083
 
6046
- CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
6047
- CUBLAS_CHECK(
6048
- cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6049
- row_diff, src1_ncols, ne10,
6050
- &alpha, src0_ddf_i, ne00,
6051
- src1_ddf_i, ne10,
6052
- &beta, dst_dd_i, ldc));
6084
+ const int compute_capability = g_compute_capabilities[id];
6085
+
6086
+ if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) {
6087
+ // convert src1 to fp16, multiply as fp16, convert dst to fp32
6088
+ half * src1_as_f16 = nullptr;
6089
+ size_t src1_as = 0;
6090
+ if (src1->type != GGML_TYPE_F16) {
6091
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
6092
+ GGML_ASSERT(to_fp16_cuda != nullptr);
6093
+ size_t ne = src1_ncols*ne10;
6094
+ src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
6095
+ to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
6096
+ }
6097
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6098
+
6099
+ size_t dst_as = 0;
6100
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
6101
+
6102
+ const half alpha_f16 = 1.0f;
6103
+ const half beta_f16 = 0.0f;
6104
+
6105
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
6106
+ CUBLAS_CHECK(
6107
+ cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6108
+ row_diff, src1_ncols, ne10,
6109
+ &alpha_f16, src0_dd_i, CUDA_R_16F, ne00,
6110
+ src1_ptr, CUDA_R_16F, ne10,
6111
+ &beta_f16, dst_f16, CUDA_R_16F, ldc,
6112
+ CUBLAS_COMPUTE_16F,
6113
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
6114
+
6115
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
6116
+ to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6117
+
6118
+ ggml_cuda_pool_free(dst_f16, dst_as);
6053
6119
 
6054
- if (src0_as > 0) {
6055
- ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
6120
+ if (src1_as != 0) {
6121
+ ggml_cuda_pool_free(src1_as_f16, src1_as);
6122
+ }
6123
+ }
6124
+ else {
6125
+ float * src0_ddq_as_f32 = nullptr;
6126
+ size_t src0_as = 0;
6127
+
6128
+ if (src0->type != GGML_TYPE_F32) {
6129
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
6130
+ GGML_ASSERT(to_fp32_cuda != nullptr);
6131
+ src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
6132
+ to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6133
+ }
6134
+ const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
6135
+
6136
+ const float alpha = 1.0f;
6137
+ const float beta = 0.0f;
6138
+
6139
+ CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
6140
+ CUBLAS_CHECK(
6141
+ cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
6142
+ row_diff, src1_ncols, ne10,
6143
+ &alpha, src0_ddf_i, ne00,
6144
+ src1_ddf_i, ne10,
6145
+ &beta, dst_dd_i, ldc));
6146
+
6147
+ if (src0_as != 0) {
6148
+ ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
6149
+ }
6056
6150
  }
6057
6151
 
6058
6152
  (void) dst;
@@ -6064,14 +6158,16 @@ inline void ggml_cuda_op_rope(
6064
6158
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6065
6159
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6066
6160
 
6067
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
6068
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
6161
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
6162
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
6163
+ GGML_ASSERT(src0->type == dst->type);
6069
6164
 
6070
6165
  const int64_t ne00 = src0->ne[0];
6071
6166
  const int64_t ne01 = src0->ne[1];
6167
+ const int64_t ne2 = dst->ne[2];
6072
6168
  const int64_t nrows = ggml_nrows(src0);
6073
6169
 
6074
- const int n_past = ((int32_t *) dst->op_params)[0];
6170
+ //const int n_past = ((int32_t *) dst->op_params)[0];
6075
6171
  const int n_dims = ((int32_t *) dst->op_params)[1];
6076
6172
  const int mode = ((int32_t *) dst->op_params)[2];
6077
6173
  const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6082,19 +6178,38 @@ inline void ggml_cuda_op_rope(
6082
6178
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
6083
6179
 
6084
6180
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
6085
- const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6181
+
6182
+ const int32_t * pos = nullptr;
6183
+ if ((mode & 1) == 0) {
6184
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
6185
+ GGML_ASSERT(src1->ne[0] == ne2);
6186
+ pos = (const int32_t *) src1_dd;
6187
+ }
6086
6188
 
6087
6189
  const bool is_neox = mode & 2;
6088
6190
  const bool is_glm = mode & 4;
6089
6191
 
6090
6192
  // compute
6091
6193
  if (is_glm) {
6092
- rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
6194
+ GGML_ASSERT(false);
6195
+ rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
6093
6196
  } else if (is_neox) {
6094
6197
  GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
6095
- rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
6198
+ if (src0->type == GGML_TYPE_F32) {
6199
+ rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6200
+ } else if (src0->type == GGML_TYPE_F16) {
6201
+ rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6202
+ } else {
6203
+ GGML_ASSERT(false);
6204
+ }
6096
6205
  } else {
6097
- rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
6206
+ if (src0->type == GGML_TYPE_F32) {
6207
+ rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6208
+ } else if (src0->type == GGML_TYPE_F16) {
6209
+ rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6210
+ } else {
6211
+ GGML_ASSERT(false);
6212
+ }
6098
6213
  }
6099
6214
 
6100
6215
  (void) src1;
@@ -6265,7 +6380,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
6265
6380
  }
6266
6381
  }
6267
6382
 
6268
- void ggml_cuda_set_peer_access(const int n_tokens) {
6383
+ static void ggml_cuda_set_peer_access(const int n_tokens) {
6269
6384
  static bool peer_access_enabled = false;
6270
6385
 
6271
6386
  const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
@@ -6593,27 +6708,27 @@ static void ggml_cuda_op_mul_mat(
6593
6708
  }
6594
6709
  }
6595
6710
 
6596
- void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6711
+ static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6597
6712
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
6598
6713
  }
6599
6714
 
6600
- void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6715
+ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6601
6716
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
6602
6717
  }
6603
6718
 
6604
- void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6719
+ static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6605
6720
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
6606
6721
  }
6607
6722
 
6608
- void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6723
+ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6609
6724
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
6610
6725
  }
6611
6726
 
6612
- void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6727
+ static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6613
6728
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
6614
6729
  }
6615
6730
 
6616
- void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6731
+ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6617
6732
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
6618
6733
  }
6619
6734
 
@@ -6624,17 +6739,13 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
6624
6739
  const int64_t ne1 = dst->ne[1];
6625
6740
 
6626
6741
  // TODO: find the optimal values for these
6627
- if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
6628
- src1->type == GGML_TYPE_F32 &&
6629
- dst->type == GGML_TYPE_F32 &&
6630
- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
6631
- return true;
6632
- }
6633
-
6634
- return false;
6742
+ return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
6743
+ src1->type == GGML_TYPE_F32 &&
6744
+ dst->type == GGML_TYPE_F32 &&
6745
+ (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
6635
6746
  }
6636
6747
 
6637
- void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6748
+ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6638
6749
  GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
6639
6750
  GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
6640
6751
  GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
@@ -6663,7 +6774,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
6663
6774
  ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
6664
6775
  }
6665
6776
 
6666
- void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6777
+ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6667
6778
  GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
6668
6779
  GGML_ASSERT(!ggml_is_permuted(src0));
6669
6780
  GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
@@ -6697,7 +6808,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
6697
6808
  ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
6698
6809
  }
6699
6810
 
6700
- void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6811
+ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6701
6812
  bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
6702
6813
  src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
6703
6814
 
@@ -6741,11 +6852,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
6741
6852
  }
6742
6853
  }
6743
6854
 
6744
- void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6855
+ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6745
6856
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
6746
6857
  }
6747
6858
 
6748
- void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6859
+ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6749
6860
  const int64_t ne = ggml_nelements(src0);
6750
6861
  GGML_ASSERT(ne == ggml_nelements(src1));
6751
6862
 
@@ -6787,35 +6898,37 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
6787
6898
  ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
6788
6899
  ne10, ne11, nb10, nb11, nb12, main_stream);
6789
6900
  } else {
6901
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
6902
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
6790
6903
  GGML_ASSERT(false);
6791
6904
  }
6792
6905
 
6793
6906
  (void) dst;
6794
6907
  }
6795
6908
 
6796
- void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6909
+ static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6797
6910
  ggml_cuda_cpy(src0, dst, nullptr);
6798
6911
  (void) src1;
6799
6912
  }
6800
6913
 
6801
- void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6914
+ static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6802
6915
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
6803
6916
  }
6804
6917
 
6805
- void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6918
+ static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6806
6919
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max);
6807
6920
  }
6808
6921
 
6809
- void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6922
+ static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6810
6923
  GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
6811
6924
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope);
6812
6925
  }
6813
6926
 
6814
- void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6927
+ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6815
6928
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
6816
6929
  }
6817
6930
 
6818
- void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6931
+ static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6819
6932
  (void) src0;
6820
6933
  (void) src1;
6821
6934
  (void) dst;
@@ -6938,11 +7051,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
6938
7051
  return extra;
6939
7052
  }
6940
7053
 
6941
- void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
7054
+ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
6942
7055
  if (scratch && g_scratch_size == 0) {
6943
7056
  return;
6944
7057
  }
6945
7058
 
7059
+ tensor->backend = GGML_BACKEND_GPU;
7060
+
6946
7061
  // recursively assign CUDA buffers until a compute tensor is found
6947
7062
  if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
6948
7063
  const ggml_op src0_op = tensor->src[0]->op;
@@ -6954,8 +7069,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
6954
7069
  ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
6955
7070
  }
6956
7071
 
6957
- tensor->backend = GGML_BACKEND_GPU;
6958
-
6959
7072
  if (scratch && no_alloc) {
6960
7073
  return;
6961
7074
  }
@@ -7040,6 +7153,15 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
7040
7153
  tensor->extra = extra;
7041
7154
  }
7042
7155
 
7156
+ void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
7157
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
7158
+ GGML_ASSERT(ggml_is_contiguous(tensor));
7159
+
7160
+ struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
7161
+ CUDA_CHECK(ggml_cuda_set_device(g_main_device));
7162
+ CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice));
7163
+ }
7164
+
7043
7165
  void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
7044
7166
  ggml_cuda_assign_buffers_impl(tensor, true, false, false);
7045
7167
  }
@@ -7075,7 +7197,12 @@ void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
7075
7197
  }
7076
7198
 
7077
7199
  void ggml_cuda_set_scratch_size(const size_t scratch_size) {
7078
- g_scratch_size = scratch_size;
7200
+ // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
7201
+ // it still won't always work as expected, but it's better than nothing
7202
+ if (scratch_size > g_scratch_size) {
7203
+ ggml_cuda_free_scratch();
7204
+ }
7205
+ g_scratch_size = std::max(g_scratch_size, scratch_size);
7079
7206
  }
7080
7207
 
7081
7208
  void ggml_cuda_free_scratch() {
@@ -31,6 +31,7 @@ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tens
31
31
 
32
32
  GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
33
33
  GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
34
+ GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor);
34
35
 
35
36
  GGML_API void ggml_cuda_set_main_device(int main_device);
36
37
  GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
@@ -19,6 +19,8 @@
19
19
 
20
20
  #pragma once
21
21
 
22
+ #include "ggml.h"
23
+
22
24
  #include <stddef.h>
23
25
  #include <stdbool.h>
24
26
 
@@ -33,6 +35,8 @@ struct ggml_cgraph;
33
35
  extern "C" {
34
36
  #endif
35
37
 
38
+ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data);
39
+
36
40
  struct ggml_metal_context;
37
41
 
38
42
  // number of command buffers to use