llama_cpp 0.15.3 → 0.15.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -2944,6 +2944,57 @@ namespace dpct
2944
2944
  using shared_memory = detail::device_memory<T, shared, Dimension>;
2945
2945
 
2946
2946
 
2947
+ template <typename T,
2948
+ sycl::access::address_space addressSpace =
2949
+ sycl::access::address_space::global_space,
2950
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2951
+ sycl::memory_scope memoryScope = sycl::memory_scope::device>
2952
+ inline T atomic_fetch_add(T *addr, T operand) {
2953
+ auto atm =
2954
+ sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2955
+ return atm.fetch_add(operand);
2956
+ }
2957
+
2958
+ template <sycl::access::address_space addressSpace =
2959
+ sycl::access::address_space::global_space,
2960
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2961
+ sycl::memory_scope memoryScope = sycl::memory_scope::device,
2962
+ typename T1, typename T2>
2963
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2964
+ auto atm =
2965
+ sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2966
+ return atm.fetch_add(operand);
2967
+ }
2968
+
2969
+ template <typename T, sycl::access::address_space addressSpace =
2970
+ sycl::access::address_space::global_space>
2971
+ inline T atomic_fetch_add(T *addr, T operand,
2972
+ sycl::memory_order memoryOrder) {
2973
+ switch (memoryOrder) {
2974
+ case sycl::memory_order::relaxed:
2975
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2976
+ sycl::memory_scope::device>(addr, operand);
2977
+ case sycl::memory_order::acq_rel:
2978
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2979
+ sycl::memory_scope::device>(addr, operand);
2980
+ case sycl::memory_order::seq_cst:
2981
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2982
+ sycl::memory_scope::device>(addr, operand);
2983
+ default:
2984
+ assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2985
+ "atomics are: sycl::memory_order::relaxed, "
2986
+ "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2987
+ }
2988
+ }
2989
+
2990
+ template <sycl::access::address_space addressSpace =
2991
+ sycl::access::address_space::global_space,
2992
+ typename T1, typename T2>
2993
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2994
+ sycl::memory_order memoryOrder) {
2995
+ atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2996
+ }
2997
+
2947
2998
  } // COPY from DPCT head files
2948
2999
 
2949
3000
  #define GGML_COMMON_DECL_SYCL
@@ -2971,20 +3022,19 @@ static int g_work_group_size = 0;
2971
3022
  // typedef sycl::half ggml_fp16_t;
2972
3023
 
2973
3024
  #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
2974
- #define VER_4VEC 610 //todo for hardward optimize.
3025
+ #define VER_4VEC 130 //todo for hardward optimize.
2975
3026
  #define VER_GEN9 700 //todo for hardward optimize.
2976
3027
  #define VER_GEN12 1000000 //todo for hardward optimize.
2977
3028
  #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
2978
3029
 
2979
3030
  #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares
2980
3031
 
2981
-
2982
- //define for XMX in Intel GPU
2983
- //TODO: currently, it's not used for XMX really.
2984
- #define SYCL_USE_XMX
3032
+ #if !defined(GGML_SYCL_FORCE_MMQ)
3033
+ #define SYCL_USE_XMX
3034
+ #endif
2985
3035
 
2986
3036
  // max batch size to use MMQ kernels when tensor cores are available
2987
- #define XMX_MAX_BATCH_SIZE 32
3037
+ #define MMQ_MAX_BATCH_SIZE 32
2988
3038
 
2989
3039
 
2990
3040
  #if defined(_MSC_VER)
@@ -3060,6 +3110,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
3060
3110
  bool ggml_backend_is_sycl(ggml_backend_t backend);
3061
3111
  int ggml_backend_sycl_get_device(ggml_backend_t backend);
3062
3112
  int get_main_device();
3113
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
3063
3114
  void print_ggml_tensor(const char*name, struct ggml_tensor *src);
3064
3115
  void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
3065
3116
 
@@ -8830,12 +8881,11 @@ static void rope(
8830
8881
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
8831
8882
  }
8832
8883
 
8833
- template<typename T, bool has_pos>
8884
+ template<typename T, bool has_pos, bool has_freq_facs>
8834
8885
  static void rope_neox(
8835
8886
  const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
8836
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
8837
- ,
8838
- const sycl::nd_item<3> &item_ct1) {
8887
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
8888
+ const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
8839
8889
  const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8840
8890
  item_ct1.get_local_id(1));
8841
8891
 
@@ -8863,8 +8913,10 @@ static void rope_neox(
8863
8913
  float cur_rot = inv_ndims * ic - ib;
8864
8914
 
8865
8915
  const int p = has_pos ? pos[i2] : 0;
8916
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
8917
+
8866
8918
  const float theta_base =
8867
- p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
8919
+ p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
8868
8920
 
8869
8921
  float cos_theta, sin_theta;
8870
8922
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -12413,7 +12465,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12413
12465
  const int32_t *pos, float freq_scale,
12414
12466
  int p_delta_rows, float freq_base, float ext_factor,
12415
12467
  float attn_factor, rope_corr_dims corr_dims,
12416
- dpct::queue_ptr stream) {
12468
+ const float * freq_factors, dpct::queue_ptr stream) {
12417
12469
  GGML_ASSERT(ncols % 2 == 0);
12418
12470
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
12419
12471
  const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12423,38 +12475,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12423
12475
  const float inv_ndims = -1.0f / n_dims;
12424
12476
 
12425
12477
  if (pos == nullptr) {
12426
- /*
12427
- DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
12428
- the limit. To get the device limit, query
12429
- info::device::max_work_group_size. Adjust the work-group size if needed.
12430
- */
12431
12478
  dpct::has_capability_or_fail(stream->get_device(),
12432
12479
  {sycl::aspect::fp16});
12433
-
12434
- stream->parallel_for(
12435
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12436
- [=](sycl::nd_item<3> item_ct1) {
12437
- rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
12438
- p_delta_rows, ext_factor, attn_factor,
12439
- corr_dims, theta_scale, inv_ndims,
12440
- item_ct1);
12441
- });
12480
+ if (freq_factors == nullptr) {
12481
+ stream->parallel_for(
12482
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12483
+ [=](sycl::nd_item<3> item_ct1) {
12484
+ rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
12485
+ p_delta_rows, ext_factor, attn_factor,
12486
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12487
+ item_ct1);
12488
+ });
12489
+ } else {
12490
+ stream->parallel_for(
12491
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12492
+ [=](sycl::nd_item<3> item_ct1) {
12493
+ rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
12494
+ p_delta_rows, ext_factor, attn_factor,
12495
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12496
+ item_ct1);
12497
+ });
12498
+ }
12442
12499
  } else {
12443
- /*
12444
- DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
12445
- the limit. To get the device limit, query
12446
- info::device::max_work_group_size. Adjust the work-group size if needed.
12447
- */
12448
12500
  dpct::has_capability_or_fail(stream->get_device(),
12449
12501
  {sycl::aspect::fp16});
12450
12502
 
12451
- stream->parallel_for(
12452
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12453
- [=](sycl::nd_item<3> item_ct1) {
12454
- rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
12455
- p_delta_rows, ext_factor, attn_factor,
12456
- corr_dims, theta_scale, inv_ndims, item_ct1);
12457
- });
12503
+ if (freq_factors == nullptr) {
12504
+ stream->parallel_for(
12505
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12506
+ [=](sycl::nd_item<3> item_ct1) {
12507
+ rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12508
+ p_delta_rows, ext_factor, attn_factor,
12509
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12510
+ });
12511
+ } else {
12512
+ stream->parallel_for(
12513
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12514
+ [=](sycl::nd_item<3> item_ct1) {
12515
+ rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12516
+ p_delta_rows, ext_factor, attn_factor,
12517
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12518
+ });
12519
+ }
12458
12520
  }
12459
12521
  }
12460
12522
 
@@ -13501,6 +13563,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0,
13501
13563
  const float *src0_dd, const float *src1_dd,
13502
13564
  float *dst_dd,
13503
13565
  const dpct::queue_ptr &main_stream) {
13566
+ #pragma message("TODO: generalize concat kernel for dim != 2")
13567
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
13568
+ int dim = dst->op_params[0];
13569
+ GGML_ASSERT(dim == 2);
13504
13570
 
13505
13571
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13506
13572
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13986,9 +14052,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
13986
14052
  ggml_tensor *dst, const float *src0_dd,
13987
14053
  const float *src1_dd, float *dst_dd,
13988
14054
  const dpct::queue_ptr &main_stream) {
13989
- #pragma message("TODO: implement phi3 frequency factors support")
13990
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
13991
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
14055
+ const ggml_tensor * src2 = dst->src[2];
13992
14056
 
13993
14057
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
13994
14058
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -14014,6 +14078,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14014
14078
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14015
14079
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14016
14080
 
14081
+ const float * freq_factors = nullptr;
14017
14082
  const int32_t * pos = nullptr;
14018
14083
  if ((mode & 1) == 0) {
14019
14084
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14024,6 +14089,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14024
14089
  const bool is_neox = mode & 2;
14025
14090
  const bool is_glm = mode & 4;
14026
14091
 
14092
+ if (is_neox) {
14093
+ pos = (const int32_t *) src1_dd;
14094
+
14095
+ if (src2 != nullptr) {
14096
+ freq_factors = (const float *) src2->data;
14097
+ }
14098
+ } else {
14099
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14100
+ }
14101
+
14027
14102
  rope_corr_dims corr_dims;
14028
14103
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
14029
14104
 
@@ -14035,13 +14110,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14035
14110
  if (src0->type == GGML_TYPE_F32) {
14036
14111
  rope_neox_sycl(
14037
14112
  (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
14038
- attn_factor, corr_dims, main_stream
14113
+ attn_factor, corr_dims, freq_factors, main_stream
14039
14114
  );
14040
14115
  } else if (src0->type == GGML_TYPE_F16) {
14041
14116
  rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
14042
14117
  ne00, n_dims, nrows, pos, freq_scale, ne01,
14043
14118
  freq_base, ext_factor, attn_factor, corr_dims,
14044
- main_stream);
14119
+ freq_factors, main_stream);
14045
14120
  } else {
14046
14121
  GGML_ASSERT(false);
14047
14122
  }
@@ -15108,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15108
15183
  const int64_t r2 = ne12/ne02;
15109
15184
  const int64_t r3 = ne13/ne03;
15110
15185
 
15111
- if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
15186
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
15112
15187
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
15113
15188
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
15114
15189
  *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
@@ -15173,6 +15248,29 @@ catch (sycl::exception const &exc) {
15173
15248
  std::exit(1);
15174
15249
  }
15175
15250
 
15251
+ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
15252
+ // TODO: accuracy issues in MMQ
15253
+ return false;
15254
+ }
15255
+
15256
+ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
15257
+ switch (type) {
15258
+ case GGML_TYPE_Q4_0:
15259
+ case GGML_TYPE_Q4_1:
15260
+ case GGML_TYPE_Q5_0:
15261
+ case GGML_TYPE_Q5_1:
15262
+ case GGML_TYPE_Q8_0:
15263
+ case GGML_TYPE_Q2_K:
15264
+ case GGML_TYPE_Q3_K:
15265
+ case GGML_TYPE_Q4_K:
15266
+ case GGML_TYPE_Q5_K:
15267
+ case GGML_TYPE_Q6_K:
15268
+ case GGML_TYPE_F16:
15269
+ return true;
15270
+ default:
15271
+ return false;
15272
+ }
15273
+ }
15176
15274
 
15177
15275
  static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
15178
15276
  const bool all_on_device =
@@ -15189,75 +15287,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15189
15287
  }
15190
15288
  }
15191
15289
 
15192
- #ifdef SYCL_USE_XMX
15193
- const bool use_xmx = true;
15194
- #else
15195
- const bool use_xmx = false;
15196
- #endif
15290
+ // check data types and tensor shapes for custom matrix multiplication kernels:
15291
+ bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
15292
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15293
+ && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
15294
+
15295
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
15296
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15297
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
15197
15298
 
15198
- // debug helpers
15199
- //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
15200
- //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
15201
- //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
15202
- //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
15203
- //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
15204
- //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
15299
+ bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
15300
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
15301
+
15302
+ // mmvq and mmq need the __dp4a instruction which is available for gen12+
15303
+ // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
15304
+ use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
15305
+ #ifdef SYCL_USE_XMX
15306
+ use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
15307
+ #endif // SYCL_USE_XMX
15205
15308
 
15206
- if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15309
+ if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15207
15310
  // KQ single-batch
15208
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
15209
15311
  ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
15210
- } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15312
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15211
15313
  // KQV single-batch
15212
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
15213
15314
  ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
15214
- } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
15315
+ } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
15215
15316
  // KQ + KQV multi-batch
15216
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
15217
15317
  ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
15218
- } else if (src0->type == GGML_TYPE_F32) {
15219
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
15220
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15221
- } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
15222
- // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
15223
- if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
15224
- #ifdef GGML_SYCL_FORCE_DMMV
15225
- const bool use_mul_mat_vec_q = false;
15226
- #else
15227
- bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15228
- use_mul_mat_vec_q = use_mul_mat_vec_q ||
15229
- (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15230
- (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15231
- (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15232
- (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15233
-
15234
-
15235
- #endif // GGML_SYCL_FORCE_DMMV
15236
-
15237
- if (use_mul_mat_vec_q) {
15238
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15239
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15240
- } else {
15241
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
15242
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15243
- }
15244
- } else {
15245
- bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15246
-
15247
- if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
15248
- use_mul_mat_q = false;
15249
- }
15250
-
15251
- if (use_mul_mat_q) {
15252
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
15253
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15254
- } else {
15255
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
15256
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15257
- }
15258
- }
15318
+ } else if (use_dequantize_mul_mat_vec) {
15319
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15320
+ } else if (use_mul_mat_vec_q) {
15321
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15322
+ } else if (use_mul_mat_q) {
15323
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15259
15324
  } else {
15260
- GGML_ASSERT(false);
15325
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15261
15326
  }
15262
15327
  }
15263
15328
 
@@ -15434,22 +15499,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15434
15499
  }
15435
15500
  #endif
15436
15501
 
15502
+ struct mmid_row_mapping {
15503
+ int32_t i1;
15504
+ int32_t i2;
15505
+ };
15506
+
15507
+ __dpct_inline__ static void k_copy_src1_to_contiguous(
15508
+ const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
15509
+ int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
15510
+ const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
15511
+ int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
15512
+ const sycl::nd_item<3> &item_ct1, int &src1_row) {
15513
+ int32_t iid1 = item_ct1.get_group(2);
15514
+ int32_t id = item_ct1.get_group(1);
15515
+
15516
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
15517
+
15518
+ if (row_id_i != i02) {
15519
+ return;
15520
+ }
15521
+
15522
+ const int64_t i11 = id % ne11;
15523
+ const int64_t i12 = iid1;
15524
+
15525
+ if (item_ct1.get_local_id(2) == 0) {
15526
+ src1_row =
15527
+ dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
15528
+ cur_src1_row, 1);
15529
+ row_mapping[src1_row] = {id, iid1};
15530
+ }
15531
+ /*
15532
+ DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
15533
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
15534
+ performance if there is no access to global memory.
15535
+ */
15536
+ item_ct1.barrier();
15537
+
15538
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
15539
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
15540
+
15541
+ #pragma unroll
15542
+ for (int i = item_ct1.get_local_id(2); i < ne10;
15543
+ i += item_ct1.get_local_range(2)) {
15544
+ src1_row_contiguous[i] = src1_row_original[i];
15545
+ }
15546
+ }
15547
+
15548
+ __dpct_inline__ static void k_copy_dst_from_contiguous(
15549
+ char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
15550
+ const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
15551
+ size_t nb2, const sycl::nd_item<3> &item_ct1) {
15552
+ int32_t i = item_ct1.get_group(2);
15553
+
15554
+ const int32_t i1 = row_mapping[i].i1;
15555
+ const int32_t i2 = row_mapping[i].i2;
15556
+
15557
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
15558
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
15559
+
15560
+ #pragma unroll
15561
+ for (int j = item_ct1.get_local_id(2); j < ne0;
15562
+ j += item_ct1.get_local_range(2)) {
15563
+ dst_row_original[j] = dst_row_contiguous[j];
15564
+ }
15565
+ }
15566
+
15437
15567
  static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15438
15568
  const ggml_tensor *src1,
15439
15569
  ggml_tensor *dst) try {
15440
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
15441
- "mul_mat_id does not support split buffers");
15570
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
15571
+
15442
15572
  const ggml_tensor *ids = dst->src[2];
15443
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15573
+ GGML_TENSOR_BINARY_OP_LOCALS
15444
15574
 
15445
- const size_t nb11 = src1->nb[1];
15446
- const size_t nb1 = dst->nb[1];
15575
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15447
15576
 
15448
- const int32_t id = ((int32_t *)dst->op_params)[0];
15449
- const int32_t n_as = src0->ne[2];
15577
+ const int64_t n_as = ne02;
15578
+ const int64_t n_ids = ids->ne[0];
15450
15579
 
15451
15580
  std::vector<char> ids_host(ggml_nbytes(ids));
15452
- const char *ids_dev = (const char *)ids->data;
15581
+ const char * ids_dev = (const char *) ids->data;
15453
15582
 
15454
15583
  SYCL_CHECK(CHECK_TRY_ERROR(
15455
15584
  stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15489,24 +15618,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15489
15618
 
15490
15619
  src0_row.ne[2] = 1;
15491
15620
  src0_row.ne[3] = 1;
15492
- src0_row.nb[3] = src0->nb[2];
15493
-
15494
- if (src1->ne[1] == 1) {
15495
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15496
- const int32_t row_id =
15497
- *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
15498
- id * ids->nb[0]);
15499
-
15500
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15621
+ src0_row.nb[3] = nb02;
15622
+
15623
+ src1_row.ne[1] = 1;
15624
+ src1_row.ne[2] = 1;
15625
+ src1_row.ne[3] = 1;
15626
+ src1_row.nb[2] = nb11;
15627
+ src1_row.nb[3] = nb11;
15628
+
15629
+ dst_row.ne[1] = 1;
15630
+ dst_row.ne[2] = 1;
15631
+ dst_row.ne[3] = 1;
15632
+ dst_row.nb[2] = nb1;
15633
+ dst_row.nb[3] = nb1;
15634
+ if (ne12 == 1) {
15635
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15636
+ for (int64_t id = 0; id < n_ids; id++) {
15637
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15638
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
15639
+
15640
+ const int64_t i11 = id % ne11;
15641
+ const int64_t i12 = iid1;
15642
+
15643
+ const int64_t i1 = id;
15644
+ const int64_t i2 = i12;
15501
15645
 
15502
15646
  src0_row_extra.data_device[g_main_device] =
15503
- src0_original + row_id * src0->nb[2];
15647
+ src0_original + i02*nb02;
15504
15648
  src1_row_extra.data_device[g_main_device] =
15505
- src1_original + i01 * src1->nb[1];
15649
+ src1_original + + i11*nb11 + i12*nb12;
15506
15650
  dst_row_extra.data_device[g_main_device] =
15507
- dst_original + i01 * dst->nb[1];
15651
+ dst_original + i1*nb1 + i2*nb2;
15508
15652
 
15509
15653
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15654
+ }
15510
15655
  }
15511
15656
  } else {
15512
15657
  sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15515,64 +15660,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15515
15660
  src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
15516
15661
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
15517
15662
 
15518
- for (int32_t row_id = 0; row_id < n_as; ++row_id) {
15663
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
15519
15664
  int64_t num_src1_rows = 0;
15520
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15521
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15665
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15666
+ for (int64_t id = 0; id < n_ids; id++) {
15667
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15522
15668
 
15523
- if (row_id_i != row_id) {
15524
- continue;
15525
- }
15669
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
15526
15670
 
15527
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15671
+ if (row_id_i != i02) {
15672
+ continue;
15673
+ }
15528
15674
 
15529
- SYCL_CHECK(CHECK_TRY_ERROR(
15530
- stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15531
- src1_original + i01 * nb11, nb11)));
15532
- num_src1_rows++;
15675
+ num_src1_rows++;
15676
+ }
15533
15677
  }
15534
15678
 
15535
15679
  if (num_src1_rows == 0) {
15536
15680
  continue;
15537
15681
  }
15538
15682
 
15539
- src0_row_extra.data_device[g_main_device] =
15540
- src0_original + row_id * src0->nb[2];
15541
15683
 
15684
+ sycl_pool_alloc<int> dev_cur_src1_row(1);
15685
+ sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
15686
+ SYCL_CHECK(CHECK_TRY_ERROR(
15687
+ stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
15688
+
15689
+ {
15690
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
15691
+ sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
15692
+ stream->submit([&](sycl::handler &cgh) {
15693
+ sycl::local_accessor<int, 0> src1_row_acc(cgh);
15694
+
15695
+ char *__restrict src1_contiguous_get =
15696
+ src1_contiguous.get();
15697
+ int *__restrict dev_cur_src1_row_get =
15698
+ dev_cur_src1_row.get();
15699
+ mmid_row_mapping *__restrict dev_row_mapping_get =
15700
+ dev_row_mapping.get();
15701
+ size_t ids_nb_ct6 = ids->nb[1];
15702
+ size_t ids_nb_ct7 = ids->nb[0];
15703
+
15704
+ cgh.parallel_for(
15705
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15706
+ [=](sycl::nd_item<3> item_ct1) {
15707
+ k_copy_src1_to_contiguous(
15708
+ src1_original, src1_contiguous_get,
15709
+ dev_cur_src1_row_get,
15710
+ dev_row_mapping_get, ids_dev, i02,
15711
+ ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
15712
+ item_ct1, src1_row_acc);
15713
+ });
15714
+ });
15715
+ }
15716
+
15717
+ src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
15718
+
15719
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
15720
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
15542
15721
  src1_row.ne[1] = num_src1_rows;
15543
- dst_row.ne[1] = num_src1_rows;
15544
15722
 
15545
15723
  src1_row.nb[1] = nb11;
15546
15724
  src1_row.nb[2] = num_src1_rows*nb11;
15547
15725
  src1_row.nb[3] = num_src1_rows*nb11;
15548
15726
 
15727
+ dst_row.ne[1] = num_src1_rows;
15549
15728
  dst_row.nb[1] = nb1;
15550
15729
  dst_row.nb[2] = num_src1_rows*nb1;
15551
15730
  dst_row.nb[3] = num_src1_rows*nb1;
15552
15731
 
15553
15732
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15554
15733
 
15555
- num_src1_rows = 0;
15556
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15557
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15558
-
15559
- if (row_id_i != row_id) {
15560
- continue;
15561
- }
15562
-
15563
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15564
-
15565
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
15566
- dst_original + i01 * nb1,
15567
- dst_contiguous.get() + num_src1_rows * nb1, nb1)));
15568
- num_src1_rows++;
15734
+ {
15735
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
15736
+ sycl::range<3> grid_dims(1, 1, num_src1_rows);
15737
+ stream->submit([&](sycl::handler &cgh) {
15738
+ const char *__restrict dst_contiguous_get =
15739
+ dst_contiguous.get();
15740
+ const mmid_row_mapping *__restrict dev_row_mapping_get =
15741
+ dev_row_mapping.get();
15742
+
15743
+ cgh.parallel_for(
15744
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15745
+ [=](sycl::nd_item<3> item_ct1) {
15746
+ k_copy_dst_from_contiguous(dst_original,
15747
+ dst_contiguous_get,
15748
+ dev_row_mapping_get,
15749
+ ne0, nb1, nb2, item_ct1);
15750
+ });
15751
+ });
15569
15752
  }
15570
15753
  }
15571
15754
  }
15572
-
15573
- if (dst->backend == GGML_BACKEND_TYPE_CPU) {
15574
- SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
15575
- }
15576
15755
  }
15577
15756
  catch (sycl::exception const &exc) {
15578
15757
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -16555,10 +16734,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
16555
16734
  UNUSED(buffer);
16556
16735
  }
16557
16736
 
16558
- // unused at the moment
16559
- //static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16560
- // return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16561
- //}
16737
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16738
+ return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16739
+ }
16562
16740
 
16563
16741
  GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
16564
16742
  ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;