llama_cpp 0.15.2 → 0.15.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2944,6 +2944,57 @@ namespace dpct
2944
2944
  using shared_memory = detail::device_memory<T, shared, Dimension>;
2945
2945
 
2946
2946
 
2947
+ template <typename T,
2948
+ sycl::access::address_space addressSpace =
2949
+ sycl::access::address_space::global_space,
2950
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2951
+ sycl::memory_scope memoryScope = sycl::memory_scope::device>
2952
+ inline T atomic_fetch_add(T *addr, T operand) {
2953
+ auto atm =
2954
+ sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2955
+ return atm.fetch_add(operand);
2956
+ }
2957
+
2958
+ template <sycl::access::address_space addressSpace =
2959
+ sycl::access::address_space::global_space,
2960
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2961
+ sycl::memory_scope memoryScope = sycl::memory_scope::device,
2962
+ typename T1, typename T2>
2963
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2964
+ auto atm =
2965
+ sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2966
+ return atm.fetch_add(operand);
2967
+ }
2968
+
2969
+ template <typename T, sycl::access::address_space addressSpace =
2970
+ sycl::access::address_space::global_space>
2971
+ inline T atomic_fetch_add(T *addr, T operand,
2972
+ sycl::memory_order memoryOrder) {
2973
+ switch (memoryOrder) {
2974
+ case sycl::memory_order::relaxed:
2975
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2976
+ sycl::memory_scope::device>(addr, operand);
2977
+ case sycl::memory_order::acq_rel:
2978
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2979
+ sycl::memory_scope::device>(addr, operand);
2980
+ case sycl::memory_order::seq_cst:
2981
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2982
+ sycl::memory_scope::device>(addr, operand);
2983
+ default:
2984
+ assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2985
+ "atomics are: sycl::memory_order::relaxed, "
2986
+ "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2987
+ }
2988
+ }
2989
+
2990
+ template <sycl::access::address_space addressSpace =
2991
+ sycl::access::address_space::global_space,
2992
+ typename T1, typename T2>
2993
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2994
+ sycl::memory_order memoryOrder) {
2995
+ atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2996
+ }
2997
+
2947
2998
  } // COPY from DPCT head files
2948
2999
 
2949
3000
  #define GGML_COMMON_DECL_SYCL
@@ -2971,20 +3022,19 @@ static int g_work_group_size = 0;
2971
3022
  // typedef sycl::half ggml_fp16_t;
2972
3023
 
2973
3024
  #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
2974
- #define VER_4VEC 610 //todo for hardward optimize.
3025
+ #define VER_4VEC 130 //todo for hardward optimize.
2975
3026
  #define VER_GEN9 700 //todo for hardward optimize.
2976
3027
  #define VER_GEN12 1000000 //todo for hardward optimize.
2977
3028
  #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
2978
3029
 
2979
3030
  #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares
2980
3031
 
2981
-
2982
- //define for XMX in Intel GPU
2983
- //TODO: currently, it's not used for XMX really.
2984
- #define SYCL_USE_XMX
3032
+ #if !defined(GGML_SYCL_FORCE_MMQ)
3033
+ #define SYCL_USE_XMX
3034
+ #endif
2985
3035
 
2986
3036
  // max batch size to use MMQ kernels when tensor cores are available
2987
- #define XMX_MAX_BATCH_SIZE 32
3037
+ #define MMQ_MAX_BATCH_SIZE 32
2988
3038
 
2989
3039
 
2990
3040
  #if defined(_MSC_VER)
@@ -3060,6 +3110,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
3060
3110
  bool ggml_backend_is_sycl(ggml_backend_t backend);
3061
3111
  int ggml_backend_sycl_get_device(ggml_backend_t backend);
3062
3112
  int get_main_device();
3113
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
3063
3114
  void print_ggml_tensor(const char*name, struct ggml_tensor *src);
3064
3115
  void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
3065
3116
 
@@ -3847,21 +3898,27 @@ static void concat_f32(const float *x,const float *y, float *dst, const int ne
3847
3898
  }
3848
3899
  }
3849
3900
 
3850
- static void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor,
3851
- const sycl::nd_item<3> &item_ct1) {
3852
- int ne0 = ne00 * scale_factor;
3853
- int nidx = item_ct1.get_local_id(2) +
3854
- item_ct1.get_group(2) * item_ct1.get_local_range(2);
3855
- if (nidx >= ne0) {
3901
+ static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
3902
+ const int nb02, const int nb03, const int ne10, const int ne11,
3903
+ const int ne12, const int ne13, const float sf0, const float sf1,
3904
+ const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
3905
+ int index = item_ct1.get_local_id(0) +
3906
+ item_ct1.get_group(0) * item_ct1.get_local_range(0);
3907
+ if (index >= ne10 * ne11 * ne12 * ne13) {
3856
3908
  return;
3857
3909
  }
3858
3910
  // operation
3859
- int i00 = nidx / scale_factor;
3860
- int i01 = item_ct1.get_group(1) / scale_factor;
3861
- int offset_src = i00 + i01 * ne00 + item_ct1.get_group(0) * nb02;
3862
- int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
3863
- item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
3864
- dst[offset_dst] = x[offset_src];
3911
+ int i10 = index % ne10;
3912
+ int i11 = (index / ne10) % ne11;
3913
+ int i12 = (index / (ne10 * ne11)) % ne12;
3914
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
3915
+
3916
+ int i00 = i10 / sf0;
3917
+ int i01 = i11 / sf1;
3918
+ int i02 = i12 / sf2;
3919
+ int i03 = i13 / sf3;
3920
+
3921
+ dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
3865
3922
  }
3866
3923
 
3867
3924
  static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
@@ -4191,7 +4248,6 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
4191
4248
  const block_q2_K * x = (const block_q2_K *) vx;
4192
4249
 
4193
4250
  const int tid = item_ct1.get_local_id(2);
4194
- #if QK_K == 256
4195
4251
  const int n = tid/32;
4196
4252
  const int l = tid - 32*n;
4197
4253
  const int is = 8*n + l/16;
@@ -4205,18 +4261,6 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
4205
4261
  y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
4206
4262
  y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
4207
4263
  y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
4208
- #else
4209
- const int is = tid/16; // 0 or 1
4210
- const int il = tid%16; // 0...15
4211
- const uint8_t q = x[i].qs[il] >> (2*is);
4212
- dst_t * y = yy + i*QK_K + 16*is + il;
4213
-
4214
- float dall = x[i].dm[0];
4215
- float dmin = x[i].dm[1];
4216
- y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
4217
- y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
4218
- #endif
4219
-
4220
4264
  }
4221
4265
 
4222
4266
  template<typename dst_t>
@@ -4226,7 +4270,6 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
4226
4270
  const int i = item_ct1.get_group(2);
4227
4271
  const block_q3_K * x = (const block_q3_K *) vx;
4228
4272
 
4229
- #if QK_K == 256
4230
4273
  const int r = item_ct1.get_local_id(2) / 4;
4231
4274
  const int tid = r/2;
4232
4275
  const int is0 = r%2;
@@ -4250,31 +4293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
4250
4293
  const uint8_t * hm = x[i].hmask;
4251
4294
 
4252
4295
  for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
4253
- #else
4254
- const int tid = item_ct1.get_local_id(2);
4255
- const int is = tid/16; // 0 or 1
4256
- const int il = tid%16; // 0...15
4257
- const int im = il/8; // 0...1
4258
- const int in = il%8; // 0...7
4259
-
4260
- dst_t * y = yy + i*QK_K + 16*is + il;
4261
-
4262
- const uint8_t q = x[i].qs[il] >> (2*is);
4263
- const uint8_t h = x[i].hmask[in] >> (2*is + im);
4264
- const float d = (float)x[i].d;
4265
-
4266
- if (is == 0) {
4267
- y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
4268
- y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
4269
- } else {
4270
- y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
4271
- y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
4272
- }
4273
- #endif
4274
-
4275
4296
  }
4276
4297
 
4277
- #if QK_K == 256
4278
4298
  static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
4279
4299
  if (j < 4) {
4280
4300
  d = q[j] & 63; m = q[j + 4] & 63;
@@ -4283,7 +4303,6 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
4283
4303
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
4284
4304
  }
4285
4305
  }
4286
- #endif
4287
4306
 
4288
4307
  template<typename dst_t>
4289
4308
  static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
@@ -4292,7 +4311,6 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
4292
4311
 
4293
4312
  const int i = item_ct1.get_group(2);
4294
4313
 
4295
- #if QK_K == 256
4296
4314
  // assume 32 threads
4297
4315
  const int tid = item_ct1.get_local_id(2);
4298
4316
  const int il = tid/8;
@@ -4316,15 +4334,6 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
4316
4334
  y[l + 0] = d1 * (q[l] & 0xF) - m1;
4317
4335
  y[l +32] = d2 * (q[l] >> 4) - m2;
4318
4336
  }
4319
- #else
4320
- const int tid = item_ct1.get_local_id(2);
4321
- const uint8_t * q = x[i].qs;
4322
- dst_t * y = yy + i*QK_K;
4323
- const float d = (float)x[i].dm[0];
4324
- const float m = (float)x[i].dm[1];
4325
- y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
4326
- y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
4327
- #endif
4328
4337
  }
4329
4338
 
4330
4339
  template<typename dst_t>
@@ -4334,7 +4343,6 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
4334
4343
 
4335
4344
  const int i = item_ct1.get_group(2);
4336
4345
 
4337
- #if QK_K == 256
4338
4346
  // assume 64 threads - this is very slightly better than the one below
4339
4347
  const int tid = item_ct1.get_local_id(2);
4340
4348
  const int il = tid/16; // il is in 0...3
@@ -4361,18 +4369,6 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
4361
4369
  hm <<= 1;
4362
4370
  y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
4363
4371
  y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
4364
- #else
4365
- const int tid = item_ct1.get_local_id(2);
4366
- const uint8_t q = x[i].qs[tid];
4367
- const int im = tid/8; // 0...3
4368
- const int in = tid%8; // 0...7
4369
- const int is = tid/16; // 0 or 1
4370
- const uint8_t h = x[i].qh[in] >> im;
4371
- const float d = x[i].d;
4372
- dst_t * y = yy + i*QK_K + tid;
4373
- y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
4374
- y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
4375
- #endif
4376
4372
  }
4377
4373
 
4378
4374
  template<typename dst_t>
@@ -4381,7 +4377,6 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
4381
4377
  const block_q6_K * x = (const block_q6_K *) vx;
4382
4378
 
4383
4379
  const int i = item_ct1.get_group(2);
4384
- #if QK_K == 256
4385
4380
 
4386
4381
  // assume 64 threads - this is very slightly better than the one below
4387
4382
  const int tid = item_ct1.get_local_id(2);
@@ -4401,24 +4396,6 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
4401
4396
  y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
4402
4397
  y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
4403
4398
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
4404
- #else
4405
-
4406
- // assume 32 threads
4407
- const int tid = item_ct1.get_local_id(2);
4408
- const int ip = tid/16; // 0 or 1
4409
- const int il = tid - 16*ip; // 0...15
4410
-
4411
- dst_t * y = yy + i*QK_K + 16*ip + il;
4412
-
4413
- const float d = x[i].d;
4414
-
4415
- const uint8_t ql = x[i].ql[16*ip + il];
4416
- const uint8_t qh = x[i].qh[il] >> (2*ip);
4417
- const int8_t * sc = x[i].scales;
4418
-
4419
- y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
4420
- y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
4421
- #endif
4422
4399
  }
4423
4400
 
4424
4401
  template<typename dst_t>
@@ -4432,7 +4409,6 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
4432
4409
  const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
4433
4410
 
4434
4411
  const int tid = item_ct1.get_local_id(2);
4435
- #if QK_K == 256
4436
4412
  const int il = tid/8; // 0...3
4437
4413
  const int ib = tid%8; // 0...7
4438
4414
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4443,10 +4419,6 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
4443
4419
  const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
4444
4420
  const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127];
4445
4421
  for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f);
4446
- #else
4447
- assert(false);
4448
- #endif
4449
-
4450
4422
  }
4451
4423
 
4452
4424
  template<typename dst_t>
@@ -4460,7 +4432,6 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
4460
4432
  const block_iq2_xs * x = (const block_iq2_xs *) vx;
4461
4433
 
4462
4434
  const int tid = item_ct1.get_local_id(2);
4463
- #if QK_K == 256
4464
4435
  const int il = tid/8; // 0...3
4465
4436
  const int ib = tid%8; // 0...7
4466
4437
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4469,10 +4440,6 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
4469
4440
  const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
4470
4441
  const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
4471
4442
  for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
4472
- #else
4473
- assert(false);
4474
- #endif
4475
-
4476
4443
  }
4477
4444
 
4478
4445
  template <typename dst_t>
@@ -4484,7 +4451,6 @@ dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4484
4451
  const block_iq2_s * x = (const block_iq2_s *) vx;
4485
4452
 
4486
4453
  const int tid = item_ct1.get_local_id(2);
4487
- #if QK_K == 256
4488
4454
  const int il = tid/8; // 0...3
4489
4455
  const int ib = tid%8; // 0...7
4490
4456
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4492,13 +4458,9 @@ dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4492
4458
  const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
4493
4459
  const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
4494
4460
  #pragma unroll
4495
- for (int j = 0; j < 8; ++j)
4461
+ for (int j = 0; j < 8; ++j) {
4496
4462
  y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
4497
- #else
4498
- assert(false);
4499
-
4500
- #endif
4501
-
4463
+ }
4502
4464
  }
4503
4465
 
4504
4466
  template<typename dst_t>
@@ -4512,7 +4474,6 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
4512
4474
  const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
4513
4475
 
4514
4476
  const int tid = item_ct1.get_local_id(2);
4515
- #if QK_K == 256
4516
4477
  const int il = tid/8; // 0...3
4517
4478
  const int ib = tid%8; // 0...7
4518
4479
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4527,10 +4488,6 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
4527
4488
  y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4528
4489
  y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4529
4490
  }
4530
- #else
4531
- assert(false);
4532
- #endif
4533
-
4534
4491
  }
4535
4492
 
4536
4493
  template <typename dst_t>
@@ -4543,7 +4500,6 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4543
4500
  const block_iq3_s * x = (const block_iq3_s *) vx;
4544
4501
 
4545
4502
  const int tid = item_ct1.get_local_id(2);
4546
- #if QK_K == 256
4547
4503
  const int il = tid/8; // 0...3
4548
4504
  const int ib = tid%8; // 0...7
4549
4505
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4557,10 +4513,6 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4557
4513
  y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4558
4514
  y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4559
4515
  }
4560
- #else
4561
- assert(false);
4562
- #endif
4563
-
4564
4516
  }
4565
4517
 
4566
4518
  template <typename dst_t>
@@ -4573,7 +4525,6 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4573
4525
  const block_iq1_s * x = (const block_iq1_s *) vx;
4574
4526
 
4575
4527
  const int tid = item_ct1.get_local_id(2);
4576
- #if QK_K == 256
4577
4528
  const int il = tid/8; // 0...3
4578
4529
  const int ib = tid%8; // 0...7
4579
4530
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4587,10 +4538,6 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4587
4538
  for (int j = 0; j < 8; ++j) {
4588
4539
  y[j] = d * (q[j] + delta);
4589
4540
  }
4590
- #else
4591
- assert(false);
4592
- #endif
4593
-
4594
4541
  }
4595
4542
 
4596
4543
  template <typename dst_t>
@@ -4603,7 +4550,6 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
4603
4550
  const block_iq1_m * x = (const block_iq1_m *) vx;
4604
4551
 
4605
4552
  const int tid = item_ct1.get_local_id(2);
4606
- #if QK_K == 256
4607
4553
  const int il = tid/8; // 0...3
4608
4554
  const int ib = tid%8; // 0...7
4609
4555
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
@@ -4621,10 +4567,6 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
4621
4567
  for (int j = 0; j < 8; ++j) {
4622
4568
  y[j] = d * (q[j] + delta);
4623
4569
  }
4624
- #else
4625
- assert(false);
4626
- #endif
4627
-
4628
4570
  }
4629
4571
 
4630
4572
  template <typename dst_t>
@@ -4698,7 +4640,6 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
4698
4640
 
4699
4641
  float tmp = 0; // partial sum for thread in warp
4700
4642
 
4701
- #if QK_K == 256
4702
4643
  const int tid =
4703
4644
  item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15
4704
4645
  const int ix =
@@ -4749,42 +4690,6 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
4749
4690
  tmp += dall * sum1 - dmin * sum2;
4750
4691
 
4751
4692
  }
4752
- #else
4753
- const int tid = item_ct1.get_local_id(2) /
4754
- (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7
4755
- const int ix = item_ct1.get_local_id(2) %
4756
- (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3
4757
- const int offset = tid * K_QUANTS_PER_ITERATION;
4758
-
4759
- uint32_t uaux[2];
4760
- const uint8_t * d = (const uint8_t *)uaux;
4761
-
4762
-
4763
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
4764
-
4765
- const float * y = yy + i * QK_K + offset;
4766
- const uint8_t * q = x[i].qs + offset;
4767
- const uint32_t * s = (const uint32_t *)x[i].scales;
4768
-
4769
- uaux[0] = s[0] & 0x0f0f0f0f;
4770
- uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
4771
-
4772
- const sycl::float2 dall =
4773
- x[i].dm.convert<float, sycl::rounding_mode::automatic>();
4774
-
4775
- float sum1 = 0, sum2 = 0;
4776
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
4777
- const uint8_t ql = q[l];
4778
- sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
4779
- + y[l+16] * d[1] * ((ql >> 2) & 3)
4780
- + y[l+32] * d[2] * ((ql >> 4) & 3)
4781
- + y[l+48] * d[3] * ((ql >> 6) & 3);
4782
- sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
4783
- }
4784
- tmp += dall.x() * sum1 - dall.y() * sum2;
4785
- }
4786
-
4787
- #endif
4788
4693
 
4789
4694
  // sum up partial sums and write back result
4790
4695
  #pragma unroll
@@ -4822,8 +4727,6 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
4822
4727
 
4823
4728
  float tmp = 0; // partial sum for thread in warp
4824
4729
 
4825
- #if QK_K == 256
4826
-
4827
4730
  const uint16_t kmask1 = 0x0303;
4828
4731
  const uint16_t kmask2 = 0x0f0f;
4829
4732
 
@@ -4876,34 +4779,6 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
4876
4779
  tmp += d * sum;
4877
4780
 
4878
4781
  }
4879
- #else
4880
-
4881
- const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
4882
- const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
4883
- const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
4884
- const int in = offset/8; // 0 or 1
4885
- const int im = offset%8; // 0...7
4886
-
4887
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
4888
-
4889
- const float * y = yy + i * QK_K + offset;
4890
- const uint8_t * q = x[i].qs + offset;
4891
- const uint8_t * s = x[i].scales;
4892
-
4893
- const float dall = (float)x[i].d;
4894
-
4895
- float sum = 0;
4896
- for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
4897
- const uint8_t hl = x[i].hmask[im+l] >> in;
4898
- const uint8_t ql = q[l];
4899
- sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
4900
- + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
4901
- + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
4902
- + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
4903
- }
4904
- tmp += sum;
4905
- }
4906
- #endif
4907
4782
 
4908
4783
  // sum up partial sums and write back result
4909
4784
  #pragma unroll
@@ -4938,7 +4813,6 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
4938
4813
 
4939
4814
  const block_q4_K * x = (const block_q4_K *)vx + ib0;
4940
4815
 
4941
- #if QK_K == 256
4942
4816
  const uint16_t kmask1 = 0x3f3f;
4943
4817
  const uint16_t kmask2 = 0x0f0f;
4944
4818
  const uint16_t kmask3 = 0xc0c0;
@@ -5027,36 +4901,6 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
5027
4901
  #endif
5028
4902
 
5029
4903
  }
5030
- #else
5031
- const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
5032
- const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
5033
-
5034
- const int step = tid * K_QUANTS_PER_ITERATION;
5035
-
5036
- uint16_t aux16[2];
5037
- const uint8_t * s = (const uint8_t *)aux16;
5038
-
5039
- float tmp = 0;
5040
-
5041
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
5042
- const uint8_t * q = x[i].qs + step;
5043
- const float * y = yy + i*QK_K + step;
5044
- const uint16_t * a = (const uint16_t *)x[i].scales;
5045
- aux16[0] = a[0] & 0x0f0f;
5046
- aux16[1] = (a[0] >> 4) & 0x0f0f;
5047
- const float d = (float)x[i].dm[0];
5048
- const float m = (float)x[i].dm[1];
5049
- float sum = 0.f;
5050
- for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
5051
- sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
5052
- + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
5053
- + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
5054
- + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
5055
- }
5056
- tmp += sum;
5057
- }
5058
-
5059
- #endif
5060
4904
 
5061
4905
  // sum up partial sums and write back result
5062
4906
  #pragma unroll
@@ -5091,7 +4935,6 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
5091
4935
 
5092
4936
  float tmp = 0; // partial sum for thread in warp
5093
4937
 
5094
- #if QK_K == 256
5095
4938
  const uint16_t kmask1 = 0x3f3f;
5096
4939
  const uint16_t kmask2 = 0x0f0f;
5097
4940
  const uint16_t kmask3 = 0xc0c0;
@@ -5168,30 +5011,6 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
5168
5011
  dmin * smin;
5169
5012
  }
5170
5013
 
5171
- #else
5172
- const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
5173
- const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
5174
- const int step = tid * K_QUANTS_PER_ITERATION;
5175
- const int im = step/8;
5176
- const int in = step%8;
5177
-
5178
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
5179
- const uint8_t * q = x[i].qs + step;
5180
- const int8_t * s = x[i].scales;
5181
- const float * y = yy + i*QK_K + step;
5182
- const float d = x[i].d;
5183
- float sum = 0.f;
5184
- for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
5185
- const uint8_t h = x[i].qh[in+j] >> im;
5186
- sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
5187
- + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
5188
- + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
5189
- + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
5190
- }
5191
- tmp += sum;
5192
- }
5193
- #endif
5194
-
5195
5014
  // sum up partial sums and write back result
5196
5015
  #pragma unroll
5197
5016
  for (int mask = 16; mask > 0; mask >>= 1) {
@@ -5218,8 +5037,6 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
5218
5037
 
5219
5038
  const block_q6_K * x = (const block_q6_K *)vx + ib0;
5220
5039
 
5221
- #if QK_K == 256
5222
-
5223
5040
  const int tid =
5224
5041
  item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
5225
5042
  const int ix =
@@ -5276,37 +5093,6 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
5276
5093
 
5277
5094
  }
5278
5095
 
5279
- #else
5280
-
5281
- const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7
5282
- const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3
5283
-
5284
- const int step = tid * K_QUANTS_PER_ITERATION;
5285
-
5286
- float tmp = 0; // partial sum for thread in warp
5287
-
5288
- for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
5289
-
5290
- const float * y = yy + i * QK_K + step;
5291
- const uint8_t * ql = x[i].ql + step;
5292
- const uint8_t * qh = x[i].qh + step;
5293
- const int8_t * s = x[i].scales;
5294
-
5295
- const float d = x[i+0].d;
5296
-
5297
- float sum = 0;
5298
- for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
5299
- sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
5300
- + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
5301
- + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
5302
- + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
5303
- }
5304
- tmp += sum;
5305
-
5306
- }
5307
-
5308
- #endif
5309
-
5310
5096
  // sum up partial sums and write back result
5311
5097
  #pragma unroll
5312
5098
  for (int mask = 16; mask > 0; mask >>= 1) {
@@ -6851,7 +6637,6 @@ static __dpct_inline__ float
6851
6637
  vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
6852
6638
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
6853
6639
 
6854
- #ifndef GGML_QKK_64
6855
6640
  const block_q4_K * bq4_K = (const block_q4_K *) vbq;
6856
6641
 
6857
6642
  int v[2];
@@ -6893,52 +6678,6 @@ vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
6893
6678
  }
6894
6679
 
6895
6680
  return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
6896
-
6897
- #else
6898
-
6899
- #if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
6900
- const block_q4_K * bq4_K = (const block_q4_K *) vbq;
6901
-
6902
- float sumf_d = 0.0f;
6903
- float sumf_m = 0.0f;
6904
-
6905
- uint16_t aux16[2];
6906
- const uint8_t * s = (const uint8_t *)aux16;
6907
-
6908
- const uint16_t * a = (const uint16_t *)bq4_K->scales;
6909
- aux16[0] = a[0] & 0x0f0f;
6910
- aux16[1] = (a[0] >> 4) & 0x0f0f;
6911
-
6912
- const float dall = bq4_K->dm[0];
6913
- const float dmin = bq4_K->dm[1];
6914
-
6915
- const float d8_1 = bq8_1[0].ds[0];
6916
- const float d8_2 = bq8_1[1].ds[1];
6917
-
6918
- const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
6919
- const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
6920
- const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
6921
- const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
6922
-
6923
- const int * q4 = (const int *)bq4_K->qs + (iqs/2);
6924
- const int v1 = q4[0];
6925
- const int v2 = q4[4];
6926
-
6927
- const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0));
6928
- const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
6929
- const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0));
6930
- const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0));
6931
-
6932
- sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
6933
- sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
6934
-
6935
- return dall * sumf_d - dmin * sumf_m;
6936
-
6937
- #else
6938
- bad_arch();
6939
- #endif // __SYCL_ARCH__ >= VER_4VEC
6940
-
6941
- #endif
6942
6681
  }
6943
6682
 
6944
6683
  template <int mmq_y>
@@ -6997,11 +6736,7 @@ load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
6997
6736
 
6998
6737
  const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
6999
6738
 
7000
- #if QK_K == 256
7001
6739
  x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
7002
- #else
7003
- x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
7004
- #endif
7005
6740
  }
7006
6741
 
7007
6742
  #pragma unroll
@@ -7044,7 +6779,6 @@ static __dpct_inline__ float
7044
6779
  vec_dot_q5_K_q8_1(const void *__restrict__ vbq,
7045
6780
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7046
6781
 
7047
- #ifndef GGML_QKK_64
7048
6782
  const block_q5_K * bq5_K = (const block_q5_K *) vbq;
7049
6783
 
7050
6784
  int vl[2];
@@ -7086,48 +6820,6 @@ vec_dot_q5_K_q8_1(const void *__restrict__ vbq,
7086
6820
  }
7087
6821
 
7088
6822
  return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
7089
-
7090
- #else
7091
-
7092
- #if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
7093
- const block_q5_K * bq5_K = (const block_q5_K *) vbq;
7094
-
7095
- const int8_t * s = bq5_K->scales;
7096
-
7097
- const float d = bq5_K->d;
7098
-
7099
- const float d8_1 = bq8_1[0].ds[0];
7100
- const float d8_2 = bq8_1[1].ds[1];
7101
-
7102
- const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
7103
- const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
7104
- const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
7105
- const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
7106
-
7107
- const int * ql = (const int *)bq5_K->qs + (iqs/2);
7108
- const int vl1 = ql[0];
7109
- const int vl2 = ql[4];
7110
-
7111
- const int step = 4 * (iqs/2); // 0, 4, 8, 12
7112
- const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
7113
- const int in = step%8; // 0, 4, 0, 4
7114
- const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
7115
-
7116
- const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
7117
- const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
7118
- const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
7119
- const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
7120
-
7121
- const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1])
7122
- + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]);
7123
-
7124
- return d * sumf_d;
7125
-
7126
- #else
7127
- bad_arch();
7128
- #endif // __SYCL_ARCH__ >= VER_4VEC
7129
-
7130
- #endif
7131
6823
  }
7132
6824
 
7133
6825
  template <int mmq_y>
@@ -7199,9 +6891,7 @@ load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
7199
6891
 
7200
6892
  const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
7201
6893
 
7202
- #if QK_K == 256
7203
6894
  x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
7204
- #endif
7205
6895
  }
7206
6896
 
7207
6897
  #pragma unroll
@@ -7381,7 +7071,6 @@ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
7381
7071
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7382
7072
  const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs,
7383
7073
  const uint8_t *kmask_iq2xs) {
7384
- #if QK_K == 256
7385
7074
  const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
7386
7075
 
7387
7076
  #if QR2_XXS == 8
@@ -7422,10 +7111,6 @@ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
7422
7111
  }
7423
7112
  return d * (sumi1 + sumi2);
7424
7113
  #endif
7425
- #else
7426
- assert(false);
7427
- return 0.f;
7428
- #endif
7429
7114
  }
7430
7115
 
7431
7116
  static __dpct_inline__ float
@@ -7434,7 +7119,6 @@ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
7434
7119
  const uint64_t *iq2xs_grid, const uint64_t *ksigns64) {
7435
7120
  #if DPCT_COMPATIBILITY_TEMP >= \
7436
7121
  MIN_CC_DP4A // lowest compute capability for integer intrinsics
7437
- #if QK_K == 256
7438
7122
  const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
7439
7123
 
7440
7124
  const int ib32 = iqs;
@@ -7472,16 +7156,11 @@ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
7472
7156
  assert(false);
7473
7157
  return 0.f;
7474
7158
  #endif
7475
- #else
7476
- assert(false);
7477
- return 0.f;
7478
- #endif
7479
7159
  }
7480
7160
 
7481
7161
  static __dpct_inline__ float
7482
7162
  vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
7483
7163
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7484
- #if QK_K == 256
7485
7164
  const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
7486
7165
 
7487
7166
  const int ib32 = iqs;
@@ -7525,9 +7204,6 @@ vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
7525
7204
  }
7526
7205
  const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
7527
7206
  return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
7528
- #else
7529
- assert(false);
7530
- #endif
7531
7207
  }
7532
7208
 
7533
7209
  static __dpct_inline__ float
@@ -7536,7 +7212,6 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7536
7212
  const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) {
7537
7213
  #if DPCT_COMPATIBILITY_TEMP >= \
7538
7214
  MIN_CC_DP4A // lowest compute capability for integer intrinsics
7539
- #if QK_K == 256
7540
7215
  const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
7541
7216
 
7542
7217
  const int ib32 = iqs;
@@ -7564,17 +7239,12 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7564
7239
  assert(false);
7565
7240
  return 0.f;
7566
7241
  #endif
7567
- #else
7568
- assert(false);
7569
- return 0.f;
7570
- #endif
7571
7242
  }
7572
7243
 
7573
7244
  static __dpct_inline__ float
7574
7245
  vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7575
7246
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7576
7247
  const uint32_t *iq3s_grid) {
7577
- #if QK_K == 256
7578
7248
  const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
7579
7249
 
7580
7250
  const int ib32 = iqs;
@@ -7603,16 +7273,12 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7603
7273
  (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
7604
7274
  bq8_1[ib32].ds[0];
7605
7275
  return d * sumi;
7606
- #else
7607
- assert(false);
7608
- #endif
7609
7276
  }
7610
7277
 
7611
7278
  static __dpct_inline__ float
7612
7279
  vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
7613
7280
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7614
7281
  const uint32_t *iq1s_grid_gpu) {
7615
- #if QK_K == 256
7616
7282
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
7617
7283
 
7618
7284
  const int ib32 = iqs;
@@ -7631,15 +7297,11 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
7631
7297
  const float d = d1q * bq8_1[ib32].ds[0];
7632
7298
  const float m = d1q * bq8_1[ib32].ds[1];
7633
7299
  return d * sumi + m * delta;
7634
- #else
7635
- assert(false);
7636
- #endif
7637
7300
  }
7638
7301
 
7639
7302
  static __dpct_inline__ float
7640
7303
  vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
7641
7304
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7642
- #if QK_K == 256
7643
7305
  const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
7644
7306
 
7645
7307
  const int ib32 = iqs;
@@ -7664,9 +7326,6 @@ vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
7664
7326
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
7665
7327
  const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
7666
7328
  return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
7667
- #else
7668
- assert(false);
7669
- #endif
7670
7329
  }
7671
7330
 
7672
7331
  static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
@@ -7714,7 +7373,6 @@ static __dpct_inline__ float
7714
7373
  vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
7715
7374
  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7716
7375
 
7717
- #if QK_K == 256
7718
7376
  const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
7719
7377
  const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
7720
7378
 
@@ -7732,9 +7390,6 @@ vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
7732
7390
  sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
7733
7391
  }
7734
7392
  return d * (sumi1 + sumi2);
7735
- #else
7736
- assert(false);
7737
- #endif
7738
7393
  }
7739
7394
 
7740
7395
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
@@ -9226,12 +8881,11 @@ static void rope(
9226
8881
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
9227
8882
  }
9228
8883
 
9229
- template<typename T, bool has_pos>
8884
+ template<typename T, bool has_pos, bool has_freq_facs>
9230
8885
  static void rope_neox(
9231
8886
  const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
9232
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
9233
- ,
9234
- const sycl::nd_item<3> &item_ct1) {
8887
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
8888
+ const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
9235
8889
  const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
9236
8890
  item_ct1.get_local_id(1));
9237
8891
 
@@ -9259,8 +8913,10 @@ static void rope_neox(
9259
8913
  float cur_rot = inv_ndims * ic - ib;
9260
8914
 
9261
8915
  const int p = has_pos ? pos[i2] : 0;
8916
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
8917
+
9262
8918
  const float theta_base =
9263
- p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
8919
+ p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
9264
8920
 
9265
8921
  float cos_theta, sin_theta;
9266
8922
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -10085,18 +9741,17 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
10085
9741
  });
10086
9742
  }
10087
9743
 
10088
- static void upscale_f32_sycl(const float *x, float *dst, const int ne00,
10089
- const int ne01, const int ne02,
10090
- const int scale_factor, dpct::queue_ptr stream) {
10091
- int ne0 = (ne00 * scale_factor);
10092
- int num_blocks = (ne0 + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
10093
- sycl::range<3> gridDim(ne02, (ne01 * scale_factor), num_blocks);
9744
+ static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
9745
+ const int nb02, const int nb03, const int ne10, const int ne11,
9746
+ const int ne12, const int ne13, const float sf0, const float sf1,
9747
+ const float sf2, const float sf3, dpct::queue_ptr stream) {
9748
+ int dst_size = ne10 * ne11 * ne12 * ne13;
9749
+ int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
9750
+ sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
10094
9751
  stream->parallel_for(
10095
- sycl::nd_range<3>(gridDim *
10096
- sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
10097
- sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
10098
- [=](sycl::nd_item<3> item_ct1) {
10099
- upscale_f32(x, dst, ne00, ne00 * ne01, scale_factor, item_ct1);
9752
+ sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
9753
+ [=](sycl::nd_item<1> item_ct1) {
9754
+ upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
10100
9755
  });
10101
9756
  }
10102
9757
 
@@ -10198,7 +9853,6 @@ template <typename dst_t>
10198
9853
  static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
10199
9854
  dpct::queue_ptr stream) {
10200
9855
  const int nb = k / QK_K;
10201
- #if QK_K == 256
10202
9856
  {
10203
9857
  dpct::has_capability_or_fail(stream->get_device(),
10204
9858
  {sycl::aspect::fp16});
@@ -10210,27 +9864,12 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
10210
9864
  dequantize_block_q2_K(vx, y, item_ct1);
10211
9865
  });
10212
9866
  }
10213
- #else
10214
- {
10215
- dpct::has_capability_or_fail(stream->get_device(),
10216
- {sycl::aspect::fp16});
10217
-
10218
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10219
- sycl::range<3>(1, 1, 32),
10220
- sycl::range<3>(1, 1, 32)),
10221
- [=](sycl::nd_item<3> item_ct1) {
10222
- dequantize_block_q2_K(vx, y, item_ct1);
10223
- });
10224
- }
10225
-
10226
- #endif
10227
9867
  }
10228
9868
 
10229
9869
  template <typename dst_t>
10230
9870
  static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
10231
9871
  dpct::queue_ptr stream) {
10232
9872
  const int nb = k / QK_K;
10233
- #if QK_K == 256
10234
9873
  {
10235
9874
  dpct::has_capability_or_fail(stream->get_device(),
10236
9875
  {sycl::aspect::fp16});
@@ -10242,19 +9881,6 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
10242
9881
  dequantize_block_q3_K(vx, y, item_ct1);
10243
9882
  });
10244
9883
  }
10245
- #else
10246
- {
10247
- dpct::has_capability_or_fail(stream->get_device(),
10248
- {sycl::aspect::fp16});
10249
-
10250
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10251
- sycl::range<3>(1, 1, 32),
10252
- sycl::range<3>(1, 1, 32)),
10253
- [=](sycl::nd_item<3> item_ct1) {
10254
- dequantize_block_q3_K(vx, y, item_ct1);
10255
- });
10256
- }
10257
- #endif
10258
9884
  }
10259
9885
 
10260
9886
  template <typename dst_t>
@@ -10315,7 +9941,6 @@ template <typename dst_t>
10315
9941
  static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
10316
9942
  dpct::queue_ptr stream) {
10317
9943
  const int nb = k / QK_K;
10318
- #if QK_K == 256
10319
9944
  {
10320
9945
  dpct::has_capability_or_fail(stream->get_device(),
10321
9946
  {sycl::aspect::fp16});
@@ -10327,27 +9952,12 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
10327
9952
  dequantize_block_q5_K(vx, y, item_ct1);
10328
9953
  });
10329
9954
  }
10330
- #else
10331
- {
10332
- dpct::has_capability_or_fail(stream->get_device(),
10333
- {sycl::aspect::fp16});
10334
-
10335
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10336
- sycl::range<3>(1, 1, 32),
10337
- sycl::range<3>(1, 1, 32)),
10338
- [=](sycl::nd_item<3> item_ct1) {
10339
- dequantize_block_q5_K(vx, y, item_ct1);
10340
- });
10341
- }
10342
-
10343
- #endif
10344
9955
  }
10345
9956
 
10346
9957
  template <typename dst_t>
10347
9958
  static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
10348
9959
  dpct::queue_ptr stream) {
10349
9960
  const int nb = k / QK_K;
10350
- #if QK_K == 256
10351
9961
  {
10352
9962
  dpct::has_capability_or_fail(stream->get_device(),
10353
9963
  {sycl::aspect::fp16});
@@ -10359,20 +9969,6 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
10359
9969
  dequantize_block_q6_K(vx, y, item_ct1);
10360
9970
  });
10361
9971
  }
10362
- #else
10363
- {
10364
- dpct::has_capability_or_fail(stream->get_device(),
10365
- {sycl::aspect::fp16});
10366
-
10367
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10368
- sycl::range<3>(1, 1, 32),
10369
- sycl::range<3>(1, 1, 32)),
10370
- [=](sycl::nd_item<3> item_ct1) {
10371
- dequantize_block_q6_K(vx, y, item_ct1);
10372
- });
10373
- }
10374
-
10375
- #endif
10376
9972
  }
10377
9973
 
10378
9974
  template <typename dst_t>
@@ -10524,9 +10120,6 @@ template <typename dst_t>
10524
10120
  static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
10525
10121
  dpct::queue_ptr stream) {
10526
10122
  const int nb = (k + QK_K - 1) / QK_K;
10527
- #if QK_K == 64
10528
- dequantize_row_iq4_nl_sycl(vx, y, k, stream);
10529
- #else
10530
10123
  {
10531
10124
  dpct::has_capability_or_fail(stream->get_device(),
10532
10125
  {sycl::aspect::fp16});
@@ -10541,7 +10134,6 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
10541
10134
  });
10542
10135
  });
10543
10136
  }
10544
- #endif
10545
10137
  }
10546
10138
 
10547
10139
 
@@ -12046,8 +11638,6 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
12046
11638
  const int nrows_y, const int nrows_dst,
12047
11639
  dpct::queue_ptr stream) try {
12048
11640
 
12049
- #if QK_K == 256
12050
-
12051
11641
  int id;
12052
11642
  SYCL_CHECK(
12053
11643
  CHECK_TRY_ERROR(id = get_current_device_id()));
@@ -12162,7 +11752,6 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
12162
11752
  });
12163
11753
  }
12164
11754
  }
12165
- #endif
12166
11755
  }
12167
11756
  catch (sycl::exception const &exc) {
12168
11757
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -12876,7 +12465,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12876
12465
  const int32_t *pos, float freq_scale,
12877
12466
  int p_delta_rows, float freq_base, float ext_factor,
12878
12467
  float attn_factor, rope_corr_dims corr_dims,
12879
- dpct::queue_ptr stream) {
12468
+ const float * freq_factors, dpct::queue_ptr stream) {
12880
12469
  GGML_ASSERT(ncols % 2 == 0);
12881
12470
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
12882
12471
  const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12886,38 +12475,48 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12886
12475
  const float inv_ndims = -1.0f / n_dims;
12887
12476
 
12888
12477
  if (pos == nullptr) {
12889
- /*
12890
- DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
12891
- the limit. To get the device limit, query
12892
- info::device::max_work_group_size. Adjust the work-group size if needed.
12893
- */
12894
12478
  dpct::has_capability_or_fail(stream->get_device(),
12895
12479
  {sycl::aspect::fp16});
12896
-
12897
- stream->parallel_for(
12898
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12899
- [=](sycl::nd_item<3> item_ct1) {
12900
- rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
12901
- p_delta_rows, ext_factor, attn_factor,
12902
- corr_dims, theta_scale, inv_ndims,
12903
- item_ct1);
12904
- });
12480
+ if (freq_factors == nullptr) {
12481
+ stream->parallel_for(
12482
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12483
+ [=](sycl::nd_item<3> item_ct1) {
12484
+ rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
12485
+ p_delta_rows, ext_factor, attn_factor,
12486
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12487
+ item_ct1);
12488
+ });
12489
+ } else {
12490
+ stream->parallel_for(
12491
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12492
+ [=](sycl::nd_item<3> item_ct1) {
12493
+ rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
12494
+ p_delta_rows, ext_factor, attn_factor,
12495
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12496
+ item_ct1);
12497
+ });
12498
+ }
12905
12499
  } else {
12906
- /*
12907
- DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
12908
- the limit. To get the device limit, query
12909
- info::device::max_work_group_size. Adjust the work-group size if needed.
12910
- */
12911
12500
  dpct::has_capability_or_fail(stream->get_device(),
12912
12501
  {sycl::aspect::fp16});
12913
12502
 
12914
- stream->parallel_for(
12915
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12916
- [=](sycl::nd_item<3> item_ct1) {
12917
- rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
12918
- p_delta_rows, ext_factor, attn_factor,
12919
- corr_dims, theta_scale, inv_ndims, item_ct1);
12920
- });
12503
+ if (freq_factors == nullptr) {
12504
+ stream->parallel_for(
12505
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12506
+ [=](sycl::nd_item<3> item_ct1) {
12507
+ rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12508
+ p_delta_rows, ext_factor, attn_factor,
12509
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12510
+ });
12511
+ } else {
12512
+ stream->parallel_for(
12513
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12514
+ [=](sycl::nd_item<3> item_ct1) {
12515
+ rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12516
+ p_delta_rows, ext_factor, attn_factor,
12517
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12518
+ });
12519
+ }
12921
12520
  }
12922
12521
  }
12923
12522
 
@@ -13964,6 +13563,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0,
13964
13563
  const float *src0_dd, const float *src1_dd,
13965
13564
  float *dst_dd,
13966
13565
  const dpct::queue_ptr &main_stream) {
13566
+ #pragma message("TODO: generalize concat kernel for dim != 2")
13567
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
13568
+ int dim = dst->op_params[0];
13569
+ GGML_ASSERT(dim == 2);
13967
13570
 
13968
13571
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13969
13572
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13985,15 +13588,15 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,
13985
13588
 
13986
13589
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13987
13590
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
13988
- GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
13989
-
13990
- #pragma message("TODO: generalize upscale operator")
13991
- #pragma message(" https://github.com/ggerganov/ggml/pull/814")
13992
- GGML_ASSERT(false && "TODO: generalize upscale operator");
13993
13591
 
13994
- const int scale_factor = dst->op_params[0];
13592
+ const float sf0 = (float)dst->ne[0]/src0->ne[0];
13593
+ const float sf1 = (float)dst->ne[1]/src0->ne[1];
13594
+ const float sf2 = (float)dst->ne[2]/src0->ne[2];
13595
+ const float sf3 = (float)dst->ne[3]/src0->ne[3];
13995
13596
 
13996
- upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
13597
+ upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
13598
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
13599
+ main_stream);
13997
13600
 
13998
13601
  (void) src1;
13999
13602
  (void) dst;
@@ -14449,6 +14052,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14449
14052
  ggml_tensor *dst, const float *src0_dd,
14450
14053
  const float *src1_dd, float *dst_dd,
14451
14054
  const dpct::queue_ptr &main_stream) {
14055
+ const ggml_tensor * src2 = dst->src[2];
14452
14056
 
14453
14057
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
14454
14058
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -14474,6 +14078,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14474
14078
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14475
14079
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14476
14080
 
14081
+ const float * freq_factors = nullptr;
14477
14082
  const int32_t * pos = nullptr;
14478
14083
  if ((mode & 1) == 0) {
14479
14084
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14484,6 +14089,16 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14484
14089
  const bool is_neox = mode & 2;
14485
14090
  const bool is_glm = mode & 4;
14486
14091
 
14092
+ if (is_neox) {
14093
+ pos = (const int32_t *) src1_dd;
14094
+
14095
+ if (src2 != nullptr) {
14096
+ freq_factors = (const float *) src2->data;
14097
+ }
14098
+ } else {
14099
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14100
+ }
14101
+
14487
14102
  rope_corr_dims corr_dims;
14488
14103
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
14489
14104
 
@@ -14495,13 +14110,13 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14495
14110
  if (src0->type == GGML_TYPE_F32) {
14496
14111
  rope_neox_sycl(
14497
14112
  (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
14498
- attn_factor, corr_dims, main_stream
14113
+ attn_factor, corr_dims, freq_factors, main_stream
14499
14114
  );
14500
14115
  } else if (src0->type == GGML_TYPE_F16) {
14501
14116
  rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
14502
14117
  ne00, n_dims, nrows, pos, freq_scale, ne01,
14503
14118
  freq_base, ext_factor, attn_factor, corr_dims,
14504
- main_stream);
14119
+ freq_factors, main_stream);
14505
14120
  } else {
14506
14121
  GGML_ASSERT(false);
14507
14122
  }
@@ -15568,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15568
15183
  const int64_t r2 = ne12/ne02;
15569
15184
  const int64_t r3 = ne13/ne03;
15570
15185
 
15571
- if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
15186
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
15572
15187
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
15573
15188
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
15574
15189
  *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
@@ -15633,6 +15248,29 @@ catch (sycl::exception const &exc) {
15633
15248
  std::exit(1);
15634
15249
  }
15635
15250
 
15251
+ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
15252
+ // TODO: accuracy issues in MMQ
15253
+ return false;
15254
+ }
15255
+
15256
+ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
15257
+ switch (type) {
15258
+ case GGML_TYPE_Q4_0:
15259
+ case GGML_TYPE_Q4_1:
15260
+ case GGML_TYPE_Q5_0:
15261
+ case GGML_TYPE_Q5_1:
15262
+ case GGML_TYPE_Q8_0:
15263
+ case GGML_TYPE_Q2_K:
15264
+ case GGML_TYPE_Q3_K:
15265
+ case GGML_TYPE_Q4_K:
15266
+ case GGML_TYPE_Q5_K:
15267
+ case GGML_TYPE_Q6_K:
15268
+ case GGML_TYPE_F16:
15269
+ return true;
15270
+ default:
15271
+ return false;
15272
+ }
15273
+ }
15636
15274
 
15637
15275
  static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
15638
15276
  const bool all_on_device =
@@ -15649,75 +15287,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15649
15287
  }
15650
15288
  }
15651
15289
 
15652
- #ifdef SYCL_USE_XMX
15653
- const bool use_xmx = true;
15654
- #else
15655
- const bool use_xmx = false;
15656
- #endif
15290
+ // check data types and tensor shapes for custom matrix multiplication kernels:
15291
+ bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
15292
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15293
+ && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
15657
15294
 
15658
- // debug helpers
15659
- //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
15660
- //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
15661
- //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
15662
- //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
15663
- //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
15664
- //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
15295
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
15296
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15297
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
15298
+
15299
+ bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
15300
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
15301
+
15302
+ // mmvq and mmq need the __dp4a instruction which is available for gen12+
15303
+ // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
15304
+ use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
15305
+ #ifdef SYCL_USE_XMX
15306
+ use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
15307
+ #endif // SYCL_USE_XMX
15665
15308
 
15666
- if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15309
+ if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15667
15310
  // KQ single-batch
15668
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
15669
15311
  ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
15670
- } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15312
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15671
15313
  // KQV single-batch
15672
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
15673
15314
  ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
15674
- } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
15315
+ } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
15675
15316
  // KQ + KQV multi-batch
15676
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
15677
15317
  ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
15678
- } else if (src0->type == GGML_TYPE_F32) {
15679
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
15680
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15681
- } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
15682
- // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
15683
- if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
15684
- #ifdef GGML_SYCL_FORCE_DMMV
15685
- const bool use_mul_mat_vec_q = false;
15686
- #else
15687
- bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15688
- use_mul_mat_vec_q = use_mul_mat_vec_q ||
15689
- (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15690
- (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15691
- (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15692
- (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15693
-
15694
-
15695
- #endif // GGML_SYCL_FORCE_DMMV
15696
-
15697
- if (use_mul_mat_vec_q) {
15698
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15699
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15700
- } else {
15701
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
15702
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15703
- }
15704
- } else {
15705
- bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15706
-
15707
- if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
15708
- use_mul_mat_q = false;
15709
- }
15710
-
15711
- if (use_mul_mat_q) {
15712
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
15713
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15714
- } else {
15715
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
15716
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15717
- }
15718
- }
15318
+ } else if (use_dequantize_mul_mat_vec) {
15319
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15320
+ } else if (use_mul_mat_vec_q) {
15321
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15322
+ } else if (use_mul_mat_q) {
15323
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15719
15324
  } else {
15720
- GGML_ASSERT(false);
15325
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15721
15326
  }
15722
15327
  }
15723
15328
 
@@ -15894,22 +15499,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15894
15499
  }
15895
15500
  #endif
15896
15501
 
15502
+ struct mmid_row_mapping {
15503
+ int32_t i1;
15504
+ int32_t i2;
15505
+ };
15506
+
15507
+ __dpct_inline__ static void k_copy_src1_to_contiguous(
15508
+ const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
15509
+ int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
15510
+ const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
15511
+ int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
15512
+ const sycl::nd_item<3> &item_ct1, int &src1_row) {
15513
+ int32_t iid1 = item_ct1.get_group(2);
15514
+ int32_t id = item_ct1.get_group(1);
15515
+
15516
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
15517
+
15518
+ if (row_id_i != i02) {
15519
+ return;
15520
+ }
15521
+
15522
+ const int64_t i11 = id % ne11;
15523
+ const int64_t i12 = iid1;
15524
+
15525
+ if (item_ct1.get_local_id(2) == 0) {
15526
+ src1_row =
15527
+ dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
15528
+ cur_src1_row, 1);
15529
+ row_mapping[src1_row] = {id, iid1};
15530
+ }
15531
+ /*
15532
+ DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
15533
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
15534
+ performance if there is no access to global memory.
15535
+ */
15536
+ item_ct1.barrier();
15537
+
15538
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
15539
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
15540
+
15541
+ #pragma unroll
15542
+ for (int i = item_ct1.get_local_id(2); i < ne10;
15543
+ i += item_ct1.get_local_range(2)) {
15544
+ src1_row_contiguous[i] = src1_row_original[i];
15545
+ }
15546
+ }
15547
+
15548
+ __dpct_inline__ static void k_copy_dst_from_contiguous(
15549
+ char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
15550
+ const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
15551
+ size_t nb2, const sycl::nd_item<3> &item_ct1) {
15552
+ int32_t i = item_ct1.get_group(2);
15553
+
15554
+ const int32_t i1 = row_mapping[i].i1;
15555
+ const int32_t i2 = row_mapping[i].i2;
15556
+
15557
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
15558
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
15559
+
15560
+ #pragma unroll
15561
+ for (int j = item_ct1.get_local_id(2); j < ne0;
15562
+ j += item_ct1.get_local_range(2)) {
15563
+ dst_row_original[j] = dst_row_contiguous[j];
15564
+ }
15565
+ }
15566
+
15897
15567
  static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15898
15568
  const ggml_tensor *src1,
15899
15569
  ggml_tensor *dst) try {
15900
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
15901
- "mul_mat_id does not support split buffers");
15570
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
15571
+
15902
15572
  const ggml_tensor *ids = dst->src[2];
15903
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15573
+ GGML_TENSOR_BINARY_OP_LOCALS
15904
15574
 
15905
- const size_t nb11 = src1->nb[1];
15906
- const size_t nb1 = dst->nb[1];
15575
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15907
15576
 
15908
- const int32_t id = ((int32_t *)dst->op_params)[0];
15909
- const int32_t n_as = src0->ne[2];
15577
+ const int64_t n_as = ne02;
15578
+ const int64_t n_ids = ids->ne[0];
15910
15579
 
15911
15580
  std::vector<char> ids_host(ggml_nbytes(ids));
15912
- const char *ids_dev = (const char *)ids->data;
15581
+ const char * ids_dev = (const char *) ids->data;
15913
15582
 
15914
15583
  SYCL_CHECK(CHECK_TRY_ERROR(
15915
15584
  stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15949,24 +15618,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15949
15618
 
15950
15619
  src0_row.ne[2] = 1;
15951
15620
  src0_row.ne[3] = 1;
15952
- src0_row.nb[3] = src0->nb[2];
15953
-
15954
- if (src1->ne[1] == 1) {
15955
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15956
- const int32_t row_id =
15957
- *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
15958
- id * ids->nb[0]);
15959
-
15960
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15621
+ src0_row.nb[3] = nb02;
15622
+
15623
+ src1_row.ne[1] = 1;
15624
+ src1_row.ne[2] = 1;
15625
+ src1_row.ne[3] = 1;
15626
+ src1_row.nb[2] = nb11;
15627
+ src1_row.nb[3] = nb11;
15628
+
15629
+ dst_row.ne[1] = 1;
15630
+ dst_row.ne[2] = 1;
15631
+ dst_row.ne[3] = 1;
15632
+ dst_row.nb[2] = nb1;
15633
+ dst_row.nb[3] = nb1;
15634
+ if (ne12 == 1) {
15635
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15636
+ for (int64_t id = 0; id < n_ids; id++) {
15637
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15638
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
15639
+
15640
+ const int64_t i11 = id % ne11;
15641
+ const int64_t i12 = iid1;
15642
+
15643
+ const int64_t i1 = id;
15644
+ const int64_t i2 = i12;
15961
15645
 
15962
15646
  src0_row_extra.data_device[g_main_device] =
15963
- src0_original + row_id * src0->nb[2];
15647
+ src0_original + i02*nb02;
15964
15648
  src1_row_extra.data_device[g_main_device] =
15965
- src1_original + i01 * src1->nb[1];
15649
+ src1_original + + i11*nb11 + i12*nb12;
15966
15650
  dst_row_extra.data_device[g_main_device] =
15967
- dst_original + i01 * dst->nb[1];
15651
+ dst_original + i1*nb1 + i2*nb2;
15968
15652
 
15969
15653
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15654
+ }
15970
15655
  }
15971
15656
  } else {
15972
15657
  sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15975,64 +15660,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15975
15660
  src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
15976
15661
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
15977
15662
 
15978
- for (int32_t row_id = 0; row_id < n_as; ++row_id) {
15663
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
15979
15664
  int64_t num_src1_rows = 0;
15980
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15981
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15665
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15666
+ for (int64_t id = 0; id < n_ids; id++) {
15667
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15982
15668
 
15983
- if (row_id_i != row_id) {
15984
- continue;
15985
- }
15669
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
15986
15670
 
15987
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15671
+ if (row_id_i != i02) {
15672
+ continue;
15673
+ }
15988
15674
 
15989
- SYCL_CHECK(CHECK_TRY_ERROR(
15990
- stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15991
- src1_original + i01 * nb11, nb11)));
15992
- num_src1_rows++;
15675
+ num_src1_rows++;
15676
+ }
15993
15677
  }
15994
15678
 
15995
15679
  if (num_src1_rows == 0) {
15996
15680
  continue;
15997
15681
  }
15998
15682
 
15999
- src0_row_extra.data_device[g_main_device] =
16000
- src0_original + row_id * src0->nb[2];
16001
15683
 
15684
+ sycl_pool_alloc<int> dev_cur_src1_row(1);
15685
+ sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
15686
+ SYCL_CHECK(CHECK_TRY_ERROR(
15687
+ stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
15688
+
15689
+ {
15690
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
15691
+ sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
15692
+ stream->submit([&](sycl::handler &cgh) {
15693
+ sycl::local_accessor<int, 0> src1_row_acc(cgh);
15694
+
15695
+ char *__restrict src1_contiguous_get =
15696
+ src1_contiguous.get();
15697
+ int *__restrict dev_cur_src1_row_get =
15698
+ dev_cur_src1_row.get();
15699
+ mmid_row_mapping *__restrict dev_row_mapping_get =
15700
+ dev_row_mapping.get();
15701
+ size_t ids_nb_ct6 = ids->nb[1];
15702
+ size_t ids_nb_ct7 = ids->nb[0];
15703
+
15704
+ cgh.parallel_for(
15705
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15706
+ [=](sycl::nd_item<3> item_ct1) {
15707
+ k_copy_src1_to_contiguous(
15708
+ src1_original, src1_contiguous_get,
15709
+ dev_cur_src1_row_get,
15710
+ dev_row_mapping_get, ids_dev, i02,
15711
+ ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
15712
+ item_ct1, src1_row_acc);
15713
+ });
15714
+ });
15715
+ }
15716
+
15717
+ src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
15718
+
15719
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
15720
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
16002
15721
  src1_row.ne[1] = num_src1_rows;
16003
- dst_row.ne[1] = num_src1_rows;
16004
15722
 
16005
15723
  src1_row.nb[1] = nb11;
16006
15724
  src1_row.nb[2] = num_src1_rows*nb11;
16007
15725
  src1_row.nb[3] = num_src1_rows*nb11;
16008
15726
 
15727
+ dst_row.ne[1] = num_src1_rows;
16009
15728
  dst_row.nb[1] = nb1;
16010
15729
  dst_row.nb[2] = num_src1_rows*nb1;
16011
15730
  dst_row.nb[3] = num_src1_rows*nb1;
16012
15731
 
16013
15732
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
16014
15733
 
16015
- num_src1_rows = 0;
16016
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
16017
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
16018
-
16019
- if (row_id_i != row_id) {
16020
- continue;
16021
- }
16022
-
16023
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
16024
-
16025
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
16026
- dst_original + i01 * nb1,
16027
- dst_contiguous.get() + num_src1_rows * nb1, nb1)));
16028
- num_src1_rows++;
15734
+ {
15735
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
15736
+ sycl::range<3> grid_dims(1, 1, num_src1_rows);
15737
+ stream->submit([&](sycl::handler &cgh) {
15738
+ const char *__restrict dst_contiguous_get =
15739
+ dst_contiguous.get();
15740
+ const mmid_row_mapping *__restrict dev_row_mapping_get =
15741
+ dev_row_mapping.get();
15742
+
15743
+ cgh.parallel_for(
15744
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15745
+ [=](sycl::nd_item<3> item_ct1) {
15746
+ k_copy_dst_from_contiguous(dst_original,
15747
+ dst_contiguous_get,
15748
+ dev_row_mapping_get,
15749
+ ne0, nb1, nb2, item_ct1);
15750
+ });
15751
+ });
16029
15752
  }
16030
15753
  }
16031
15754
  }
16032
-
16033
- if (dst->backend == GGML_BACKEND_TYPE_CPU) {
16034
- SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
16035
- }
16036
15755
  }
16037
15756
  catch (sycl::exception const &exc) {
16038
15757
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -17015,10 +16734,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
17015
16734
  UNUSED(buffer);
17016
16735
  }
17017
16736
 
17018
- // unused at the moment
17019
- //static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
17020
- // return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
17021
- //}
16737
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16738
+ return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16739
+ }
17022
16740
 
17023
16741
  GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
17024
16742
  ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;