llama_cpp 0.14.4 → 0.14.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1664,24 +1664,6 @@ namespace dpct
1664
1664
  const void *alpha, const void *a, int lda, const void *b,
1665
1665
  int ldb, const void *beta, void *c, int ldc)
1666
1666
  {
1667
- #ifndef __INTEL_MKL__
1668
- GGML_UNUSED(q);
1669
- GGML_UNUSED(a_trans);
1670
- GGML_UNUSED(b_trans);
1671
- GGML_UNUSED(m);
1672
- GGML_UNUSED(n);
1673
- GGML_UNUSED(k);
1674
- GGML_UNUSED(alpha);
1675
- GGML_UNUSED(a);
1676
- GGML_UNUSED(lda);
1677
- GGML_UNUSED(b);
1678
- GGML_UNUSED(ldb);
1679
- GGML_UNUSED(beta);
1680
- GGML_UNUSED(c);
1681
- GGML_UNUSED(ldc);
1682
- throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
1683
- "Project does not support this API.");
1684
- #else
1685
1667
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1686
1668
  Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1687
1669
  auto data_a = get_memory<const Ta>(a);
@@ -1690,7 +1672,6 @@ namespace dpct
1690
1672
  oneapi::mkl::blas::column_major::gemm(
1691
1673
  q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1692
1674
  data_b, ldb, beta_value, data_c, ldc);
1693
- #endif
1694
1675
  }
1695
1676
 
1696
1677
  template <typename VecT, class BinaryOperation, class = void>
@@ -2330,6 +2311,7 @@ namespace dpct
2330
2311
  lda, b, ldb, beta, c, ldc);
2331
2312
  break;
2332
2313
  }
2314
+ #ifdef __INTEL_MKL__
2333
2315
  case detail::get_type_combination_id(
2334
2316
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2335
2317
  library_data_t::real_float, library_data_t::real_float):
@@ -2391,6 +2373,7 @@ namespace dpct
2391
2373
  q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
2392
2374
  break;
2393
2375
  }
2376
+ #endif // __INTEL_MKL__
2394
2377
  default:
2395
2378
  throw std::runtime_error("the combination of data type is unsupported");
2396
2379
  }
@@ -3055,6 +3038,10 @@ typedef float dfloat; // dequantize float
3055
3038
  typedef sycl::float2 dfloat2;
3056
3039
  #endif //GGML_SYCL_F16
3057
3040
 
3041
+ #define MMVQ_MAX_BATCH_SIZE 8
3042
+
3043
+ static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
3044
+
3058
3045
  bool ggml_sycl_loaded(void);
3059
3046
  void * ggml_sycl_host_malloc(size_t size);
3060
3047
  void ggml_sycl_host_free(void * ptr);
@@ -3167,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
3167
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3168
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3169
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
3170
- #define SYCL_SOFT_MAX_BLOCK_SIZE 1024
3171
3157
  #define SYCL_ALIBI_BLOCK_SIZE 32
3172
3158
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3173
3159
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
@@ -4490,6 +4476,32 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
4490
4476
 
4491
4477
  }
4492
4478
 
4479
+ template <typename dst_t>
4480
+ __dpct_inline__ static void
4481
+ dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4482
+ const sycl::nd_item<3> &item_ct1) {
4483
+
4484
+ const int i = item_ct1.get_group(2);
4485
+ const block_iq2_s * x = (const block_iq2_s *) vx;
4486
+
4487
+ const int tid = item_ct1.get_local_id(2);
4488
+ #if QK_K == 256
4489
+ const int il = tid/8; // 0...3
4490
+ const int ib = tid%8; // 0...7
4491
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4492
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
4493
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
4494
+ const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
4495
+ #pragma unroll
4496
+ for (int j = 0; j < 8; ++j)
4497
+ y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
4498
+ #else
4499
+ assert(false);
4500
+
4501
+ #endif
4502
+
4503
+ }
4504
+
4493
4505
  template<typename dst_t>
4494
4506
  static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
4495
4507
  const sycl::nd_item<3> &item_ct1,
@@ -4522,26 +4534,26 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
4522
4534
 
4523
4535
  }
4524
4536
 
4525
- template<typename dst_t>
4526
- static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
4527
- const sycl::nd_item<3> &item_ct1,
4528
- const uint32_t *iq3s_grid,
4529
- const uint8_t *ksigns_iq2xs,
4530
- const uint8_t *kmask_iq2xs) {
4537
+ template <typename dst_t>
4538
+ __dpct_inline__ static void
4539
+ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4540
+ const sycl::nd_item<3> &item_ct1,
4541
+ const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
4531
4542
 
4532
4543
  const int i = item_ct1.get_group(2);
4533
- const block_iq3_s * x = (const block_iq3_s *) vx;
4544
+ const block_iq3_s * x = (const block_iq3_s *) vx;
4534
4545
 
4535
4546
  const int tid = item_ct1.get_local_id(2);
4536
4547
  #if QK_K == 256
4537
4548
  const int il = tid/8; // 0...3
4538
4549
  const int ib = tid%8; // 0...7
4539
4550
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4540
- const uint8_t * qs = x[i].qs + 8*ib;
4541
- const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + qs[2*il+0]);
4542
- const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + qs[2*il+1]);
4551
+ const uint8_t * qs = x[i].qs + 8*ib;
4552
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
4553
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
4543
4554
  const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
4544
4555
  const uint8_t signs = x[i].signs[4*ib + il];
4556
+ #pragma unroll
4545
4557
  for (int j = 0; j < 4; ++j) {
4546
4558
  y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4547
4559
  y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
@@ -4552,12 +4564,12 @@ static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restr
4552
4564
 
4553
4565
  }
4554
4566
 
4555
- template<typename dst_t>
4556
- static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
4557
- const sycl::nd_item<3> &item_ct1,
4558
- const uint32_t *iq1s_grid,
4559
- const uint8_t *ksigns_iq2xs,
4560
- const uint8_t *kmask_iq2xs) {
4567
+ template <typename dst_t>
4568
+ __dpct_inline__ static void
4569
+ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4570
+ const sycl::nd_item<3> &item_ct1,
4571
+ const uint32_t *iq1s_grid_gpu) {
4572
+
4561
4573
  const int i = item_ct1.get_group(2);
4562
4574
  const block_iq1_s * x = (const block_iq1_s *) vx;
4563
4575
 
@@ -4566,14 +4578,49 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
4566
4578
  const int il = tid/8; // 0...3
4567
4579
  const int ib = tid%8; // 0...7
4568
4580
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4569
- const uint8_t * qs = x[i].qs + 8*ib;
4570
- const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]);
4571
- const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]);
4572
- const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
4573
- const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
4574
- for (int j = 0; j < 4; ++j) {
4575
- y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4576
- y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4581
+ const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
4582
+ const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
4583
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
4584
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
4585
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
4586
+ grid32[0] &= 0x0f0f0f0f;
4587
+ #pragma unroll
4588
+ for (int j = 0; j < 8; ++j) {
4589
+ y[j] = d * (q[j] + delta);
4590
+ }
4591
+ #else
4592
+ assert(false);
4593
+ #endif
4594
+
4595
+ }
4596
+
4597
+ template <typename dst_t>
4598
+ __dpct_inline__ static void
4599
+ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
4600
+ const sycl::nd_item<3> &item_ct1,
4601
+ const uint32_t *iq1s_grid_gpu) {
4602
+
4603
+ const int i = item_ct1.get_group(2);
4604
+ const block_iq1_m * x = (const block_iq1_m *) vx;
4605
+
4606
+ const int tid = item_ct1.get_local_id(2);
4607
+ #if QK_K == 256
4608
+ const int il = tid/8; // 0...3
4609
+ const int ib = tid%8; // 0...7
4610
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4611
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
4612
+ iq1m_scale_t scale;
4613
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4614
+ const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
4615
+ const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
4616
+ const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
4617
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
4618
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
4619
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
4620
+ grid32[0] &= 0x0f0f0f0f;
4621
+ #pragma unroll
4622
+ for (int j = 0; j < 8; ++j) {
4623
+ y[j] = d * (q[j] + delta);
4577
4624
  }
4578
4625
  #else
4579
4626
  assert(false);
@@ -4581,6 +4628,51 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
4581
4628
 
4582
4629
  }
4583
4630
 
4631
+ template <typename dst_t>
4632
+ __dpct_inline__ static void
4633
+ dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
4634
+ const sycl::nd_item<3> &item_ct1) {
4635
+
4636
+ const int i = item_ct1.get_group(2);
4637
+ const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
4638
+
4639
+ const int tid = item_ct1.get_local_id(2);
4640
+ const int il = tid/8; // 0...3
4641
+ const int ib = tid%8; // 0...7
4642
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
4643
+ const uint8_t * q4 = x[ib].qs + 4*il;
4644
+ const float d = (float)x[ib].d;
4645
+ #pragma unroll
4646
+ for (int j = 0; j < 4; ++j) {
4647
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
4648
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
4649
+ }
4650
+
4651
+ }
4652
+
4653
+
4654
+ template <typename dst_t>
4655
+ __dpct_inline__ static void
4656
+ dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
4657
+ const sycl::nd_item<3> &item_ct1) {
4658
+ const int i = item_ct1.get_group(2);
4659
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
4660
+
4661
+ const int tid = item_ct1.get_local_id(2);
4662
+ const int il = tid/8; // 0...3
4663
+ const int ib = tid%8; // 0...7
4664
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
4665
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
4666
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
4667
+ #pragma unroll
4668
+ for (int j = 0; j < 4; ++j) {
4669
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
4670
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
4671
+ }
4672
+ }
4673
+
4674
+
4675
+
4584
4676
  /*
4585
4677
  DPCT1110:4: The total declared local variable size in device function
4586
4678
  dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
@@ -7387,6 +7479,58 @@ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
7387
7479
  #endif
7388
7480
  }
7389
7481
 
7482
+ static __dpct_inline__ float
7483
+ vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
7484
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7485
+ #if QK_K == 256
7486
+ const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
7487
+
7488
+ const int ib32 = iqs;
7489
+ const int8_t * q8 = bq8_1[ib32].qs;
7490
+ const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
7491
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
7492
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
7493
+ int sumi1 = 0;
7494
+ for (int l = 0; l < 2; ++l) {
7495
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
7496
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7497
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
7498
+ std::equal_to<>());
7499
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7500
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
7501
+ std::equal_to<>());
7502
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7503
+ grid[0] ^ signs0, signs0, std::minus<>());
7504
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7505
+ grid[1] ^ signs1, signs1, std::minus<>());
7506
+ sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
7507
+ sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
7508
+ q8 += 8;
7509
+ }
7510
+ int sumi2 = 0;
7511
+ for (int l = 2; l < 4; ++l) {
7512
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
7513
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7514
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
7515
+ std::equal_to<>());
7516
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7517
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
7518
+ std::equal_to<>());
7519
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7520
+ grid[0] ^ signs0, signs0, std::minus<>());
7521
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7522
+ grid[1] ^ signs1, signs1, std::minus<>());
7523
+ sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
7524
+ sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
7525
+ q8 += 8;
7526
+ }
7527
+ const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
7528
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
7529
+ #else
7530
+ assert(false);
7531
+ #endif
7532
+ }
7533
+
7390
7534
  static __dpct_inline__ float
7391
7535
  vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7392
7536
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
@@ -7429,10 +7573,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7429
7573
 
7430
7574
  static __dpct_inline__ float
7431
7575
  vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7432
- const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7433
- const uint32_t *iq3s_grid, const uint64_t *ksigns64) {
7434
- #if DPCT_COMPATIBILITY_TEMP >= \
7435
- MIN_CC_DP4A // lowest compute capability for integer intrinsics
7576
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7577
+ const uint32_t *iq3s_grid) {
7436
7578
  #if QK_K == 256
7437
7579
  const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
7438
7580
 
@@ -7444,9 +7586,11 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7444
7586
  const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
7445
7587
  const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
7446
7588
  uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7447
- ((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
7589
+ ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,
7590
+ 0x08040201, std::equal_to<>());
7448
7591
  uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7449
- ((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
7592
+ ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,
7593
+ 0x08040201, std::equal_to<>());
7450
7594
  const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7451
7595
  grid1[0] ^ signs0, signs0, std::minus<>());
7452
7596
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
@@ -7455,45 +7599,142 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7455
7599
  sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
7456
7600
  q8 += 8;
7457
7601
  }
7458
- const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * bq8_1[ib32].ds[0];
7602
+ const float d =
7603
+ (float)bq2->d *
7604
+ (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
7605
+ bq8_1[ib32].ds[0];
7459
7606
  return d * sumi;
7460
7607
  #else
7461
7608
  assert(false);
7462
- return 0.f;
7463
- #endif
7464
- #else
7465
- assert(false);
7466
- return 0.f;
7467
7609
  #endif
7468
7610
  }
7469
7611
 
7470
7612
  static __dpct_inline__ float
7471
7613
  vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
7472
- const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7473
- const uint32_t *iq1s_grid, const uint64_t *ksigns64) {
7614
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7615
+ const uint32_t *iq1s_grid_gpu) {
7474
7616
  #if QK_K == 256
7475
7617
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
7476
7618
 
7477
7619
  const int ib32 = iqs;
7478
- const uint8_t * qs = bq1->qs + 4*ib32;
7479
- const int8_t * q8 = bq8_1[ib32].qs;
7480
7620
  int sumi = 0;
7621
+ const int * q8 = (const int *)bq8_1[ib32].qs;
7481
7622
  for (int l = 0; l < 4; ++l) {
7482
- const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]);
7483
- const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8));
7484
- const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7485
- grid[0] ^ signs[0], signs[0], std::minus<>());
7486
- const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7487
- grid[1] ^ signs[1], signs[1], std::minus<>());
7488
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
7489
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
7490
- q8 += 8;
7623
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
7624
+ int grid0 = grid[0] & 0x0f0f0f0f;
7625
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
7626
+ sumi = dpct::dp4a(q8[2 * l + 1], grid1,
7627
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi));
7628
+ }
7629
+
7630
+ const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
7631
+ const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
7632
+ const float d = d1q * bq8_1[ib32].ds[0];
7633
+ const float m = d1q * bq8_1[ib32].ds[1];
7634
+ return d * sumi + m * delta;
7635
+ #else
7636
+ assert(false);
7637
+ #endif
7638
+ }
7639
+
7640
+ static __dpct_inline__ float
7641
+ vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
7642
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7643
+ #if QK_K == 256
7644
+ const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
7645
+
7646
+ const int ib32 = iqs;
7647
+ int sumi[2] = {0, 0};
7648
+ float sumf[2] = {0.f, 0.f};
7649
+
7650
+ const int * q8 = (const int *)bq8_1[ib32].qs;
7651
+ for (int l = 0; l < 4; ++l) {
7652
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
7653
+ int grid0 = grid[0] & 0x0f0f0f0f;
7654
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
7655
+ sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,
7656
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));
7657
+ const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
7658
+ const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,
7659
+ dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));
7660
+ sumf[l/2] += delta*sumy;
7661
+ }
7662
+
7663
+ iq1m_scale_t scale;
7664
+ const uint16_t * sc = (const uint16_t *)bq1->scales;
7665
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
7666
+ const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
7667
+ return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
7668
+ #else
7669
+ assert(false);
7670
+ #endif
7671
+ }
7672
+
7673
+ static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
7674
+ const uint8_t *values,
7675
+ int &val1, int &val2) {
7676
+
7677
+ uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
7678
+ aux32 = q4 & 0x0f0f0f0f;
7679
+ uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
7680
+ uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
7681
+ val1 = v1 | (v2 << 16);
7682
+ aux32 = (q4 >> 4) & 0x0f0f0f0f;
7683
+ v1 = values[q8[0]] | (values[q8[1]] << 8);
7684
+ v2 = values[q8[2]] | (values[q8[3]] << 8);
7685
+ val2 = v1 | (v2 << 16);
7686
+ }
7687
+
7688
+
7689
+ static __dpct_inline__ float
7690
+ vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
7691
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7692
+
7693
+ const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
7694
+
7695
+ const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
7696
+ const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
7697
+
7698
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
7699
+
7700
+ int v1, v2;
7701
+ int sumi1 = 0, sumi2 = 0;
7702
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
7703
+ const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
7704
+ get_int_from_table_16(aux, values, v1, v2);
7705
+ sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);
7706
+ sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);
7491
7707
  }
7492
- const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f;
7493
- return d * sumi;
7708
+
7709
+ const float d = (float)bq->d * bq8_1->ds[0];
7710
+ return d * (sumi1 + sumi2);
7711
+ }
7712
+
7713
+
7714
+ static __dpct_inline__ float
7715
+ vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
7716
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7717
+
7718
+ #if QK_K == 256
7719
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
7720
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
7721
+
7722
+ // iqs is 0...7
7723
+ const int ib32 = iqs;
7724
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
7725
+ const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
7726
+ const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
7727
+ const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];
7728
+ int v1, v2;
7729
+ int sumi1 = 0, sumi2 = 0;
7730
+ for (int j = 0; j < 4; ++j) {
7731
+ get_int_from_table_16(q4[j], values, v1, v2);
7732
+ sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);
7733
+ sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
7734
+ }
7735
+ return d * (sumi1 + sumi2);
7494
7736
  #else
7495
7737
  assert(false);
7496
- return 0.f;
7497
7738
  #endif
7498
7739
  }
7499
7740
 
@@ -8078,8 +8319,7 @@ template <bool need_check> static void
8078
8319
 
8079
8320
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
8080
8321
  static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8081
- const sycl::nd_item<3> &item_ct1,
8082
- const uint32_t *iq3xxs_grid_ptr=nullptr, const uint64_t *ksigns64_ptr=nullptr) {
8322
+ const sycl::nd_item<3> &item_ct1) {
8083
8323
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8084
8324
  item_ct1.get_local_id(1);
8085
8325
 
@@ -8123,10 +8363,11 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
8123
8363
  }
8124
8364
 
8125
8365
  template <int qk, int qi, typename block_q_t, int vdr>
8126
- static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8127
- const sycl::nd_item<3> &item_ct1,
8128
- const uint64_t *iq2xxs_grid_ptr, const uint8_t *ksigns_iq2xs_ptr,
8129
- const uint8_t *kmask_iq2xs_ptr ) {
8366
+ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
8367
+ const void *__restrict__ vy,
8368
+ float *__restrict__ dst, const int ncols,
8369
+ const int nrows,
8370
+ const sycl::nd_item<3> &item_ct1) {
8130
8371
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8131
8372
  item_ct1.get_local_id(1);
8132
8373
 
@@ -8154,7 +8395,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void
8154
8395
  (item_ct1.get_local_id(2) %
8155
8396
  (qi / vdr)); // x block quant index when casting the quants to int
8156
8397
 
8157
- tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid_ptr, ksigns_iq2xs_ptr, kmask_iq2xs_ptr);
8398
+ tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
8158
8399
  }
8159
8400
 
8160
8401
  // sum up partial sums and write back result
@@ -8170,9 +8411,11 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void
8170
8411
  }
8171
8412
 
8172
8413
  template <int qk, int qi, typename block_q_t, int vdr>
8173
- static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8174
- const sycl::nd_item<3> &item_ct1,
8175
- const uint64_t *iq2xs_grid_ptr, const uint64_t *ksigns64_ptr ) {
8414
+ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
8415
+ const void *__restrict__ vy,
8416
+ float *__restrict__ dst, const int ncols,
8417
+ const int nrows,
8418
+ const sycl::nd_item<3> &item_ct1) {
8176
8419
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8177
8420
  item_ct1.get_local_id(1);
8178
8421
 
@@ -8200,7 +8443,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void *
8200
8443
  (item_ct1.get_local_id(2) %
8201
8444
  (qi / vdr)); // x block quant index when casting the quants to int
8202
8445
 
8203
- tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid_ptr, ksigns64_ptr);
8446
+ tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
8204
8447
  }
8205
8448
 
8206
8449
  // sum up partial sums and write back result
@@ -8216,9 +8459,11 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void *
8216
8459
  }
8217
8460
 
8218
8461
  template <int qk, int qi, typename block_q_t, int vdr>
8219
- static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8220
- const sycl::nd_item<3> &item_ct1,
8221
- const uint32_t *iq3xxs_grid_ptr, const uint64_t *ksigns64_ptr ) {
8462
+ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
8463
+ const void *__restrict__ vy,
8464
+ float *__restrict__ dst, const int ncols,
8465
+ const int nrows,
8466
+ const sycl::nd_item<3> &item_ct1) {
8222
8467
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8223
8468
  item_ct1.get_local_id(1);
8224
8469
 
@@ -8246,7 +8491,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
8246
8491
  (item_ct1.get_local_id(2) %
8247
8492
  (qi / vdr)); // x block quant index when casting the quants to int
8248
8493
 
8249
- tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid_ptr, ksigns64_ptr);
8494
+ tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
8250
8495
  }
8251
8496
 
8252
8497
  // sum up partial sums and write back result
@@ -8262,9 +8507,11 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
8262
8507
  }
8263
8508
 
8264
8509
  template <int qk, int qi, typename block_q_t, int vdr>
8265
- static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8266
- const sycl::nd_item<3> &item_ct1,
8267
- const uint32_t *iq3s_grid_ptr, const uint64_t *ksigns64_ptr ) {
8510
+ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
8511
+ const void *__restrict__ vy,
8512
+ float *__restrict__ dst, const int ncols,
8513
+ const int nrows,
8514
+ const sycl::nd_item<3> &item_ct1) {
8268
8515
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8269
8516
  item_ct1.get_local_id(1);
8270
8517
 
@@ -8292,7 +8539,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
8292
8539
  (item_ct1.get_local_id(2) %
8293
8540
  (qi / vdr)); // x block quant index when casting the quants to int
8294
8541
 
8295
- tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid_ptr, ksigns64_ptr);
8542
+ tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
8296
8543
  }
8297
8544
 
8298
8545
  // sum up partial sums and write back result
@@ -8308,9 +8555,11 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
8308
8555
  }
8309
8556
 
8310
8557
  template <int qk, int qi, typename block_q_t, int vdr>
8311
- static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8312
- const sycl::nd_item<3> &item_ct1,
8313
- const uint32_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) {
8558
+ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
8559
+ const void *__restrict__ vy,
8560
+ float *__restrict__ dst, const int ncols,
8561
+ const int nrows,
8562
+ const sycl::nd_item<3> &item_ct1) {
8314
8563
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8315
8564
  item_ct1.get_local_id(1);
8316
8565
 
@@ -8338,7 +8587,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
8338
8587
  (item_ct1.get_local_id(2) %
8339
8588
  (qi / vdr)); // x block quant index when casting the quants to int
8340
8589
 
8341
- tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr);
8590
+ tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
8342
8591
  }
8343
8592
 
8344
8593
  // sum up partial sums and write back result
@@ -8353,6 +8602,200 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
8353
8602
  }
8354
8603
  }
8355
8604
 
8605
+ template <int qk, int qi, typename block_q_t, int vdr>
8606
+ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
8607
+ const void *__restrict__ vy,
8608
+ float *__restrict__ dst, const int ncols,
8609
+ const int nrows,
8610
+ const sycl::nd_item<3> &item_ct1) {
8611
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8612
+ item_ct1.get_local_id(1);
8613
+
8614
+ if (row >= nrows) {
8615
+ return;
8616
+ }
8617
+
8618
+ const int blocks_per_row = ncols / qk;
8619
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8620
+
8621
+ // partial sum for each thread
8622
+ float tmp = 0.0f;
8623
+
8624
+ const block_q_t * x = (const block_q_t *) vx;
8625
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8626
+
8627
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8628
+ i += blocks_per_warp) {
8629
+ const int ibx = row*blocks_per_row + i; // x block index
8630
+
8631
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8632
+
8633
+ const int iqs =
8634
+ vdr *
8635
+ (item_ct1.get_local_id(2) %
8636
+ (qi / vdr)); // x block quant index when casting the quants to int
8637
+
8638
+ tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
8639
+ }
8640
+
8641
+ // sum up partial sums and write back result
8642
+ #pragma unroll
8643
+ for (int mask = 16; mask > 0; mask >>= 1) {
8644
+ tmp +=
8645
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8646
+ }
8647
+
8648
+ if (item_ct1.get_local_id(2) == 0) {
8649
+ dst[row] = tmp;
8650
+ }
8651
+ }
8652
+
8653
+ template <int qk, int qi, typename block_q_t, int vdr>
8654
+ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
8655
+ const void *__restrict__ vy,
8656
+ float *__restrict__ dst, const int ncols,
8657
+ const int nrows,
8658
+ const sycl::nd_item<3> &item_ct1) {
8659
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8660
+ item_ct1.get_local_id(1);
8661
+
8662
+ if (row >= nrows) {
8663
+ return;
8664
+ }
8665
+
8666
+ const int blocks_per_row = ncols / qk;
8667
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8668
+
8669
+ // partial sum for each thread
8670
+ float tmp = 0.0f;
8671
+
8672
+ const block_q_t * x = (const block_q_t *) vx;
8673
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8674
+
8675
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8676
+ i += blocks_per_warp) {
8677
+ const int ibx = row*blocks_per_row + i; // x block index
8678
+
8679
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8680
+
8681
+ const int iqs =
8682
+ vdr *
8683
+ (item_ct1.get_local_id(2) %
8684
+ (qi / vdr)); // x block quant index when casting the quants to int
8685
+
8686
+ tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
8687
+ }
8688
+
8689
+ // sum up partial sums and write back result
8690
+ #pragma unroll
8691
+ for (int mask = 16; mask > 0; mask >>= 1) {
8692
+ tmp +=
8693
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8694
+ }
8695
+
8696
+ if (item_ct1.get_local_id(2) == 0) {
8697
+ dst[row] = tmp;
8698
+ }
8699
+ }
8700
+
8701
+ template <int qk, int qi, typename block_q_t, int vdr>
8702
+ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
8703
+ const void *__restrict__ vy,
8704
+ float *__restrict__ dst, const int ncols,
8705
+ const int nrows,
8706
+ const sycl::nd_item<3> &item_ct1) {
8707
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8708
+ item_ct1.get_local_id(1);
8709
+
8710
+ if (row >= nrows) {
8711
+ return;
8712
+ }
8713
+
8714
+ const int blocks_per_row = ncols / qk;
8715
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8716
+
8717
+ // partial sum for each thread
8718
+ float tmp = 0.0f;
8719
+
8720
+ const block_q_t * x = (const block_q_t *) vx;
8721
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8722
+
8723
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8724
+ i += blocks_per_warp) {
8725
+ const int ibx = row*blocks_per_row + i; // x block index
8726
+
8727
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8728
+
8729
+ const int iqs =
8730
+ vdr *
8731
+ (item_ct1.get_local_id(2) %
8732
+ (qi / vdr)); // x block quant index when casting the quants to int
8733
+
8734
+ tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
8735
+ }
8736
+
8737
+ // sum up partial sums and write back result
8738
+ #pragma unroll
8739
+ for (int mask = 16; mask > 0; mask >>= 1) {
8740
+ tmp +=
8741
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8742
+ }
8743
+
8744
+ if (item_ct1.get_local_id(2) == 0) {
8745
+ dst[row] = tmp;
8746
+ }
8747
+ }
8748
+
8749
+
8750
+ template <int qk, int qi, typename block_q_t, int vdr>
8751
+ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
8752
+ const void *__restrict__ vy,
8753
+ float *__restrict__ dst, const int ncols,
8754
+ const int nrows,
8755
+ const sycl::nd_item<3> &item_ct1) {
8756
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8757
+ item_ct1.get_local_id(1);
8758
+
8759
+ if (row >= nrows) {
8760
+ return;
8761
+ }
8762
+
8763
+ const int blocks_per_row = ncols / qk;
8764
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8765
+
8766
+ // partial sum for each thread
8767
+ float tmp = 0.0f;
8768
+
8769
+ const block_q_t * x = (const block_q_t *) vx;
8770
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8771
+
8772
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8773
+ i += blocks_per_warp) {
8774
+ const int ibx = row*blocks_per_row + i; // x block index
8775
+
8776
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8777
+
8778
+ const int iqs =
8779
+ vdr *
8780
+ (item_ct1.get_local_id(2) %
8781
+ (qi / vdr)); // x block quant index when casting the quants to int
8782
+
8783
+ tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
8784
+ }
8785
+
8786
+ // sum up partial sums and write back result
8787
+ #pragma unroll
8788
+ for (int mask = 16; mask > 0; mask >>= 1) {
8789
+ tmp +=
8790
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8791
+ }
8792
+
8793
+ if (item_ct1.get_local_id(2) == 0) {
8794
+ dst[row] = tmp;
8795
+ }
8796
+ }
8797
+
8798
+
8356
8799
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
8357
8800
  static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
8358
8801
  const sycl::nd_item<3> &item_ct1) {
@@ -8914,64 +9357,71 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
8914
9357
  }
8915
9358
  }
8916
9359
 
9360
+
8917
9361
  template<typename T>
8918
- static inline void swap(T & a, T & b) {
9362
+ static inline void ggml_sycl_swap(T & a, T & b) {
8919
9363
  T tmp = a;
8920
9364
  a = b;
8921
9365
  b = tmp;
8922
9366
  }
8923
9367
 
8924
- template<ggml_sort_order order>
8925
- static void k_argsort_f32_i32(const float * x, int * dst, const int ncols,
8926
- const sycl::nd_item<3> &item_ct1) {
9368
+ template <ggml_sort_order order>
9369
+ __dpct_inline__ static void
9370
+ k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
9371
+ const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
8927
9372
  // bitonic sort
8928
9373
  int col = item_ct1.get_local_id(2);
8929
9374
  int row = item_ct1.get_group(1);
8930
9375
 
8931
- if (col >= ncols) return;
9376
+ if (col >= ncols_pad) {
9377
+ return;
9378
+ }
8932
9379
 
8933
9380
  const float * x_row = x + row * ncols;
8934
- int * dst_row = dst + row * ncols;
9381
+ auto dst_row = (int *)dpct_local;
8935
9382
 
8936
9383
  // initialize indices
8937
- if (col < ncols) {
8938
- dst_row[col] = col;
8939
- }
8940
- /*
8941
- DPCT1065:58: Consider replacing sycl::nd_item::barrier() with
8942
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
8943
- performance if there is no access to global memory.
8944
- */
8945
- item_ct1.barrier();
9384
+ dst_row[col] = col;
8946
9385
 
8947
- for (int k = 2; k <= ncols; k *= 2) {
9386
+ item_ct1.barrier(sycl::access::fence_space::local_space);
9387
+
9388
+ for (int k = 2; k <= ncols_pad; k *= 2) {
8948
9389
  for (int j = k / 2; j > 0; j /= 2) {
8949
9390
  int ixj = col ^ j;
8950
9391
  if (ixj > col) {
8951
9392
  if ((col & k) == 0) {
8952
- if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
8953
- swap(dst_row[col], dst_row[ixj]);
9393
+ if (dst_row[col] >= ncols ||
9394
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
9395
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
9396
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
9397
+ ) {
9398
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
8954
9399
  }
8955
9400
  } else {
8956
- if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
8957
- swap(dst_row[col], dst_row[ixj]);
9401
+ if (dst_row[ixj] >= ncols ||
9402
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
9403
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
9404
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
9405
+ ) {
9406
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
8958
9407
  }
8959
9408
  }
8960
9409
  }
8961
9410
  /*
8962
- DPCT1118:11: SYCL group functions and algorithms must be encountered
9411
+ DPCT1118:1: SYCL group functions and algorithms must be encountered
8963
9412
  in converged control flow. You may need to adjust the code.
8964
9413
  */
8965
- /*
8966
- DPCT1065:59: Consider replacing sycl::nd_item::barrier() with
8967
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8968
- better performance if there is no access to global memory.
8969
- */
8970
- item_ct1.barrier();
9414
+ item_ct1.barrier(sycl::access::fence_space::local_space);
8971
9415
  }
8972
9416
  }
9417
+
9418
+ // copy the result to dst without the padding
9419
+ if (col < ncols) {
9420
+ dst[row * ncols + col] = dst_row[col];
9421
+ }
8973
9422
  }
8974
9423
 
9424
+
8975
9425
  static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
8976
9426
  const sycl::nd_item<3> &item_ct1) {
8977
9427
  const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -9950,28 +10400,64 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
9950
10400
  #endif
9951
10401
  }
9952
10402
 
9953
-
9954
10403
  template <typename dst_t>
9955
- static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
10404
+ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
9956
10405
  dpct::queue_ptr stream) {
9957
10406
  const int nb = k / QK_K;
9958
10407
  {
10408
+ dpct::has_capability_or_fail(stream->get_device(),
10409
+ {sycl::aspect::fp16});
10410
+
10411
+ stream->submit([&](sycl::handler &cgh) {
10412
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10413
+ sycl::range<3>(1, 1, 32),
10414
+ sycl::range<3>(1, 1, 32)),
10415
+ [=](sycl::nd_item<3> item_ct1) {
10416
+ dequantize_block_iq1_s(
10417
+ vx, y, item_ct1, iq1s_grid_gpu
10418
+ );
10419
+ });
10420
+ });
10421
+ }
10422
+ }
9959
10423
 
10424
+ template <typename dst_t>
10425
+ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
10426
+ dpct::queue_ptr stream) {
10427
+ const int nb = k / QK_K;
10428
+ {
9960
10429
  dpct::has_capability_or_fail(stream->get_device(),
9961
10430
  {sycl::aspect::fp16});
9962
10431
 
9963
10432
  stream->submit([&](sycl::handler &cgh) {
9964
- auto iq2xxs_grid_ptr_ct1 = &iq2xxs_grid[0];
9965
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
9966
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10433
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10434
+ sycl::range<3>(1, 1, 32),
10435
+ sycl::range<3>(1, 1, 32)),
10436
+ [=](sycl::nd_item<3> item_ct1) {
10437
+ dequantize_block_iq1_m(
10438
+ vx, y, item_ct1, iq1s_grid_gpu
10439
+ );
10440
+ });
10441
+ });
10442
+ }
10443
+ }
10444
+
10445
+ template <typename dst_t>
10446
+ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
10447
+ dpct::queue_ptr stream) {
10448
+ const int nb = k / QK_K;
10449
+ {
10450
+ dpct::has_capability_or_fail(stream->get_device(),
10451
+ {sycl::aspect::fp16});
9967
10452
 
10453
+ stream->submit([&](sycl::handler &cgh) {
9968
10454
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
9969
10455
  sycl::range<3>(1, 1, 32),
9970
10456
  sycl::range<3>(1, 1, 32)),
9971
10457
  [=](sycl::nd_item<3> item_ct1) {
9972
10458
  dequantize_block_iq2_xxs(
9973
- vx, y, item_ct1, iq2xxs_grid_ptr_ct1,
9974
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10459
+ vx, y, item_ct1, iq2xxs_grid,
10460
+ ksigns_iq2xs, kmask_iq2xs);
9975
10461
  });
9976
10462
  });
9977
10463
  }
@@ -9982,105 +10468,130 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
9982
10468
  dpct::queue_ptr stream) {
9983
10469
  const int nb = k / QK_K;
9984
10470
  {
9985
-
9986
10471
  dpct::has_capability_or_fail(stream->get_device(),
9987
10472
  {sycl::aspect::fp16});
9988
10473
 
9989
10474
  stream->submit([&](sycl::handler &cgh) {
9990
- auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
9991
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
9992
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
9993
-
9994
10475
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
9995
10476
  sycl::range<3>(1, 1, 32),
9996
10477
  sycl::range<3>(1, 1, 32)),
9997
10478
  [=](sycl::nd_item<3> item_ct1) {
9998
10479
  dequantize_block_iq2_xs(
9999
- vx, y, item_ct1, iq2xs_grid_ptr_ct1,
10000
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10480
+ vx, y, item_ct1, iq2xs_grid,
10481
+ ksigns_iq2xs, kmask_iq2xs);
10001
10482
  });
10002
10483
  });
10003
10484
  }
10004
10485
  }
10005
10486
 
10006
10487
  template <typename dst_t>
10007
- static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
10008
- dpct::queue_ptr stream) {
10488
+ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
10489
+ dpct::queue_ptr stream) {
10009
10490
  const int nb = k / QK_K;
10010
10491
  {
10011
-
10012
10492
  dpct::has_capability_or_fail(stream->get_device(),
10013
10493
  {sycl::aspect::fp16});
10014
10494
 
10015
10495
  stream->submit([&](sycl::handler &cgh) {
10016
- auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0];
10017
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10018
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10019
-
10020
10496
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10021
10497
  sycl::range<3>(1, 1, 32),
10022
10498
  sycl::range<3>(1, 1, 32)),
10023
10499
  [=](sycl::nd_item<3> item_ct1) {
10024
- dequantize_block_iq3_xxs(
10025
- vx, y, item_ct1, iq3xxs_grid_ptr_ct1,
10026
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10500
+ dequantize_block_iq2_s(vx, y, item_ct1);
10027
10501
  });
10028
10502
  });
10029
10503
  }
10030
10504
  }
10031
10505
 
10506
+
10032
10507
  template <typename dst_t>
10033
- static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
10508
+ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
10034
10509
  dpct::queue_ptr stream) {
10035
10510
  const int nb = k / QK_K;
10036
10511
  {
10037
-
10038
10512
  dpct::has_capability_or_fail(stream->get_device(),
10039
10513
  {sycl::aspect::fp16});
10040
10514
 
10041
10515
  stream->submit([&](sycl::handler &cgh) {
10042
- auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
10043
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10044
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10045
-
10046
10516
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10047
10517
  sycl::range<3>(1, 1, 32),
10048
10518
  sycl::range<3>(1, 1, 32)),
10049
10519
  [=](sycl::nd_item<3> item_ct1) {
10050
- dequantize_block_iq3_s(
10051
- vx, y, item_ct1, iq3s_grid_ptr_ct1,
10052
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10520
+ dequantize_block_iq3_xxs(
10521
+ vx, y, item_ct1, iq3xxs_grid,
10522
+ ksigns_iq2xs, kmask_iq2xs);
10053
10523
  });
10054
10524
  });
10055
10525
  }
10056
10526
  }
10057
10527
 
10058
10528
  template <typename dst_t>
10059
- static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
10529
+ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
10060
10530
  dpct::queue_ptr stream) {
10061
10531
  const int nb = k / QK_K;
10062
10532
  {
10063
-
10064
10533
  dpct::has_capability_or_fail(stream->get_device(),
10065
10534
  {sycl::aspect::fp16});
10066
10535
 
10067
10536
  stream->submit([&](sycl::handler &cgh) {
10068
- auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0];
10069
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10070
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10071
-
10072
10537
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10073
10538
  sycl::range<3>(1, 1, 32),
10074
10539
  sycl::range<3>(1, 1, 32)),
10075
10540
  [=](sycl::nd_item<3> item_ct1) {
10076
- dequantize_block_iq1_s(
10077
- vx, y, item_ct1, iq1s_grid_ptr_ct1,
10078
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10541
+ dequantize_block_iq3_s(
10542
+ vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
10079
10543
  });
10080
10544
  });
10081
10545
  }
10082
10546
  }
10083
10547
 
10548
+ template <typename dst_t>
10549
+ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
10550
+ dpct::queue_ptr stream) {
10551
+ const int nb = (k + QK_K - 1) / QK_K;
10552
+ #if QK_K == 64
10553
+ dequantize_row_iq4_nl_sycl(vx, y, k, stream);
10554
+ #else
10555
+ {
10556
+ dpct::has_capability_or_fail(stream->get_device(),
10557
+ {sycl::aspect::fp16});
10558
+
10559
+ stream->submit([&](sycl::handler &cgh) {
10560
+ cgh.parallel_for(
10561
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10562
+ sycl::range<3>(1, 1, 32),
10563
+ sycl::range<3>(1, 1, 32)),
10564
+ [=](sycl::nd_item<3> item_ct1) {
10565
+ dequantize_block_iq4_xs(vx, y, item_ct1);
10566
+ });
10567
+ });
10568
+ }
10569
+ #endif
10570
+ }
10571
+
10572
+
10573
+ template <typename dst_t>
10574
+ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
10575
+ dpct::queue_ptr stream) {
10576
+ const int nb = (k + QK_K - 1) / QK_K;
10577
+ {
10578
+ dpct::has_capability_or_fail(stream->get_device(),
10579
+ {sycl::aspect::fp16});
10580
+
10581
+ stream->submit([&](sycl::handler &cgh) {
10582
+ cgh.parallel_for(
10583
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10584
+ sycl::range<3>(1, 1, 32),
10585
+ sycl::range<3>(1, 1, 32)),
10586
+ [=](sycl::nd_item<3> item_ct1) {
10587
+ dequantize_block_iq4_nl(vx, y, item_ct1);
10588
+ });
10589
+ });
10590
+ }
10591
+ }
10592
+
10593
+
10594
+
10084
10595
  template <typename src_t, typename dst_t>
10085
10596
  static void convert_unary_sycl(const void *__restrict__ vx,
10086
10597
  dst_t *__restrict__ y, const int k,
@@ -10125,16 +10636,24 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
10125
10636
  return dequantize_row_q5_K_sycl;
10126
10637
  case GGML_TYPE_Q6_K:
10127
10638
  return dequantize_row_q6_K_sycl;
10639
+ case GGML_TYPE_IQ1_S:
10640
+ return dequantize_row_iq1_s_sycl;
10641
+ case GGML_TYPE_IQ1_M:
10642
+ return dequantize_row_iq1_m_sycl;
10128
10643
  case GGML_TYPE_IQ2_XXS:
10129
10644
  return dequantize_row_iq2_xxs_sycl;
10130
10645
  case GGML_TYPE_IQ2_XS:
10131
10646
  return dequantize_row_iq2_xs_sycl;
10647
+ case GGML_TYPE_IQ2_S:
10648
+ return dequantize_row_iq2_s_sycl;
10132
10649
  case GGML_TYPE_IQ3_XXS:
10133
10650
  return dequantize_row_iq3_xxs_sycl;
10134
10651
  case GGML_TYPE_IQ3_S:
10135
10652
  return dequantize_row_iq3_s_sycl;
10136
- case GGML_TYPE_IQ1_S:
10137
- return dequantize_row_iq1_s_sycl;
10653
+ case GGML_TYPE_IQ4_XS:
10654
+ return dequantize_row_iq4_xs_sycl;
10655
+ case GGML_TYPE_IQ4_NL:
10656
+ return dequantize_row_iq4_nl_sycl;
10138
10657
  case GGML_TYPE_F32:
10139
10658
  return convert_unary_sycl<float>;
10140
10659
  default:
@@ -10169,16 +10688,24 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
10169
10688
  return dequantize_row_q5_K_sycl;
10170
10689
  case GGML_TYPE_Q6_K:
10171
10690
  return dequantize_row_q6_K_sycl;
10691
+ case GGML_TYPE_IQ1_S:
10692
+ return dequantize_row_iq1_s_sycl;
10693
+ case GGML_TYPE_IQ1_M:
10694
+ return dequantize_row_iq1_m_sycl;
10172
10695
  case GGML_TYPE_IQ2_XXS:
10173
10696
  return dequantize_row_iq2_xxs_sycl;
10174
10697
  case GGML_TYPE_IQ2_XS:
10175
10698
  return dequantize_row_iq2_xs_sycl;
10699
+ case GGML_TYPE_IQ2_S:
10700
+ return dequantize_row_iq2_s_sycl;
10176
10701
  case GGML_TYPE_IQ3_XXS:
10177
10702
  return dequantize_row_iq3_xxs_sycl;
10178
10703
  case GGML_TYPE_IQ3_S:
10179
10704
  return dequantize_row_iq3_s_sycl;
10180
- case GGML_TYPE_IQ1_S:
10181
- return dequantize_row_iq1_s_sycl;
10705
+ case GGML_TYPE_IQ4_XS:
10706
+ return dequantize_row_iq4_xs_sycl;
10707
+ case GGML_TYPE_IQ4_NL:
10708
+ return dequantize_row_iq4_nl_sycl;
10182
10709
  case GGML_TYPE_F16:
10183
10710
  return convert_unary_sycl<sycl::half>;
10184
10711
  default:
@@ -10641,19 +11168,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
10641
11168
  const sycl::range<3> block_nums(1, 1, block_num_y);
10642
11169
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
10643
11170
  {
10644
-
10645
11171
  stream->submit([&](sycl::handler &cgh) {
10646
- auto iq2xxs_grid_ptr_ct1 = &iq2xxs_grid[0];
10647
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10648
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10649
-
10650
11172
  cgh.parallel_for(
10651
11173
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
10652
11174
  [=](sycl::nd_item<3> item_ct1)
10653
11175
  [[intel::reqd_sub_group_size(32)]] {
10654
11176
  mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS, block_iq2_xxs, 1>(
10655
- vx, vy, dst, ncols, nrows, item_ct1,
10656
- iq2xxs_grid_ptr_ct1, ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
11177
+ vx, vy, dst, ncols, nrows, item_ct1);
10657
11178
  });
10658
11179
  });
10659
11180
  }
@@ -10678,8 +11199,32 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
10678
11199
  [=](sycl::nd_item<3> item_ct1)
10679
11200
  [[intel::reqd_sub_group_size(32)]] {
10680
11201
  mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS, block_iq2_xs, 1>(
10681
- vx, vy, dst, ncols, nrows, item_ct1,
10682
- iq2xs_grid_ptr_ct1, ksigns64_ptr_ct1);
11202
+ vx, vy, dst, ncols, nrows, item_ct1);
11203
+ });
11204
+ });
11205
+ }
11206
+ }
11207
+
11208
+ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
11209
+ float *dst, const int ncols,
11210
+ const int nrows,
11211
+ dpct::queue_ptr stream) {
11212
+ GGML_ASSERT(ncols % QK_K == 0);
11213
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11214
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11215
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11216
+ {
11217
+
11218
+ stream->submit([&](sycl::handler &cgh) {
11219
+ auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
11220
+ auto ksigns64_ptr_ct1 = &ksigns64[0];
11221
+
11222
+ cgh.parallel_for(
11223
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11224
+ [=](sycl::nd_item<3> item_ct1)
11225
+ [[intel::reqd_sub_group_size(32)]] {
11226
+ mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
11227
+ vx, vy, dst, ncols, nrows, item_ct1);
10683
11228
  });
10684
11229
  });
10685
11230
  }
@@ -10704,8 +11249,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
10704
11249
  [=](sycl::nd_item<3> item_ct1)
10705
11250
  [[intel::reqd_sub_group_size(32)]] {
10706
11251
  mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS, block_iq3_xxs, 1>(
10707
- vx, vy, dst, ncols, nrows, item_ct1,
10708
- iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1);
11252
+ vx, vy, dst, ncols, nrows, item_ct1);
10709
11253
  });
10710
11254
  });
10711
11255
  }
@@ -10723,15 +11267,13 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
10723
11267
 
10724
11268
  stream->submit([&](sycl::handler &cgh) {
10725
11269
  auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
10726
- auto ksigns64_ptr_ct1 = &ksigns64[0];
10727
11270
 
10728
11271
  cgh.parallel_for(
10729
11272
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
10730
11273
  [=](sycl::nd_item<3> item_ct1)
10731
11274
  [[intel::reqd_sub_group_size(32)]] {
10732
11275
  mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_XS, block_iq3_s, 1>(
10733
- vx, vy, dst, ncols, nrows, item_ct1,
10734
- iq3s_grid_ptr_ct1, ksigns64_ptr_ct1);
11276
+ vx, vy, dst, ncols, nrows, item_ct1);
10735
11277
  });
10736
11278
  });
10737
11279
  }
@@ -10756,8 +11298,72 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
10756
11298
  [=](sycl::nd_item<3> item_ct1)
10757
11299
  [[intel::reqd_sub_group_size(32)]] {
10758
11300
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
10759
- vx, vy, dst, ncols, nrows, item_ct1,
10760
- iq1s_grid_ptr_ct1, ksigns64_ptr_ct1);
11301
+ vx, vy, dst, ncols, nrows, item_ct1);
11302
+ });
11303
+ });
11304
+ }
11305
+ }
11306
+
11307
+ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
11308
+ float *dst, const int ncols,
11309
+ const int nrows,
11310
+ dpct::queue_ptr stream) {
11311
+ GGML_ASSERT(ncols % QK_K == 0);
11312
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11313
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11314
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11315
+ {
11316
+ stream->submit([&](sycl::handler &cgh) {
11317
+ cgh.parallel_for(
11318
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11319
+ [=](sycl::nd_item<3> item_ct1)
11320
+ [[intel::reqd_sub_group_size(32)]] {
11321
+ mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
11322
+ vx, vy, dst, ncols, nrows, item_ct1);
11323
+ });
11324
+ });
11325
+ }
11326
+ }
11327
+
11328
+ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
11329
+ float *dst, const int ncols,
11330
+ const int nrows,
11331
+ dpct::queue_ptr stream) {
11332
+ GGML_ASSERT(ncols % QK4_NL == 0);
11333
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11334
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11335
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11336
+ {
11337
+
11338
+ stream->submit([&](sycl::handler &cgh) {
11339
+ cgh.parallel_for(
11340
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11341
+ [=](sycl::nd_item<3> item_ct1)
11342
+ [[intel::reqd_sub_group_size(32)]] {
11343
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
11344
+ vx, vy, dst, ncols, nrows, item_ct1);
11345
+ });
11346
+ });
11347
+ }
11348
+ }
11349
+
11350
+ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
11351
+ float *dst, const int ncols,
11352
+ const int nrows,
11353
+ dpct::queue_ptr stream) {
11354
+ GGML_ASSERT(ncols % QK_K == 0);
11355
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11356
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11357
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11358
+ {
11359
+
11360
+ stream->submit([&](sycl::handler &cgh) {
11361
+ cgh.parallel_for(
11362
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11363
+ [=](sycl::nd_item<3> item_ct1)
11364
+ [[intel::reqd_sub_group_size(32)]] {
11365
+ mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS, block_iq4_xs, 1>(
11366
+ vx, vy, dst, ncols, nrows, item_ct1);
10761
11367
  });
10762
11368
  });
10763
11369
  }
@@ -12381,36 +12987,54 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12381
12987
  });
12382
12988
  }
12383
12989
 
12990
+ static int next_power_of_2(int x) {
12991
+ int n = 1;
12992
+ while (n < x) {
12993
+ n *= 2;
12994
+ }
12995
+ return n;
12996
+ }
12997
+
12384
12998
  static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
12385
12999
  const int nrows, ggml_sort_order order,
12386
13000
  dpct::queue_ptr stream) {
12387
13001
  // bitonic sort requires ncols to be power of 2
12388
- GGML_ASSERT((ncols & (ncols - 1)) == 0);
13002
+ const int ncols_pad = next_power_of_2(ncols);
12389
13003
 
12390
- const sycl::range<3> block_dims(1, 1, ncols);
13004
+ const sycl::range<3> block_dims(1, 1, ncols_pad);
12391
13005
  const sycl::range<3> block_nums(1, nrows, 1);
13006
+ const size_t shared_mem = ncols_pad * sizeof(int);
13007
+
13008
+ // GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
13009
+
12392
13010
  if (order == GGML_SORT_ORDER_ASC) {
12393
- /*
12394
- DPCT1049:44: The work-group size passed to the SYCL kernel may exceed
12395
- the limit. To get the device limit, query
12396
- info::device::max_work_group_size. Adjust the work-group size if needed.
12397
- */
12398
- stream->parallel_for(
12399
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12400
- [=](sycl::nd_item<3> item_ct1) {
12401
- k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(x, dst, ncols, item_ct1);
12402
- });
13011
+ stream->submit([&](sycl::handler &cgh) {
13012
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
13013
+ sycl::range<1>(shared_mem), cgh);
13014
+
13015
+ cgh.parallel_for(
13016
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
13017
+ [=](sycl::nd_item<3> item_ct1) {
13018
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
13019
+ x, dst, ncols, ncols_pad, item_ct1,
13020
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
13021
+ .get());
13022
+ });
13023
+ });
12403
13024
  } else if (order == GGML_SORT_ORDER_DESC) {
12404
- /*
12405
- DPCT1049:45: The work-group size passed to the SYCL kernel may exceed
12406
- the limit. To get the device limit, query
12407
- info::device::max_work_group_size. Adjust the work-group size if needed.
12408
- */
12409
- stream->parallel_for(
12410
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12411
- [=](sycl::nd_item<3> item_ct1) {
12412
- k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(x, dst, ncols, item_ct1);
12413
- });
13025
+ stream->submit([&](sycl::handler &cgh) {
13026
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
13027
+ sycl::range<1>(shared_mem), cgh);
13028
+
13029
+ cgh.parallel_for(
13030
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
13031
+ [=](sycl::nd_item<3> item_ct1) {
13032
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
13033
+ x, dst, ncols, ncols_pad, item_ct1,
13034
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
13035
+ .get());
13036
+ });
13037
+ });
12414
13038
  } else {
12415
13039
  GGML_ASSERT(false);
12416
13040
  }
@@ -12455,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
12455
13079
  const int nrows_y, const float scale, const float max_bias,
12456
13080
  dpct::queue_ptr stream) {
12457
13081
  int nth = WARP_SIZE;
12458
- while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
13082
+ int max_block_size = g_work_group_size;
13083
+ while (nth < ncols_x && nth < max_block_size) nth *= 2;
13084
+ if (nth>max_block_size) nth = max_block_size;
13085
+
12459
13086
  const sycl::range<3> block_dims(1, 1, nth);
12460
13087
  const sycl::range<3> block_nums(1, 1, nrows_x);
12461
13088
  const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
12462
- static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
12463
13089
 
12464
13090
  const uint32_t n_head_kv = nrows_x/nrows_y;
12465
13091
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
@@ -12469,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
12469
13095
 
12470
13096
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
12471
13097
  if (n_local_scratch*sizeof(float) < local_mem_size) {
13098
+ if (ncols_x > max_block_size) {
13099
+ soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13100
+ max_bias, m0, m1, n_head_log2, block_nums,
13101
+ block_dims, n_local_scratch, stream);
13102
+ return;
13103
+ }
12472
13104
  switch (ncols_x) {
12473
13105
  case 32:
12474
13106
  soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
@@ -13538,8 +14170,12 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
13538
14170
  case GGML_TYPE_Q5_K:
13539
14171
  case GGML_TYPE_IQ2_XXS:
13540
14172
  case GGML_TYPE_IQ2_XS:
14173
+ case GGML_TYPE_IQ2_S:
13541
14174
  case GGML_TYPE_IQ1_S:
14175
+ case GGML_TYPE_IQ1_M:
13542
14176
  case GGML_TYPE_IQ3_XXS:
14177
+ case GGML_TYPE_IQ4_XS:
14178
+ case GGML_TYPE_IQ4_NL:
13543
14179
  return max_compute_capability >= VER_GEN9 ? 128 : 64;
13544
14180
  case GGML_TYPE_IQ3_S:
13545
14181
  return max_compute_capability >= VER_GEN9 ? 128 : 64;
@@ -13558,11 +14194,20 @@ inline void ggml_sycl_op_mul_mat_vec_q(
13558
14194
  const int64_t src1_ncols, const int64_t src1_padded_row_size,
13559
14195
  const dpct::queue_ptr &stream) {
13560
14196
 
13561
- GGML_ASSERT(ggml_nrows(src1) == 1);
14197
+ const int64_t ne10 = src1->ne[0];
14198
+ GGML_ASSERT(ne10 % QK8_1 == 0);
13562
14199
 
13563
14200
  const int64_t ne00 = src0->ne[0];
13564
14201
  const int64_t row_diff = row_high - row_low;
13565
14202
 
14203
+ int id;
14204
+ SYCL_CHECK(
14205
+ CHECK_TRY_ERROR(id = get_current_device_id()));
14206
+
14207
+ // the main device has a larger memory buffer to hold the results from all GPUs
14208
+ // nrows_dst == nrows of the matrix that the kernel writes into
14209
+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne00 : row_diff;
14210
+
13566
14211
  switch (src0->type) {
13567
14212
  case GGML_TYPE_Q4_0:
13568
14213
  mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
@@ -13594,20 +14239,32 @@ inline void ggml_sycl_op_mul_mat_vec_q(
13594
14239
  case GGML_TYPE_Q6_K:
13595
14240
  mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13596
14241
  break;
14242
+ case GGML_TYPE_IQ1_S:
14243
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14244
+ break;
14245
+ case GGML_TYPE_IQ1_M:
14246
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14247
+ break;
13597
14248
  case GGML_TYPE_IQ2_XXS:
13598
14249
  mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13599
14250
  break;
13600
14251
  case GGML_TYPE_IQ2_XS:
13601
14252
  mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13602
14253
  break;
14254
+ case GGML_TYPE_IQ2_S:
14255
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14256
+ break;
13603
14257
  case GGML_TYPE_IQ3_XXS:
13604
14258
  mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13605
14259
  break;
13606
14260
  case GGML_TYPE_IQ3_S:
13607
14261
  mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13608
14262
  break;
13609
- case GGML_TYPE_IQ1_S:
13610
- mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14263
+ case GGML_TYPE_IQ4_NL:
14264
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14265
+ break;
14266
+ case GGML_TYPE_IQ4_XS:
14267
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13611
14268
  break;
13612
14269
  default:
13613
14270
  GGML_ASSERT(false);
@@ -13689,6 +14346,7 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
13689
14346
  convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
13690
14347
  break;
13691
14348
  default:
14349
+ printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
13692
14350
  GGML_ASSERT(false);
13693
14351
  break;
13694
14352
  }
@@ -14543,8 +15201,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
14543
15201
  src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
14544
15202
  }
14545
15203
  // do the computation
14546
- op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
14547
- dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream);
15204
+ SYCL_CHECK(CHECK_TRY_ERROR(op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
15205
+ dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
14548
15206
  /*
14549
15207
  DPCT1010:93: SYCL uses exceptions to report errors and does not
14550
15208
  use the error codes. The call was replaced with 0. You need to
@@ -15125,11 +15783,17 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15125
15783
  #ifdef GGML_SYCL_FORCE_DMMV
15126
15784
  const bool use_mul_mat_vec_q = false;
15127
15785
  #else
15128
- const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
15786
+ bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15787
+ use_mul_mat_vec_q = use_mul_mat_vec_q ||
15788
+ (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15789
+ (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15790
+ (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15791
+ (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15792
+
15793
+
15129
15794
  #endif // GGML_SYCL_FORCE_DMMV
15130
15795
 
15131
15796
  if (use_mul_mat_vec_q) {
15132
- // NOTE: this kernel does not support ggml_nrows(src1) > 1
15133
15797
  // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15134
15798
  ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15135
15799
  } else {
@@ -15332,73 +15996,76 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15332
15996
  static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15333
15997
  const ggml_tensor *src1,
15334
15998
  ggml_tensor *dst) try {
15335
- #if 0
15336
- ggml_sycl_mul_mat_id_sycl(dst);
15337
- // TODO: mmq/mmv support
15338
- #endif
15999
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
16000
+ "mul_mat_id does not support split buffers");
16001
+ const ggml_tensor *ids = dst->src[2];
16002
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15339
16003
 
15340
- const int64_t nb11 = src1->nb[1];
15341
- const int64_t nb1 = dst->nb[1];
16004
+ const size_t nb11 = src1->nb[1];
16005
+ const size_t nb1 = dst->nb[1];
15342
16006
 
15343
- const struct ggml_tensor * ids = src0;
15344
- const int32_t id = ((int32_t *) dst->op_params)[0];
15345
- const int32_t n_as = ((int32_t *) dst->op_params)[1];
16007
+ const int32_t id = ((int32_t *)dst->op_params)[0];
16008
+ const int32_t n_as = src0->ne[2];
15346
16009
 
15347
16010
  std::vector<char> ids_host(ggml_nbytes(ids));
16011
+ const char *ids_dev = (const char *)ids->data;
15348
16012
 
15349
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15350
-
15351
- if (ids->backend == GGML_BACKEND_TYPE_GPU) {
15352
- const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
15353
- SYCL_CHECK(CHECK_TRY_ERROR(
15354
- stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)).wait()));
15355
- // SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
15356
- } else {
15357
- memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
15358
- }
16013
+ SYCL_CHECK(CHECK_TRY_ERROR(
16014
+ stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
16015
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
15359
16016
 
15360
- const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
15361
- const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
16017
+ const ggml_tensor_extra_gpu *src0_extra =
16018
+ (const ggml_tensor_extra_gpu *)src0->extra;
16019
+ const ggml_tensor_extra_gpu *src1_extra =
16020
+ (const ggml_tensor_extra_gpu *)src1->extra;
16021
+ const ggml_tensor_extra_gpu *dst_extra =
16022
+ (const ggml_tensor_extra_gpu *)dst->extra;
15362
16023
 
16024
+ ggml_tensor_extra_gpu src0_row_extra;
15363
16025
  ggml_tensor_extra_gpu src1_row_extra;
15364
16026
  ggml_tensor_extra_gpu dst_row_extra;
15365
16027
 
16028
+ ggml_tensor src0_row = *src0;
15366
16029
  ggml_tensor src1_row = *src1;
15367
16030
  ggml_tensor dst_row = *dst;
15368
16031
 
15369
16032
  src1_row.backend = GGML_BACKEND_TYPE_GPU;
15370
16033
  dst_row.backend = GGML_BACKEND_TYPE_GPU;
15371
16034
 
16035
+ src0_row.extra = &src0_row_extra;
15372
16036
  src1_row.extra = &src1_row_extra;
15373
16037
  dst_row.extra = &dst_row_extra;
15374
16038
 
15375
- char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
15376
- (char *) src1->data : (char *) src1_extra->data_device[g_main_device];
15377
- char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ?
15378
- (char *) dst->data : (char *) dst_extra->data_device[g_main_device];
16039
+ char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
16040
+ ? (char *)src0->data
16041
+ : (char *)src0_extra->data_device[g_main_device];
16042
+ char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
16043
+ ? (char *)src1->data
16044
+ : (char *)src1_extra->data_device[g_main_device];
16045
+ char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
16046
+ ? (char *)dst->data
16047
+ : (char *)dst_extra->data_device[g_main_device];
15379
16048
 
15380
- if (src1->ne[1] == 1) {
15381
- GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
15382
- GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
16049
+ src0_row.ne[2] = 1;
16050
+ src0_row.ne[3] = 1;
16051
+ src0_row.nb[3] = src0->nb[2];
15383
16052
 
16053
+ if (src1->ne[1] == 1) {
15384
16054
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15385
- //int32_t row_id;
15386
- //SYCL_CHECK(syclMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), syclMemcpyDeviceToHost, g_syclStreams[g_main_device][0]));
15387
- //SYCL_CHECK(syclStreamSynchronize(g_syclStreams[g_main_device][0]));
15388
-
15389
- const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
16055
+ const int32_t row_id =
16056
+ *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
16057
+ id * ids->nb[0]);
15390
16058
 
15391
16059
  GGML_ASSERT(row_id >= 0 && row_id < n_as);
15392
16060
 
15393
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
15394
-
15395
- src1_row_extra.data_device[g_main_device] = src1_original + i01*src1->nb[1];
15396
- src1_row.data = (char *) src1->data + i01*src1->nb[1]; // TODO why is this set?
16061
+ src0_row_extra.data_device[g_main_device] =
16062
+ src0_original + row_id * src0->nb[2];
16063
+ src1_row_extra.data_device[g_main_device] =
16064
+ src1_original + i01 * src1->nb[1];
16065
+ dst_row_extra.data_device[g_main_device] =
16066
+ dst_original + i01 * dst->nb[1];
15397
16067
 
15398
- dst_row_extra.data_device[g_main_device] = dst_original + i01*dst->nb[1];
15399
- dst_row.data = (char *) dst->data + i01*dst->nb[1]; // TODO why is this set?
15400
-
15401
- ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
16068
+ ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15402
16069
  }
15403
16070
  } else {
15404
16071
  sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15408,8 +16075,6 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15408
16075
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
15409
16076
 
15410
16077
  for (int32_t row_id = 0; row_id < n_as; ++row_id) {
15411
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
15412
-
15413
16078
  int64_t num_src1_rows = 0;
15414
16079
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15415
16080
  const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
@@ -15422,7 +16087,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15422
16087
 
15423
16088
  SYCL_CHECK(CHECK_TRY_ERROR(
15424
16089
  stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15425
- src1_original + i01 * nb11, nb11).wait()));
16090
+ src1_original + i01 * nb11, nb11)));
15426
16091
  num_src1_rows++;
15427
16092
  }
15428
16093
 
@@ -15430,6 +16095,9 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15430
16095
  continue;
15431
16096
  }
15432
16097
 
16098
+ src0_row_extra.data_device[g_main_device] =
16099
+ src0_original + row_id * src0->nb[2];
16100
+
15433
16101
  src1_row.ne[1] = num_src1_rows;
15434
16102
  dst_row.ne[1] = num_src1_rows;
15435
16103
 
@@ -15441,7 +16109,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15441
16109
  dst_row.nb[2] = num_src1_rows*nb1;
15442
16110
  dst_row.nb[3] = num_src1_rows*nb1;
15443
16111
 
15444
- ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
16112
+ ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15445
16113
 
15446
16114
  num_src1_rows = 0;
15447
16115
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@@ -15455,7 +16123,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15455
16123
 
15456
16124
  SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
15457
16125
  dst_original + i01 * nb1,
15458
- dst_contiguous.get() + num_src1_rows * nb1, nb1).wait()));
16126
+ dst_contiguous.get() + num_src1_rows * nb1, nb1)));
15459
16127
  num_src1_rows++;
15460
16128
  }
15461
16129
  }
@@ -16157,11 +16825,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
16157
16825
  const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
16158
16826
  SYCL_CHECK(
16159
16827
  CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
16160
-
16828
+ char* host_buf = (char*)malloc(size);
16829
+ memcpy(host_buf, data, size);
16161
16830
  SYCL_CHECK(
16162
16831
  CHECK_TRY_ERROR((*stream)
16163
- .memcpy((char *)tensor->data + offset, data, size)
16832
+ .memcpy((char *)tensor->data + offset, host_buf, size)
16164
16833
  .wait()));
16834
+ free(host_buf);
16165
16835
  }
16166
16836
  catch (sycl::exception const &exc) {
16167
16837
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -16985,9 +17655,14 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
16985
17655
  return false;
16986
17656
  }
16987
17657
  ggml_type a_type = a->type;
16988
- if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ2_S ||
16989
- a_type == GGML_TYPE_IQ4_XS) {
16990
- return false;
17658
+ if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
17659
+ a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
17660
+ a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
17661
+ a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
17662
+ ) {
17663
+ if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
17664
+ return false;
17665
+ }
16991
17666
  }
16992
17667
  return true;
16993
17668
  } break;
@@ -17077,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
17077
17752
 
17078
17753
  GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
17079
17754
  const int min_batch_size = 32;
17080
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
17755
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
17081
17756
  GGML_UNUSED(backend);
17082
17757
  }
17083
17758