llama_cpp 0.15.2 → 0.15.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -2944,6 +2944,57 @@ namespace dpct
2944
2944
  using shared_memory = detail::device_memory<T, shared, Dimension>;
2945
2945
 
2946
2946
 
2947
+ template <typename T,
2948
+ sycl::access::address_space addressSpace =
2949
+ sycl::access::address_space::global_space,
2950
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2951
+ sycl::memory_scope memoryScope = sycl::memory_scope::device>
2952
+ inline T atomic_fetch_add(T *addr, T operand) {
2953
+ auto atm =
2954
+ sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2955
+ return atm.fetch_add(operand);
2956
+ }
2957
+
2958
+ template <sycl::access::address_space addressSpace =
2959
+ sycl::access::address_space::global_space,
2960
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2961
+ sycl::memory_scope memoryScope = sycl::memory_scope::device,
2962
+ typename T1, typename T2>
2963
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2964
+ auto atm =
2965
+ sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2966
+ return atm.fetch_add(operand);
2967
+ }
2968
+
2969
+ template <typename T, sycl::access::address_space addressSpace =
2970
+ sycl::access::address_space::global_space>
2971
+ inline T atomic_fetch_add(T *addr, T operand,
2972
+ sycl::memory_order memoryOrder) {
2973
+ switch (memoryOrder) {
2974
+ case sycl::memory_order::relaxed:
2975
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2976
+ sycl::memory_scope::device>(addr, operand);
2977
+ case sycl::memory_order::acq_rel:
2978
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2979
+ sycl::memory_scope::device>(addr, operand);
2980
+ case sycl::memory_order::seq_cst:
2981
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2982
+ sycl::memory_scope::device>(addr, operand);
2983
+ default:
2984
+ assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2985
+ "atomics are: sycl::memory_order::relaxed, "
2986
+ "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2987
+ }
2988
+ }
2989
+
2990
+ template <sycl::access::address_space addressSpace =
2991
+ sycl::access::address_space::global_space,
2992
+ typename T1, typename T2>
2993
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2994
+ sycl::memory_order memoryOrder) {
2995
+ atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2996
+ }
2997
+
2947
2998
  } // COPY from DPCT head files
2948
2999
 
2949
3000
  #define GGML_COMMON_DECL_SYCL
@@ -2971,20 +3022,19 @@ static int g_work_group_size = 0;
2971
3022
  // typedef sycl::half ggml_fp16_t;
2972
3023
 
2973
3024
  #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
2974
- #define VER_4VEC 610 //todo for hardward optimize.
3025
+ #define VER_4VEC 130 //todo for hardward optimize.
2975
3026
  #define VER_GEN9 700 //todo for hardward optimize.
2976
3027
  #define VER_GEN12 1000000 //todo for hardward optimize.
2977
3028
  #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
2978
3029
 
2979
3030
  #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares
2980
3031
 
2981
-
2982
- //define for XMX in Intel GPU
2983
- //TODO: currently, it's not used for XMX really.
2984
- #define SYCL_USE_XMX
3032
+ #if !defined(GGML_SYCL_FORCE_MMQ)
3033
+ #define SYCL_USE_XMX
3034
+ #endif
2985
3035
 
2986
3036
  // max batch size to use MMQ kernels when tensor cores are available
2987
- #define XMX_MAX_BATCH_SIZE 32
3037
+ #define MMQ_MAX_BATCH_SIZE 32
2988
3038
 
2989
3039
 
2990
3040
  #if defined(_MSC_VER)
@@ -3060,6 +3110,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
3060
3110
  bool ggml_backend_is_sycl(ggml_backend_t backend);
3061
3111
  int ggml_backend_sycl_get_device(ggml_backend_t backend);
3062
3112
  int get_main_device();
3113
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
3063
3114
  void print_ggml_tensor(const char*name, struct ggml_tensor *src);
3064
3115
  void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
3065
3116
 
@@ -3847,21 +3898,27 @@ static void concat_f32(const float *x,const float *y, float *dst, const int ne
3847
3898
  }
3848
3899
  }
3849
3900
 
3850
- static void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor,
3851
- const sycl::nd_item<3> &item_ct1) {
3852
- int ne0 = ne00 * scale_factor;
3853
- int nidx = item_ct1.get_local_id(2) +
3854
- item_ct1.get_group(2) * item_ct1.get_local_range(2);
3855
- if (nidx >= ne0) {
3901
+ static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
3902
+ const int nb02, const int nb03, const int ne10, const int ne11,
3903
+ const int ne12, const int ne13, const float sf0, const float sf1,
3904
+ const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
3905
+ int index = item_ct1.get_local_id(0) +
3906
+ item_ct1.get_group(0) * item_ct1.get_local_range(0);
3907
+ if (index >= ne10 * ne11 * ne12 * ne13) {
3856
3908
  return;
3857
3909
  }
3858
3910
  // operation
3859
- int i00 = nidx / scale_factor;
3860
- int i01 = item_ct1.get_group(1) / scale_factor;
3861
- int offset_src = i00 + i01 * ne00 + item_ct1.get_group(0) * nb02;
3862
- int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
3863
- item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
3864
- dst[offset_dst] = x[offset_src];
3911
+ int i10 = index % ne10;
3912
+ int i11 = (index / ne10) % ne11;
3913
+ int i12 = (index / (ne10 * ne11)) % ne12;
3914
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
3915
+
3916
+ int i00 = i10 / sf0;
3917
+ int i01 = i11 / sf1;
3918
+ int i02 = i12 / sf2;
3919
+ int i03 = i13 / sf3;
3920
+
3921
+ dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
3865
3922
  }
3866
3923
 
3867
3924
  static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
@@ -4191,7 +4248,6 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
4191
4248
  const block_q2_K * x = (const block_q2_K *) vx;
4192
4249
 
4193
4250
  const int tid = item_ct1.get_local_id(2);
4194
- #if QK_K == 256
4195
4251
  const int n = tid/32;
4196
4252
  const int l = tid - 32*n;
4197
4253
  const int is = 8*n + l/16;
@@ -4205,18 +4261,6 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
4205
4261
  y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
4206
4262
  y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
4207
4263
  y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
4208
- #else
4209
- const int is = tid/16; // 0 or 1
4210
- const int il = tid%16; // 0...15
4211
- const uint8_t q = x[i].qs[il] >> (2*is);
4212
- dst_t * y = yy + i*QK_K + 16*is + il;
4213
-
4214
- float dall = x[i].dm[0];
4215
- float dmin = x[i].dm[1];
4216
- y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
4217
- y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
4218
- #endif
4219
-
4220
4264
  }
4221
4265
 
4222
4266
  template<typename dst_t>
@@ -4226,7 +4270,6 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
4226
4270
  const int i = item_ct1.get_group(2);
4227
4271
  const block_q3_K * x = (const block_q3_K *) vx;
4228
4272
 
4229
- #if QK_K == 256
4230
4273
  const int r = item_ct1.get_local_id(2) / 4;
4231
4274
  const int tid = r/2;
4232
4275
  const int is0 = r%2;
@@ -4250,31 +4293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
4250
4293
  const uint8_t * hm = x[i].hmask;
4251
4294
 
4252
4295
  for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
4253
- #else
4254
- const int tid = item_ct1.get_local_id(2);
4255
- const int is = tid/16; // 0 or 1
4256
- const int il = tid%16; // 0...15
4257
- const int im = il/8; // 0...1
4258
- const int in = il%8; // 0...7
4259
-
4260
- dst_t * y = yy + i*QK_K + 16*is + il;
4261
-
4262
- const uint8_t q = x[i].qs[il] >> (2*is);
4263
- const uint8_t h = x[i].hmask[in] >> (2*is + im);
4264
- const float d = (float)x[i].d;
4265
-
4266
- if (is == 0) {
4267
- y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
4268
- y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
4269
- } else {
4270
- y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
4271
- y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
4272
- }
4273
- #endif
4274
-
4275
4296
  }
4276
4297
 
4277
- #if QK_K == 256
4278
4298
  static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
4279
4299
  if (j < 4) {
4280
4300
  d = q[j] & 63; m = q[j + 4] & 63;
@@ -4283,7 +4303,6 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
4283
4303
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
4284
4304
  }
4285
4305
  }
4286
- #endif
4287
4306
 
4288
4307
  template<typename dst_t>
4289
4308
  static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
@@ -4292,7 +4311,6 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
4292
4311
 
4293
4312
  const int i = item_ct1.get_group(2);
4294
4313
 
4295
- #if QK_K == 256
4296
4314
  // assume 32 threads
4297
4315
  const int tid = item_ct1.get_local_id(2);
4298
4316
  const int il = tid/8;
@@ -4316,15 +4334,6 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
4316
4334
  y[l + 0] = d1 * (q[l] & 0xF) - m1;
4317
4335
  y[l +32] = d2 * (q[l] >> 4) - m2;
4318
4336
  }
4319
- #else
4320
- const int tid = item_ct1.get_local_id(2);
4321
- const uint8_t * q = x[i].qs;
4322
- dst_t * y = yy + i*QK_K;
4323
- const float d = (float)x[i].dm[0];
4324
- const float m = (float)x[i].dm[1];
4325
- y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
4326
- y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
4327
- #endif
4328
4337
  }
4329
4338
 
4330
4339
  template<typename dst_t>
@@ -4334,7 +4343,6 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
4334
4343
 
4335
4344
  const int i = item_ct1.get_group(2);
4336
4345
 
4337
- #if QK_K == 256
4338
4346
  // assume 64 threads - this is very slightly better than the one below
4339
4347
  const int tid = item_ct1.get_local_id(2);
4340
4348
  const int il = tid/16; // il is in 0...3
@@ -4361,18 +4369,6 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
4361
4369
  hm <<= 1;
4362
4370
  y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
4363
4371
  y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
4364
- #else
4365
- const int tid = item_ct1.get_local_id(2);
4366
- const uint8_t q = x[i].qs[tid];
4367
- const int im = tid/8; // 0...3
4368
- const int in = tid%8; // 0...7
4369
- const int is = tid/16; // 0 or 1
4370
- const uint8_t h = x[i].qh[in] >> im;
4371
- const float d = x[i].d;
4372
- dst_t * y = yy + i*QK_K + tid;
4373
- y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
4374
- y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
4375
- #endif
4376
4372
  }
4377
4373
 
4378
4374
  template<typename dst_t>
@@ -4381,7 +4377,6 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
4381
4377
  const block_q6_K * x = (const block_q6_K *) vx;
4382
4378
 
4383
4379
  const int i = item_ct1.get_group(2);
4384
- #if QK_K == 256
4385
4380
 
4386
4381
  // assume 64 threads - this is very slightly better than the one below
4387
4382
  const int tid = item_ct1.get_local_id(2);
@@ -4401,24 +4396,6 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
4401
4396
  y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
4402
4397
  y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
4403
4398
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
4404
- #else
4405
-
4406
- // assume 32 threads
4407
- const int tid = item_ct1.get_local_id(2);
4408
- const int ip = tid/16; // 0 or 1
4409
- const int il = tid - 16*ip; // 0...15
4410
-
4411
- dst_t * y = yy + i*QK_K + 16*ip + il;
4412
-
4413
- const float d = x[i].d;
4414
-
4415
- const uint8_t ql = x[i].ql[16*ip + il];
4416
- const uint8_t qh = x[i].qh[il] >> (2*ip);
4417
- const int8_t * sc = x[i].scales;
4418
-
4419
- y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
4420
- y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
4421
- #endif
4422
4399
  }
4423
4400
 
4424
4401
  template<typename dst_t>
@@ -4432,7 +4409,6 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
4432
4409
  const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
4433
4410
 
4434
4411
  const int tid = item_ct1.get_local_id(2);
4435
- #if QK_K == 256
4436
4412
  const int il = tid/8; // 0...3
4437
4413
  const int ib = tid%8; // 0...7
4438
4414
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4443,10 +4419,6 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
4443
4419
  const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
4444
4420
  const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127];
4445
4421
  for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f);
4446
- #else
4447
- assert(false);
4448
- #endif
4449
-
4450
4422
  }
4451
4423
 
4452
4424
  template<typename dst_t>
@@ -4460,7 +4432,6 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
4460
4432
  const block_iq2_xs * x = (const block_iq2_xs *) vx;
4461
4433
 
4462
4434
  const int tid = item_ct1.get_local_id(2);
4463
- #if QK_K == 256
4464
4435
  const int il = tid/8; // 0...3
4465
4436
  const int ib = tid%8; // 0...7
4466
4437
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4469,10 +4440,6 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
4469
4440
  const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
4470
4441
  const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
4471
4442
  for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
4472
- #else
4473
- assert(false);
4474
- #endif
4475
-
4476
4443
  }
4477
4444
 
4478
4445
  template <typename dst_t>
@@ -4484,7 +4451,6 @@ dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4484
4451
  const block_iq2_s * x = (const block_iq2_s *) vx;
4485
4452
 
4486
4453
  const int tid = item_ct1.get_local_id(2);
4487
- #if QK_K == 256
4488
4454
  const int il = tid/8; // 0...3
4489
4455
  const int ib = tid%8; // 0...7
4490
4456
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4492,13 +4458,9 @@ dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4492
4458
  const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
4493
4459
  const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
4494
4460
  #pragma unroll
4495
- for (int j = 0; j < 8; ++j)
4461
+ for (int j = 0; j < 8; ++j) {
4496
4462
  y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
4497
- #else
4498
- assert(false);
4499
-
4500
- #endif
4501
-
4463
+ }
4502
4464
  }
4503
4465
 
4504
4466
  template<typename dst_t>
@@ -4512,7 +4474,6 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
4512
4474
  const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
4513
4475
 
4514
4476
  const int tid = item_ct1.get_local_id(2);
4515
- #if QK_K == 256
4516
4477
  const int il = tid/8; // 0...3
4517
4478
  const int ib = tid%8; // 0...7
4518
4479
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4527,10 +4488,6 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
4527
4488
  y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4528
4489
  y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4529
4490
  }
4530
- #else
4531
- assert(false);
4532
- #endif
4533
-
4534
4491
  }
4535
4492
 
4536
4493
  template <typename dst_t>
@@ -4543,7 +4500,6 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4543
4500
  const block_iq3_s * x = (const block_iq3_s *) vx;
4544
4501
 
4545
4502
  const int tid = item_ct1.get_local_id(2);
4546
- #if QK_K == 256
4547
4503
  const int il = tid/8; // 0...3
4548
4504
  const int ib = tid%8; // 0...7
4549
4505
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4557,10 +4513,6 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4557
4513
  y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4558
4514
  y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4559
4515
  }
4560
- #else
4561
- assert(false);
4562
- #endif
4563
-
4564
4516
  }
4565
4517
 
4566
4518
  template <typename dst_t>
@@ -4573,7 +4525,6 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4573
4525
  const block_iq1_s * x = (const block_iq1_s *) vx;
4574
4526
 
4575
4527
  const int tid = item_ct1.get_local_id(2);
4576
- #if QK_K == 256
4577
4528
  const int il = tid/8; // 0...3
4578
4529
  const int ib = tid%8; // 0...7
4579
4530
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4587,10 +4538,6 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4587
4538
  for (int j = 0; j < 8; ++j) {
4588
4539
  y[j] = d * (q[j] + delta);
4589
4540
  }
4590
- #else
4591
- assert(false);
4592
- #endif
4593
-
4594
4541
  }
4595
4542
 
4596
4543
  template <typename dst_t>
@@ -4603,7 +4550,6 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
4603
4550
  const block_iq1_m * x = (const block_iq1_m *) vx;
4604
4551
 
4605
4552
  const int tid = item_ct1.get_local_id(2);
4606
- #if QK_K == 256
4607
4553
  const int il = tid/8; // 0...3
4608
4554
  const int ib = tid%8; // 0...7
4609
4555
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4621,10 +4567,6 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
4621
4567
  for (int j = 0; j < 8; ++j) {
4622
4568
  y[j] = d * (q[j] + delta);
4623
4569
  }
4624
- #else
4625
- assert(false);
4626
- #endif
4627
-
4628
4570
  }
4629
4571
 
4630
4572
  template <typename dst_t>
@@ -4698,7 +4640,6 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
4698
4640
 
4699
4641
  float tmp = 0; // partial sum for thread in warp
4700
4642
 
4701
- #if QK_K == 256
4702
4643
  const int tid =
4703
4644
  item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15
4704
4645
  const int ix =
@@ -4749,42 +4690,6 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
4749
4690
  tmp += dall * sum1 - dmin * sum2;
4750
4691
 
4751
4692
  }
4752
- #else
4753
- const int tid = item_ct1.get_local_id(2) /
4754
- (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7
4755
- const int ix = item_ct1.get_local_id(2) %
4756
- (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3
4757
- const int offset = tid * K_QUANTS_PER_ITERATION;
4758
-
4759
- uint32_t uaux[2];
4760
- const uint8_t * d = (const uint8_t *)uaux;
4761
-
4762
-
4763
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
4764
-
4765
- const float * y = yy + i * QK_K + offset;
4766
- const uint8_t * q = x[i].qs + offset;
4767
- const uint32_t * s = (const uint32_t *)x[i].scales;
4768
-
4769
- uaux[0] = s[0] & 0x0f0f0f0f;
4770
- uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
4771
-
4772
- const sycl::float2 dall =
4773
- x[i].dm.convert<float, sycl::rounding_mode::automatic>();
4774
-
4775
- float sum1 = 0, sum2 = 0;
4776
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
4777
- const uint8_t ql = q[l];
4778
- sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
4779
- + y[l+16] * d[1] * ((ql >> 2) & 3)
4780
- + y[l+32] * d[2] * ((ql >> 4) & 3)
4781
- + y[l+48] * d[3] * ((ql >> 6) & 3);
4782
- sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
4783
- }
4784
- tmp += dall.x() * sum1 - dall.y() * sum2;
4785
- }
4786
-
4787
- #endif
4788
4693
 
4789
4694
  // sum up partial sums and write back result
4790
4695
  #pragma unroll
@@ -4822,8 +4727,6 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
4822
4727
 
4823
4728
  float tmp = 0; // partial sum for thread in warp
4824
4729
 
4825
- #if QK_K == 256
4826
-
4827
4730
  const uint16_t kmask1 = 0x0303;
4828
4731
  const uint16_t kmask2 = 0x0f0f;
4829
4732
 
@@ -4876,34 +4779,6 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
4876
4779
  tmp += d * sum;
4877
4780
 
4878
4781
  }
4879
- #else
4880
-
4881
- const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
4882
- const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
4883
- const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
4884
- const int in = offset/8; // 0 or 1
4885
- const int im = offset%8; // 0...7
4886
-
4887
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
4888
-
4889
- const float * y = yy + i * QK_K + offset;
4890
- const uint8_t * q = x[i].qs + offset;
4891
- const uint8_t * s = x[i].scales;
4892
-
4893
- const float dall = (float)x[i].d;
4894
-
4895
- float sum = 0;
4896
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
4897
- const uint8_t hl = x[i].hmask[im+l] >> in;
4898
- const uint8_t ql = q[l];
4899
- sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
4900
- + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
4901
- + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
4902
- + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
4903
- }
4904
- tmp += sum;
4905
- }
4906
- #endif
4907
4782
 
4908
4783
  // sum up partial sums and write back result
4909
4784
  #pragma unroll
@@ -4938,7 +4813,6 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
4938
4813
 
4939
4814
  const block_q4_K * x = (const block_q4_K *)vx + ib0;
4940
4815
 
4941
- #if QK_K == 256
4942
4816
  const uint16_t kmask1 = 0x3f3f;
4943
4817
  const uint16_t kmask2 = 0x0f0f;
4944
4818
  const uint16_t kmask3 = 0xc0c0;
@@ -5027,36 +4901,6 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
5027
4901
  #endif
5028
4902
 
5029
4903
  }
5030
- #else
5031
- const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
5032
- const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
5033
-
5034
- const int step = tid * K_QUANTS_PER_ITERATION;
5035
-
5036
- uint16_t aux16[2];
5037
- const uint8_t * s = (const uint8_t *)aux16;
5038
-
5039
- float tmp = 0;
5040
-
5041
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
5042
- const uint8_t * q = x[i].qs + step;
5043
- const float * y = yy + i*QK_K + step;
5044
- const uint16_t * a = (const uint16_t *)x[i].scales;
5045
- aux16[0] = a[0] & 0x0f0f;
5046
- aux16[1] = (a[0] >> 4) & 0x0f0f;
5047
- const float d = (float)x[i].dm[0];
5048
- const float m = (float)x[i].dm[1];
5049
- float sum = 0.f;
5050
- for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
5051
- sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
5052
- + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
5053
- + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
5054
- + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
5055
- }
5056
- tmp += sum;
5057
- }
5058
-
5059
- #endif
5060
4904
 
5061
4905
  // sum up partial sums and write back result
5062
4906
  #pragma unroll
@@ -5091,7 +4935,6 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
5091
4935
 
5092
4936
  float tmp = 0; // partial sum for thread in warp
5093
4937
 
5094
- #if QK_K == 256
5095
4938
  const uint16_t kmask1 = 0x3f3f;
5096
4939
  const uint16_t kmask2 = 0x0f0f;
5097
4940
  const uint16_t kmask3 = 0xc0c0;
@@ -5168,30 +5011,6 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
5168
5011
  dmin * smin;
5169
5012
  }
5170
5013
 
5171
- #else
5172
- const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
5173
- const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
5174
- const int step = tid * K_QUANTS_PER_ITERATION;
5175
- const int im = step/8;
5176
- const int in = step%8;
5177
-
5178
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
5179
- const uint8_t * q = x[i].qs + step;
5180
- const int8_t * s = x[i].scales;
5181
- const float * y = yy + i*QK_K + step;
5182
- const float d = x[i].d;
5183
- float sum = 0.f;
5184
- for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
5185
- const uint8_t h = x[i].qh[in+j] >> im;
5186
- sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
5187
- + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
5188
- + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
5189
- + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
5190
- }
5191
- tmp += sum;
5192
- }
5193
- #endif
5194
-
5195
5014
  // sum up partial sums and write back result
5196
5015
  #pragma unroll
5197
5016
  for (int mask = 16; mask > 0; mask >>= 1) {
@@ -5218,8 +5037,6 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
5218
5037
 
5219
5038
  const block_q6_K * x = (const block_q6_K *)vx + ib0;
5220
5039
 
5221
- #if QK_K == 256
5222
-
5223
5040
  const int tid =
5224
5041
  item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
5225
5042
  const int ix =
@@ -5276,37 +5093,6 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
5276
5093
 
5277
5094
  }
5278
5095
 
5279
- #else
5280
-
5281
- const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7
5282
- const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3
5283
-
5284
- const int step = tid * K_QUANTS_PER_ITERATION;
5285
-
5286
- float tmp = 0; // partial sum for thread in warp
5287
-
5288
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
5289
-
5290
- const float * y = yy + i * QK_K + step;
5291
- const uint8_t * ql = x[i].ql + step;
5292
- const uint8_t * qh = x[i].qh + step;
5293
- const int8_t * s = x[i].scales;
5294
-
5295
- const float d = x[i+0].d;
5296
-
5297
- float sum = 0;
5298
- for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
5299
- sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
5300
- + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
5301
- + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
5302
- + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
5303
- }
5304
- tmp += sum;
5305
-
5306
- }
5307
-
5308
- #endif
5309
-
5310
5096
  // sum up partial sums and write back result
5311
5097
  #pragma unroll
5312
5098
  for (int mask = 16; mask > 0; mask >>= 1) {
@@ -6851,7 +6637,6 @@ static __dpct_inline__ float
6851
6637
  vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
6852
6638
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
6853
6639
 
6854
- #ifndef GGML_QKK_64
6855
6640
  const block_q4_K * bq4_K = (const block_q4_K *) vbq;
6856
6641
 
6857
6642
  int v[2];
@@ -6893,52 +6678,6 @@ vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
6893
6678
  }
6894
6679
 
6895
6680
  return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
6896
-
6897
- #else
6898
-
6899
- #if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
6900
- const block_q4_K * bq4_K = (const block_q4_K *) vbq;
6901
-
6902
- float sumf_d = 0.0f;
6903
- float sumf_m = 0.0f;
6904
-
6905
- uint16_t aux16[2];
6906
- const uint8_t * s = (const uint8_t *)aux16;
6907
-
6908
- const uint16_t * a = (const uint16_t *)bq4_K->scales;
6909
- aux16[0] = a[0] & 0x0f0f;
6910
- aux16[1] = (a[0] >> 4) & 0x0f0f;
6911
-
6912
- const float dall = bq4_K->dm[0];
6913
- const float dmin = bq4_K->dm[1];
6914
-
6915
- const float d8_1 = bq8_1[0].ds[0];
6916
- const float d8_2 = bq8_1[1].ds[1];
6917
-
6918
- const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
6919
- const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
6920
- const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
6921
- const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
6922
-
6923
- const int * q4 = (const int *)bq4_K->qs + (iqs/2);
6924
- const int v1 = q4[0];
6925
- const int v2 = q4[4];
6926
-
6927
- const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0));
6928
- const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
6929
- const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0));
6930
- const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0));
6931
-
6932
- sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
6933
- sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
6934
-
6935
- return dall * sumf_d - dmin * sumf_m;
6936
-
6937
- #else
6938
- bad_arch();
6939
- #endif // __SYCL_ARCH__ >= VER_4VEC
6940
-
6941
- #endif
6942
6681
  }
6943
6682
 
6944
6683
  template <int mmq_y>
@@ -6997,11 +6736,7 @@ load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
6997
6736
 
6998
6737
  const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
6999
6738
 
7000
- #if QK_K == 256
7001
6739
  x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
7002
- #else
7003
- x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
7004
- #endif
7005
6740
  }
7006
6741
 
7007
6742
  #pragma unroll
@@ -7044,7 +6779,6 @@ static __dpct_inline__ float
7044
6779
  vec_dot_q5_K_q8_1(const void *__restrict__ vbq,
7045
6780
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7046
6781
 
7047
- #ifndef GGML_QKK_64
7048
6782
  const block_q5_K * bq5_K = (const block_q5_K *) vbq;
7049
6783
 
7050
6784
  int vl[2];
@@ -7086,48 +6820,6 @@ vec_dot_q5_K_q8_1(const void *__restrict__ vbq,
7086
6820
  }
7087
6821
 
7088
6822
  return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
7089
-
7090
- #else
7091
-
7092
- #if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
7093
- const block_q5_K * bq5_K = (const block_q5_K *) vbq;
7094
-
7095
- const int8_t * s = bq5_K->scales;
7096
-
7097
- const float d = bq5_K->d;
7098
-
7099
- const float d8_1 = bq8_1[0].ds[0];
7100
- const float d8_2 = bq8_1[1].ds[1];
7101
-
7102
- const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
7103
- const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
7104
- const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
7105
- const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
7106
-
7107
- const int * ql = (const int *)bq5_K->qs + (iqs/2);
7108
- const int vl1 = ql[0];
7109
- const int vl2 = ql[4];
7110
-
7111
- const int step = 4 * (iqs/2); // 0, 4, 8, 12
7112
- const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
7113
- const int in = step%8; // 0, 4, 0, 4
7114
- const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
7115
-
7116
- const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
7117
- const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
7118
- const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
7119
- const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
7120
-
7121
- const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1])
7122
- + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]);
7123
-
7124
- return d * sumf_d;
7125
-
7126
- #else
7127
- bad_arch();
7128
- #endif // __SYCL_ARCH__ >= VER_4VEC
7129
-
7130
- #endif
7131
6823
  }
7132
6824
 
7133
6825
  template <int mmq_y>
@@ -7199,9 +6891,7 @@ load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
7199
6891
 
7200
6892
  const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
7201
6893
 
7202
- #if QK_K == 256
7203
6894
  x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
7204
- #endif
7205
6895
  }
7206
6896
 
7207
6897
  #pragma unroll
@@ -7381,7 +7071,6 @@ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
7381
7071
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7382
7072
  const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs,
7383
7073
  const uint8_t *kmask_iq2xs) {
7384
- #if QK_K == 256
7385
7074
  const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
7386
7075
 
7387
7076
  #if QR2_XXS == 8
@@ -7422,10 +7111,6 @@ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
7422
7111
  }
7423
7112
  return d * (sumi1 + sumi2);
7424
7113
  #endif
7425
- #else
7426
- assert(false);
7427
- return 0.f;
7428
- #endif
7429
7114
  }
7430
7115
 
7431
7116
  static __dpct_inline__ float
@@ -7434,7 +7119,6 @@ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
7434
7119
  const uint64_t *iq2xs_grid, const uint64_t *ksigns64) {
7435
7120
  #if DPCT_COMPATIBILITY_TEMP >= \
7436
7121
  MIN_CC_DP4A // lowest compute capability for integer intrinsics
7437
- #if QK_K == 256
7438
7122
  const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
7439
7123
 
7440
7124
  const int ib32 = iqs;
@@ -7472,16 +7156,11 @@ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
7472
7156
  assert(false);
7473
7157
  return 0.f;
7474
7158
  #endif
7475
- #else
7476
- assert(false);
7477
- return 0.f;
7478
- #endif
7479
7159
  }
7480
7160
 
7481
7161
  static __dpct_inline__ float
7482
7162
  vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
7483
7163
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7484
- #if QK_K == 256
7485
7164
  const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
7486
7165
 
7487
7166
  const int ib32 = iqs;
@@ -7525,9 +7204,6 @@ vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
7525
7204
  }
7526
7205
  const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
7527
7206
  return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
7528
- #else
7529
- assert(false);
7530
- #endif
7531
7207
  }
7532
7208
 
7533
7209
  static __dpct_inline__ float
@@ -7536,7 +7212,6 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7536
7212
  const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) {
7537
7213
  #if DPCT_COMPATIBILITY_TEMP >= \
7538
7214
  MIN_CC_DP4A // lowest compute capability for integer intrinsics
7539
- #if QK_K == 256
7540
7215
  const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
7541
7216
 
7542
7217
  const int ib32 = iqs;
@@ -7564,17 +7239,12 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7564
7239
  assert(false);
7565
7240
  return 0.f;
7566
7241
  #endif
7567
- #else
7568
- assert(false);
7569
- return 0.f;
7570
- #endif
7571
7242
  }
7572
7243
 
7573
7244
  static __dpct_inline__ float
7574
7245
  vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7575
7246
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7576
7247
  const uint32_t *iq3s_grid) {
7577
- #if QK_K == 256
7578
7248
  const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
7579
7249
 
7580
7250
  const int ib32 = iqs;
@@ -7603,16 +7273,12 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7603
7273
  (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
7604
7274
  bq8_1[ib32].ds[0];
7605
7275
  return d * sumi;
7606
- #else
7607
- assert(false);
7608
- #endif
7609
7276
  }
7610
7277
 
7611
7278
  static __dpct_inline__ float
7612
7279
  vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
7613
7280
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7614
7281
  const uint32_t *iq1s_grid_gpu) {
7615
- #if QK_K == 256
7616
7282
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
7617
7283
 
7618
7284
  const int ib32 = iqs;
@@ -7631,15 +7297,11 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
7631
7297
  const float d = d1q * bq8_1[ib32].ds[0];
7632
7298
  const float m = d1q * bq8_1[ib32].ds[1];
7633
7299
  return d * sumi + m * delta;
7634
- #else
7635
- assert(false);
7636
- #endif
7637
7300
  }
7638
7301
 
7639
7302
  static __dpct_inline__ float
7640
7303
  vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
7641
7304
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7642
- #if QK_K == 256
7643
7305
  const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
7644
7306
 
7645
7307
  const int ib32 = iqs;
@@ -7664,9 +7326,6 @@ vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
7664
7326
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
7665
7327
  const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
7666
7328
  return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
7667
- #else
7668
- assert(false);
7669
- #endif
7670
7329
  }
7671
7330
 
7672
7331
  static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
@@ -7714,7 +7373,6 @@ static __dpct_inline__ float
7714
7373
  vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
7715
7374
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7716
7375
 
7717
- #if QK_K == 256
7718
7376
  const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
7719
7377
  const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
7720
7378
 
@@ -7732,9 +7390,6 @@ vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
7732
7390
  sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
7733
7391
  }
7734
7392
  return d * (sumi1 + sumi2);
7735
- #else
7736
- assert(false);
7737
- #endif
7738
7393
  }
7739
7394
 
7740
7395
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
@@ -9226,12 +8881,11 @@ static void rope(
9226
8881
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
9227
8882
  }
9228
8883
 
9229
- template<typename T, bool has_pos>
8884
+ template<typename T, bool has_pos, bool has_freq_facs>
9230
8885
  static void rope_neox(
9231
8886
  const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
9232
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
9233
- ,
9234
- const sycl::nd_item<3> &item_ct1) {
8887
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
8888
+ const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
9235
8889
  const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
9236
8890
  item_ct1.get_local_id(1));
9237
8891
 
@@ -9259,8 +8913,10 @@ static void rope_neox(
9259
8913
  float cur_rot = inv_ndims * ic - ib;
9260
8914
 
9261
8915
  const int p = has_pos ? pos[i2] : 0;
8916
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
8917
+
9262
8918
  const float theta_base =
9263
- p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
8919
+ p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
9264
8920
 
9265
8921
  float cos_theta, sin_theta;
9266
8922
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -10085,18 +9741,17 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
10085
9741
  });
10086
9742
  }
10087
9743
 
10088
- static void upscale_f32_sycl(const float *x, float *dst, const int ne00,
10089
- const int ne01, const int ne02,
10090
- const int scale_factor, dpct::queue_ptr stream) {
10091
- int ne0 = (ne00 * scale_factor);
10092
- int num_blocks = (ne0 + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
10093
- sycl::range<3> gridDim(ne02, (ne01 * scale_factor), num_blocks);
9744
+ static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
9745
+ const int nb02, const int nb03, const int ne10, const int ne11,
9746
+ const int ne12, const int ne13, const float sf0, const float sf1,
9747
+ const float sf2, const float sf3, dpct::queue_ptr stream) {
9748
+ int dst_size = ne10 * ne11 * ne12 * ne13;
9749
+ int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
9750
+ sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
10094
9751
  stream->parallel_for(
10095
- sycl::nd_range<3>(gridDim *
10096
- sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
10097
- sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
10098
- [=](sycl::nd_item<3> item_ct1) {
10099
- upscale_f32(x, dst, ne00, ne00 * ne01, scale_factor, item_ct1);
9752
+ sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
9753
+ [=](sycl::nd_item<1> item_ct1) {
9754
+ upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
10100
9755
  });
10101
9756
  }
10102
9757
 
@@ -10198,7 +9853,6 @@ template <typename dst_t>
10198
9853
  static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
10199
9854
  dpct::queue_ptr stream) {
10200
9855
  const int nb = k / QK_K;
10201
- #if QK_K == 256
10202
9856
  {
10203
9857
  dpct::has_capability_or_fail(stream->get_device(),
10204
9858
  {sycl::aspect::fp16});
@@ -10210,27 +9864,12 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
10210
9864
  dequantize_block_q2_K(vx, y, item_ct1);
10211
9865
  });
10212
9866
  }
10213
- #else
10214
- {
10215
- dpct::has_capability_or_fail(stream->get_device(),
10216
- {sycl::aspect::fp16});
10217
-
10218
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10219
- sycl::range<3>(1, 1, 32),
10220
- sycl::range<3>(1, 1, 32)),
10221
- [=](sycl::nd_item<3> item_ct1) {
10222
- dequantize_block_q2_K(vx, y, item_ct1);
10223
- });
10224
- }
10225
-
10226
- #endif
10227
9867
  }
10228
9868
 
10229
9869
  template <typename dst_t>
10230
9870
  static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
10231
9871
  dpct::queue_ptr stream) {
10232
9872
  const int nb = k / QK_K;
10233
- #if QK_K == 256
10234
9873
  {
10235
9874
  dpct::has_capability_or_fail(stream->get_device(),
10236
9875
  {sycl::aspect::fp16});
@@ -10242,19 +9881,6 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
10242
9881
  dequantize_block_q3_K(vx, y, item_ct1);
10243
9882
  });
10244
9883
  }
10245
- #else
10246
- {
10247
- dpct::has_capability_or_fail(stream->get_device(),
10248
- {sycl::aspect::fp16});
10249
-
10250
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10251
- sycl::range<3>(1, 1, 32),
10252
- sycl::range<3>(1, 1, 32)),
10253
- [=](sycl::nd_item<3> item_ct1) {
10254
- dequantize_block_q3_K(vx, y, item_ct1);
10255
- });
10256
- }
10257
- #endif
10258
9884
  }
10259
9885
 
10260
9886
  template <typename dst_t>
@@ -10315,7 +9941,6 @@ template <typename dst_t>
10315
9941
  static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
10316
9942
  dpct::queue_ptr stream) {
10317
9943
  const int nb = k / QK_K;
10318
- #if QK_K == 256
10319
9944
  {
10320
9945
  dpct::has_capability_or_fail(stream->get_device(),
10321
9946
  {sycl::aspect::fp16});
@@ -10327,27 +9952,12 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
10327
9952
  dequantize_block_q5_K(vx, y, item_ct1);
10328
9953
  });
10329
9954
  }
10330
- #else
10331
- {
10332
- dpct::has_capability_or_fail(stream->get_device(),
10333
- {sycl::aspect::fp16});
10334
-
10335
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10336
- sycl::range<3>(1, 1, 32),
10337
- sycl::range<3>(1, 1, 32)),
10338
- [=](sycl::nd_item<3> item_ct1) {
10339
- dequantize_block_q5_K(vx, y, item_ct1);
10340
- });
10341
- }
10342
-
10343
- #endif
10344
9955
  }
10345
9956
 
10346
9957
  template <typename dst_t>
10347
9958
  static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
10348
9959
  dpct::queue_ptr stream) {
10349
9960
  const int nb = k / QK_K;
10350
- #if QK_K == 256
10351
9961
  {
10352
9962
  dpct::has_capability_or_fail(stream->get_device(),
10353
9963
  {sycl::aspect::fp16});
@@ -10359,20 +9969,6 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
10359
9969
  dequantize_block_q6_K(vx, y, item_ct1);
10360
9970
  });
10361
9971
  }
10362
- #else
10363
- {
10364
- dpct::has_capability_or_fail(stream->get_device(),
10365
- {sycl::aspect::fp16});
10366
-
10367
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10368
- sycl::range<3>(1, 1, 32),
10369
- sycl::range<3>(1, 1, 32)),
10370
- [=](sycl::nd_item<3> item_ct1) {
10371
- dequantize_block_q6_K(vx, y, item_ct1);
10372
- });
10373
- }
10374
-
10375
- #endif
10376
9972
  }
10377
9973
 
10378
9974
  template <typename dst_t>
@@ -10524,9 +10120,6 @@ template <typename dst_t>
10524
10120
  static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
10525
10121
  dpct::queue_ptr stream) {
10526
10122
  const int nb = (k + QK_K - 1) / QK_K;
10527
- #if QK_K == 64
10528
- dequantize_row_iq4_nl_sycl(vx, y, k, stream);
10529
- #else
10530
10123
  {
10531
10124
  dpct::has_capability_or_fail(stream->get_device(),
10532
10125
  {sycl::aspect::fp16});
@@ -10541,7 +10134,6 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
10541
10134
  });
10542
10135
  });
10543
10136
  }
10544
- #endif
10545
10137
  }
10546
10138
 
10547
10139
 
@@ -12046,8 +11638,6 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
12046
11638
  const int nrows_y, const int nrows_dst,
12047
11639
  dpct::queue_ptr stream) try {
12048
11640
 
12049
- #if QK_K == 256
12050
-
12051
11641
  int id;
12052
11642
  SYCL_CHECK(
12053
11643
  CHECK_TRY_ERROR(id = get_current_device_id()));
@@ -12162,7 +11752,6 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
12162
11752
  });
12163
11753
  }
12164
11754
  }
12165
- #endif
12166
11755
  }
12167
11756
  catch (sycl::exception const &exc) {
12168
11757
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -12876,7 +12465,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12876
12465
  const int32_t *pos, float freq_scale,
12877
12466
  int p_delta_rows, float freq_base, float ext_factor,
12878
12467
  float attn_factor, rope_corr_dims corr_dims,
12879
- dpct::queue_ptr stream) {
12468
+ const float * freq_factors, dpct::queue_ptr stream) {
12880
12469
  GGML_ASSERT(ncols % 2 == 0);
12881
12470
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
12882
12471
  const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12886,38 +12475,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12886
12475
  const float inv_ndims = -1.0f / n_dims;
12887
12476
 
12888
12477
  if (pos == nullptr) {
12889
- /*
12890
- DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
12891
- the limit. To get the device limit, query
12892
- info::device::max_work_group_size. Adjust the work-group size if needed.
12893
- */
12894
12478
  dpct::has_capability_or_fail(stream->get_device(),
12895
12479
  {sycl::aspect::fp16});
12896
-
12897
- stream->parallel_for(
12898
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12899
- [=](sycl::nd_item<3> item_ct1) {
12900
- rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
12901
- p_delta_rows, ext_factor, attn_factor,
12902
- corr_dims, theta_scale, inv_ndims,
12903
- item_ct1);
12904
- });
12480
+ if (freq_factors == nullptr) {
12481
+ stream->parallel_for(
12482
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12483
+ [=](sycl::nd_item<3> item_ct1) {
12484
+ rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
12485
+ p_delta_rows, ext_factor, attn_factor,
12486
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12487
+ item_ct1);
12488
+ });
12489
+ } else {
12490
+ stream->parallel_for(
12491
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12492
+ [=](sycl::nd_item<3> item_ct1) {
12493
+ rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
12494
+ p_delta_rows, ext_factor, attn_factor,
12495
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12496
+ item_ct1);
12497
+ });
12498
+ }
12905
12499
  } else {
12906
- /*
12907
- DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
12908
- the limit. To get the device limit, query
12909
- info::device::max_work_group_size. Adjust the work-group size if needed.
12910
- */
12911
12500
  dpct::has_capability_or_fail(stream->get_device(),
12912
12501
  {sycl::aspect::fp16});
12913
12502
 
12914
- stream->parallel_for(
12915
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12916
- [=](sycl::nd_item<3> item_ct1) {
12917
- rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
12918
- p_delta_rows, ext_factor, attn_factor,
12919
- corr_dims, theta_scale, inv_ndims, item_ct1);
12920
- });
12503
+ if (freq_factors == nullptr) {
12504
+ stream->parallel_for(
12505
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12506
+ [=](sycl::nd_item<3> item_ct1) {
12507
+ rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12508
+ p_delta_rows, ext_factor, attn_factor,
12509
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12510
+ });
12511
+ } else {
12512
+ stream->parallel_for(
12513
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12514
+ [=](sycl::nd_item<3> item_ct1) {
12515
+ rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12516
+ p_delta_rows, ext_factor, attn_factor,
12517
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12518
+ });
12519
+ }
12921
12520
  }
12922
12521
  }
12923
12522
 
@@ -13964,6 +13563,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0,
13964
13563
  const float *src0_dd, const float *src1_dd,
13965
13564
  float *dst_dd,
13966
13565
  const dpct::queue_ptr &main_stream) {
13566
+ #pragma message("TODO: generalize concat kernel for dim != 2")
13567
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
13568
+ int dim = dst->op_params[0];
13569
+ GGML_ASSERT(dim == 2);
13967
13570
 
13968
13571
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13969
13572
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13985,15 +13588,15 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,
13985
13588
 
13986
13589
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13987
13590
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
13988
- GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
13989
-
13990
- #pragma message("TODO: generalize upscale operator")
13991
- #pragma message(" https://github.com/ggerganov/ggml/pull/814")
13992
- GGML_ASSERT(false && "TODO: generalize upscale operator");
13993
13591
 
13994
- const int scale_factor = dst->op_params[0];
13592
+ const float sf0 = (float)dst->ne[0]/src0->ne[0];
13593
+ const float sf1 = (float)dst->ne[1]/src0->ne[1];
13594
+ const float sf2 = (float)dst->ne[2]/src0->ne[2];
13595
+ const float sf3 = (float)dst->ne[3]/src0->ne[3];
13995
13596
 
13996
- upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
13597
+ upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
13598
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
13599
+ main_stream);
13997
13600
 
13998
13601
  (void) src1;
13999
13602
  (void) dst;
@@ -14449,6 +14052,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14449
14052
  ggml_tensor *dst, const float *src0_dd,
14450
14053
  const float *src1_dd, float *dst_dd,
14451
14054
  const dpct::queue_ptr &main_stream) {
14055
+ const ggml_tensor * src2 = dst->src[2];
14452
14056
 
14453
14057
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
14454
14058
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -14474,6 +14078,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14474
14078
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14475
14079
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14476
14080
 
14081
+ const float * freq_factors = nullptr;
14477
14082
  const int32_t * pos = nullptr;
14478
14083
  if ((mode & 1) == 0) {
14479
14084
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14484,6 +14089,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14484
14089
  const bool is_neox = mode & 2;
14485
14090
  const bool is_glm = mode & 4;
14486
14091
 
14092
+ if (is_neox) {
14093
+ pos = (const int32_t *) src1_dd;
14094
+
14095
+ if (src2 != nullptr) {
14096
+ freq_factors = (const float *) src2->data;
14097
+ }
14098
+ } else {
14099
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14100
+ }
14101
+
14487
14102
  rope_corr_dims corr_dims;
14488
14103
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
14489
14104
 
@@ -14495,13 +14110,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14495
14110
  if (src0->type == GGML_TYPE_F32) {
14496
14111
  rope_neox_sycl(
14497
14112
  (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
14498
- attn_factor, corr_dims, main_stream
14113
+ attn_factor, corr_dims, freq_factors, main_stream
14499
14114
  );
14500
14115
  } else if (src0->type == GGML_TYPE_F16) {
14501
14116
  rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
14502
14117
  ne00, n_dims, nrows, pos, freq_scale, ne01,
14503
14118
  freq_base, ext_factor, attn_factor, corr_dims,
14504
- main_stream);
14119
+ freq_factors, main_stream);
14505
14120
  } else {
14506
14121
  GGML_ASSERT(false);
14507
14122
  }
@@ -15568,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15568
15183
  const int64_t r2 = ne12/ne02;
15569
15184
  const int64_t r3 = ne13/ne03;
15570
15185
 
15571
- if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
15186
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
15572
15187
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
15573
15188
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
15574
15189
  *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
@@ -15633,6 +15248,29 @@ catch (sycl::exception const &exc) {
15633
15248
  std::exit(1);
15634
15249
  }
15635
15250
 
15251
+ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
15252
+ // TODO: accuracy issues in MMQ
15253
+ return false;
15254
+ }
15255
+
15256
+ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
15257
+ switch (type) {
15258
+ case GGML_TYPE_Q4_0:
15259
+ case GGML_TYPE_Q4_1:
15260
+ case GGML_TYPE_Q5_0:
15261
+ case GGML_TYPE_Q5_1:
15262
+ case GGML_TYPE_Q8_0:
15263
+ case GGML_TYPE_Q2_K:
15264
+ case GGML_TYPE_Q3_K:
15265
+ case GGML_TYPE_Q4_K:
15266
+ case GGML_TYPE_Q5_K:
15267
+ case GGML_TYPE_Q6_K:
15268
+ case GGML_TYPE_F16:
15269
+ return true;
15270
+ default:
15271
+ return false;
15272
+ }
15273
+ }
15636
15274
 
15637
15275
  static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
15638
15276
  const bool all_on_device =
@@ -15649,75 +15287,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15649
15287
  }
15650
15288
  }
15651
15289
 
15652
- #ifdef SYCL_USE_XMX
15653
- const bool use_xmx = true;
15654
- #else
15655
- const bool use_xmx = false;
15656
- #endif
15290
+ // check data types and tensor shapes for custom matrix multiplication kernels:
15291
+ bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
15292
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15293
+ && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
15657
15294
 
15658
- // debug helpers
15659
- //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
15660
- //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
15661
- //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
15662
- //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
15663
- //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
15664
- //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
15295
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
15296
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15297
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
15298
+
15299
+ bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
15300
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
15301
+
15302
+ // mmvq and mmq need the __dp4a instruction which is available for gen12+
15303
+ // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
15304
+ use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
15305
+ #ifdef SYCL_USE_XMX
15306
+ use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
15307
+ #endif // SYCL_USE_XMX
15665
15308
 
15666
- if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15309
+ if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15667
15310
  // KQ single-batch
15668
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
15669
15311
  ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
15670
- } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15312
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15671
15313
  // KQV single-batch
15672
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
15673
15314
  ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
15674
- } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
15315
+ } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
15675
15316
  // KQ + KQV multi-batch
15676
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
15677
15317
  ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
15678
- } else if (src0->type == GGML_TYPE_F32) {
15679
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
15680
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15681
- } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
15682
- // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
15683
- if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
15684
- #ifdef GGML_SYCL_FORCE_DMMV
15685
- const bool use_mul_mat_vec_q = false;
15686
- #else
15687
- bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15688
- use_mul_mat_vec_q = use_mul_mat_vec_q ||
15689
- (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15690
- (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15691
- (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15692
- (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15693
-
15694
-
15695
- #endif // GGML_SYCL_FORCE_DMMV
15696
-
15697
- if (use_mul_mat_vec_q) {
15698
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15699
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15700
- } else {
15701
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
15702
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15703
- }
15704
- } else {
15705
- bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15706
-
15707
- if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
15708
- use_mul_mat_q = false;
15709
- }
15710
-
15711
- if (use_mul_mat_q) {
15712
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
15713
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15714
- } else {
15715
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
15716
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15717
- }
15718
- }
15318
+ } else if (use_dequantize_mul_mat_vec) {
15319
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15320
+ } else if (use_mul_mat_vec_q) {
15321
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15322
+ } else if (use_mul_mat_q) {
15323
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15719
15324
  } else {
15720
- GGML_ASSERT(false);
15325
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15721
15326
  }
15722
15327
  }
15723
15328
 
@@ -15894,22 +15499,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15894
15499
  }
15895
15500
  #endif
15896
15501
 
15502
+ struct mmid_row_mapping {
15503
+ int32_t i1;
15504
+ int32_t i2;
15505
+ };
15506
+
15507
+ __dpct_inline__ static void k_copy_src1_to_contiguous(
15508
+ const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
15509
+ int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
15510
+ const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
15511
+ int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
15512
+ const sycl::nd_item<3> &item_ct1, int &src1_row) {
15513
+ int32_t iid1 = item_ct1.get_group(2);
15514
+ int32_t id = item_ct1.get_group(1);
15515
+
15516
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
15517
+
15518
+ if (row_id_i != i02) {
15519
+ return;
15520
+ }
15521
+
15522
+ const int64_t i11 = id % ne11;
15523
+ const int64_t i12 = iid1;
15524
+
15525
+ if (item_ct1.get_local_id(2) == 0) {
15526
+ src1_row =
15527
+ dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
15528
+ cur_src1_row, 1);
15529
+ row_mapping[src1_row] = {id, iid1};
15530
+ }
15531
+ /*
15532
+ DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
15533
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
15534
+ performance if there is no access to global memory.
15535
+ */
15536
+ item_ct1.barrier();
15537
+
15538
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
15539
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
15540
+
15541
+ #pragma unroll
15542
+ for (int i = item_ct1.get_local_id(2); i < ne10;
15543
+ i += item_ct1.get_local_range(2)) {
15544
+ src1_row_contiguous[i] = src1_row_original[i];
15545
+ }
15546
+ }
15547
+
15548
+ __dpct_inline__ static void k_copy_dst_from_contiguous(
15549
+ char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
15550
+ const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
15551
+ size_t nb2, const sycl::nd_item<3> &item_ct1) {
15552
+ int32_t i = item_ct1.get_group(2);
15553
+
15554
+ const int32_t i1 = row_mapping[i].i1;
15555
+ const int32_t i2 = row_mapping[i].i2;
15556
+
15557
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
15558
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
15559
+
15560
+ #pragma unroll
15561
+ for (int j = item_ct1.get_local_id(2); j < ne0;
15562
+ j += item_ct1.get_local_range(2)) {
15563
+ dst_row_original[j] = dst_row_contiguous[j];
15564
+ }
15565
+ }
15566
+
15897
15567
  static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15898
15568
  const ggml_tensor *src1,
15899
15569
  ggml_tensor *dst) try {
15900
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
15901
- "mul_mat_id does not support split buffers");
15570
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
15571
+
15902
15572
  const ggml_tensor *ids = dst->src[2];
15903
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15573
+ GGML_TENSOR_BINARY_OP_LOCALS
15904
15574
 
15905
- const size_t nb11 = src1->nb[1];
15906
- const size_t nb1 = dst->nb[1];
15575
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15907
15576
 
15908
- const int32_t id = ((int32_t *)dst->op_params)[0];
15909
- const int32_t n_as = src0->ne[2];
15577
+ const int64_t n_as = ne02;
15578
+ const int64_t n_ids = ids->ne[0];
15910
15579
 
15911
15580
  std::vector<char> ids_host(ggml_nbytes(ids));
15912
- const char *ids_dev = (const char *)ids->data;
15581
+ const char * ids_dev = (const char *) ids->data;
15913
15582
 
15914
15583
  SYCL_CHECK(CHECK_TRY_ERROR(
15915
15584
  stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15949,24 +15618,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15949
15618
 
15950
15619
  src0_row.ne[2] = 1;
15951
15620
  src0_row.ne[3] = 1;
15952
- src0_row.nb[3] = src0->nb[2];
15953
-
15954
- if (src1->ne[1] == 1) {
15955
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15956
- const int32_t row_id =
15957
- *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
15958
- id * ids->nb[0]);
15959
-
15960
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15621
+ src0_row.nb[3] = nb02;
15622
+
15623
+ src1_row.ne[1] = 1;
15624
+ src1_row.ne[2] = 1;
15625
+ src1_row.ne[3] = 1;
15626
+ src1_row.nb[2] = nb11;
15627
+ src1_row.nb[3] = nb11;
15628
+
15629
+ dst_row.ne[1] = 1;
15630
+ dst_row.ne[2] = 1;
15631
+ dst_row.ne[3] = 1;
15632
+ dst_row.nb[2] = nb1;
15633
+ dst_row.nb[3] = nb1;
15634
+ if (ne12 == 1) {
15635
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15636
+ for (int64_t id = 0; id < n_ids; id++) {
15637
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15638
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
15639
+
15640
+ const int64_t i11 = id % ne11;
15641
+ const int64_t i12 = iid1;
15642
+
15643
+ const int64_t i1 = id;
15644
+ const int64_t i2 = i12;
15961
15645
 
15962
15646
  src0_row_extra.data_device[g_main_device] =
15963
- src0_original + row_id * src0->nb[2];
15647
+ src0_original + i02*nb02;
15964
15648
  src1_row_extra.data_device[g_main_device] =
15965
- src1_original + i01 * src1->nb[1];
15649
+ src1_original + + i11*nb11 + i12*nb12;
15966
15650
  dst_row_extra.data_device[g_main_device] =
15967
- dst_original + i01 * dst->nb[1];
15651
+ dst_original + i1*nb1 + i2*nb2;
15968
15652
 
15969
15653
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15654
+ }
15970
15655
  }
15971
15656
  } else {
15972
15657
  sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15975,64 +15660,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15975
15660
  src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
15976
15661
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
15977
15662
 
15978
- for (int32_t row_id = 0; row_id < n_as; ++row_id) {
15663
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
15979
15664
  int64_t num_src1_rows = 0;
15980
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15981
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15665
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15666
+ for (int64_t id = 0; id < n_ids; id++) {
15667
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15982
15668
 
15983
- if (row_id_i != row_id) {
15984
- continue;
15985
- }
15669
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
15986
15670
 
15987
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15671
+ if (row_id_i != i02) {
15672
+ continue;
15673
+ }
15988
15674
 
15989
- SYCL_CHECK(CHECK_TRY_ERROR(
15990
- stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15991
- src1_original + i01 * nb11, nb11)));
15992
- num_src1_rows++;
15675
+ num_src1_rows++;
15676
+ }
15993
15677
  }
15994
15678
 
15995
15679
  if (num_src1_rows == 0) {
15996
15680
  continue;
15997
15681
  }
15998
15682
 
15999
- src0_row_extra.data_device[g_main_device] =
16000
- src0_original + row_id * src0->nb[2];
16001
15683
 
15684
+ sycl_pool_alloc<int> dev_cur_src1_row(1);
15685
+ sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
15686
+ SYCL_CHECK(CHECK_TRY_ERROR(
15687
+ stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
15688
+
15689
+ {
15690
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
15691
+ sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
15692
+ stream->submit([&](sycl::handler &cgh) {
15693
+ sycl::local_accessor<int, 0> src1_row_acc(cgh);
15694
+
15695
+ char *__restrict src1_contiguous_get =
15696
+ src1_contiguous.get();
15697
+ int *__restrict dev_cur_src1_row_get =
15698
+ dev_cur_src1_row.get();
15699
+ mmid_row_mapping *__restrict dev_row_mapping_get =
15700
+ dev_row_mapping.get();
15701
+ size_t ids_nb_ct6 = ids->nb[1];
15702
+ size_t ids_nb_ct7 = ids->nb[0];
15703
+
15704
+ cgh.parallel_for(
15705
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15706
+ [=](sycl::nd_item<3> item_ct1) {
15707
+ k_copy_src1_to_contiguous(
15708
+ src1_original, src1_contiguous_get,
15709
+ dev_cur_src1_row_get,
15710
+ dev_row_mapping_get, ids_dev, i02,
15711
+ ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
15712
+ item_ct1, src1_row_acc);
15713
+ });
15714
+ });
15715
+ }
15716
+
15717
+ src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
15718
+
15719
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
15720
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
16002
15721
  src1_row.ne[1] = num_src1_rows;
16003
- dst_row.ne[1] = num_src1_rows;
16004
15722
 
16005
15723
  src1_row.nb[1] = nb11;
16006
15724
  src1_row.nb[2] = num_src1_rows*nb11;
16007
15725
  src1_row.nb[3] = num_src1_rows*nb11;
16008
15726
 
15727
+ dst_row.ne[1] = num_src1_rows;
16009
15728
  dst_row.nb[1] = nb1;
16010
15729
  dst_row.nb[2] = num_src1_rows*nb1;
16011
15730
  dst_row.nb[3] = num_src1_rows*nb1;
16012
15731
 
16013
15732
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
16014
15733
 
16015
- num_src1_rows = 0;
16016
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
16017
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
16018
-
16019
- if (row_id_i != row_id) {
16020
- continue;
16021
- }
16022
-
16023
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
16024
-
16025
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
16026
- dst_original + i01 * nb1,
16027
- dst_contiguous.get() + num_src1_rows * nb1, nb1)));
16028
- num_src1_rows++;
15734
+ {
15735
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
15736
+ sycl::range<3> grid_dims(1, 1, num_src1_rows);
15737
+ stream->submit([&](sycl::handler &cgh) {
15738
+ const char *__restrict dst_contiguous_get =
15739
+ dst_contiguous.get();
15740
+ const mmid_row_mapping *__restrict dev_row_mapping_get =
15741
+ dev_row_mapping.get();
15742
+
15743
+ cgh.parallel_for(
15744
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15745
+ [=](sycl::nd_item<3> item_ct1) {
15746
+ k_copy_dst_from_contiguous(dst_original,
15747
+ dst_contiguous_get,
15748
+ dev_row_mapping_get,
15749
+ ne0, nb1, nb2, item_ct1);
15750
+ });
15751
+ });
16029
15752
  }
16030
15753
  }
16031
15754
  }
16032
-
16033
- if (dst->backend == GGML_BACKEND_TYPE_CPU) {
16034
- SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
16035
- }
16036
15755
  }
16037
15756
  catch (sycl::exception const &exc) {
16038
15757
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -17015,10 +16734,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
17015
16734
  UNUSED(buffer);
17016
16735
  }
17017
16736
 
17018
- // unused at the moment
17019
- //static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
17020
- // return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
17021
- //}
16737
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16738
+ return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16739
+ }
17022
16740
 
17023
16741
  GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
17024
16742
  ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;