llama_cpp 0.14.4 → 0.14.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -1664,24 +1664,6 @@ namespace dpct
1664
1664
  const void *alpha, const void *a, int lda, const void *b,
1665
1665
  int ldb, const void *beta, void *c, int ldc)
1666
1666
  {
1667
- #ifndef __INTEL_MKL__
1668
- GGML_UNUSED(q);
1669
- GGML_UNUSED(a_trans);
1670
- GGML_UNUSED(b_trans);
1671
- GGML_UNUSED(m);
1672
- GGML_UNUSED(n);
1673
- GGML_UNUSED(k);
1674
- GGML_UNUSED(alpha);
1675
- GGML_UNUSED(a);
1676
- GGML_UNUSED(lda);
1677
- GGML_UNUSED(b);
1678
- GGML_UNUSED(ldb);
1679
- GGML_UNUSED(beta);
1680
- GGML_UNUSED(c);
1681
- GGML_UNUSED(ldc);
1682
- throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
1683
- "Project does not support this API.");
1684
- #else
1685
1667
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1686
1668
  Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1687
1669
  auto data_a = get_memory<const Ta>(a);
@@ -1690,7 +1672,6 @@ namespace dpct
1690
1672
  oneapi::mkl::blas::column_major::gemm(
1691
1673
  q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1692
1674
  data_b, ldb, beta_value, data_c, ldc);
1693
- #endif
1694
1675
  }
1695
1676
 
1696
1677
  template <typename VecT, class BinaryOperation, class = void>
@@ -2330,6 +2311,7 @@ namespace dpct
2330
2311
  lda, b, ldb, beta, c, ldc);
2331
2312
  break;
2332
2313
  }
2314
+ #ifdef __INTEL_MKL__
2333
2315
  case detail::get_type_combination_id(
2334
2316
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2335
2317
  library_data_t::real_float, library_data_t::real_float):
@@ -2391,6 +2373,7 @@ namespace dpct
2391
2373
  q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
2392
2374
  break;
2393
2375
  }
2376
+ #endif // __INTEL_MKL__
2394
2377
  default:
2395
2378
  throw std::runtime_error("the combination of data type is unsupported");
2396
2379
  }
@@ -3055,6 +3038,10 @@ typedef float dfloat; // dequantize float
3055
3038
  typedef sycl::float2 dfloat2;
3056
3039
  #endif //GGML_SYCL_F16
3057
3040
 
3041
+ #define MMVQ_MAX_BATCH_SIZE 8
3042
+
3043
+ static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
3044
+
3058
3045
  bool ggml_sycl_loaded(void);
3059
3046
  void * ggml_sycl_host_malloc(size_t size);
3060
3047
  void ggml_sycl_host_free(void * ptr);
@@ -4490,6 +4477,32 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
4490
4477
 
4491
4478
  }
4492
4479
 
4480
+ template <typename dst_t>
4481
+ __dpct_inline__ static void
4482
+ dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4483
+ const sycl::nd_item<3> &item_ct1) {
4484
+
4485
+ const int i = item_ct1.get_group(2);
4486
+ const block_iq2_s * x = (const block_iq2_s *) vx;
4487
+
4488
+ const int tid = item_ct1.get_local_id(2);
4489
+ #if QK_K == 256
4490
+ const int il = tid/8; // 0...3
4491
+ const int ib = tid%8; // 0...7
4492
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4493
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
4494
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
4495
+ const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
4496
+ #pragma unroll
4497
+ for (int j = 0; j < 8; ++j)
4498
+ y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
4499
+ #else
4500
+ assert(false);
4501
+
4502
+ #endif
4503
+
4504
+ }
4505
+
4493
4506
  template<typename dst_t>
4494
4507
  static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
4495
4508
  const sycl::nd_item<3> &item_ct1,
@@ -4522,26 +4535,26 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
4522
4535
 
4523
4536
  }
4524
4537
 
4525
- template<typename dst_t>
4526
- static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
4527
- const sycl::nd_item<3> &item_ct1,
4528
- const uint32_t *iq3s_grid,
4529
- const uint8_t *ksigns_iq2xs,
4530
- const uint8_t *kmask_iq2xs) {
4538
+ template <typename dst_t>
4539
+ __dpct_inline__ static void
4540
+ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4541
+ const sycl::nd_item<3> &item_ct1,
4542
+ const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
4531
4543
 
4532
4544
  const int i = item_ct1.get_group(2);
4533
- const block_iq3_s * x = (const block_iq3_s *) vx;
4545
+ const block_iq3_s * x = (const block_iq3_s *) vx;
4534
4546
 
4535
4547
  const int tid = item_ct1.get_local_id(2);
4536
4548
  #if QK_K == 256
4537
4549
  const int il = tid/8; // 0...3
4538
4550
  const int ib = tid%8; // 0...7
4539
4551
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4540
- const uint8_t * qs = x[i].qs + 8*ib;
4541
- const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + qs[2*il+0]);
4542
- const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + qs[2*il+1]);
4552
+ const uint8_t * qs = x[i].qs + 8*ib;
4553
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
4554
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
4543
4555
  const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
4544
4556
  const uint8_t signs = x[i].signs[4*ib + il];
4557
+ #pragma unroll
4545
4558
  for (int j = 0; j < 4; ++j) {
4546
4559
  y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4547
4560
  y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
@@ -4552,12 +4565,12 @@ static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restr
4552
4565
 
4553
4566
  }
4554
4567
 
4555
- template<typename dst_t>
4556
- static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
4557
- const sycl::nd_item<3> &item_ct1,
4558
- const uint32_t *iq1s_grid,
4559
- const uint8_t *ksigns_iq2xs,
4560
- const uint8_t *kmask_iq2xs) {
4568
+ template <typename dst_t>
4569
+ __dpct_inline__ static void
4570
+ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4571
+ const sycl::nd_item<3> &item_ct1,
4572
+ const uint32_t *iq1s_grid_gpu) {
4573
+
4561
4574
  const int i = item_ct1.get_group(2);
4562
4575
  const block_iq1_s * x = (const block_iq1_s *) vx;
4563
4576
 
@@ -4566,14 +4579,15 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
4566
4579
  const int il = tid/8; // 0...3
4567
4580
  const int ib = tid%8; // 0...7
4568
4581
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4569
- const uint8_t * qs = x[i].qs + 8*ib;
4570
- const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]);
4571
- const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]);
4572
- const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
4573
- const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
4574
- for (int j = 0; j < 4; ++j) {
4575
- y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4576
- y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4582
+ const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
4583
+ const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
4584
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
4585
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
4586
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
4587
+ grid32[0] &= 0x0f0f0f0f;
4588
+ #pragma unroll
4589
+ for (int j = 0; j < 8; ++j) {
4590
+ y[j] = d * (q[j] + delta);
4577
4591
  }
4578
4592
  #else
4579
4593
  assert(false);
@@ -4581,6 +4595,85 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
4581
4595
 
4582
4596
  }
4583
4597
 
4598
+ template <typename dst_t>
4599
+ __dpct_inline__ static void
4600
+ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
4601
+ const sycl::nd_item<3> &item_ct1,
4602
+ const uint32_t *iq1s_grid_gpu) {
4603
+
4604
+ const int i = item_ct1.get_group(2);
4605
+ const block_iq1_m * x = (const block_iq1_m *) vx;
4606
+
4607
+ const int tid = item_ct1.get_local_id(2);
4608
+ #if QK_K == 256
4609
+ const int il = tid/8; // 0...3
4610
+ const int ib = tid%8; // 0...7
4611
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4612
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
4613
+ iq1m_scale_t scale;
4614
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4615
+ const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
4616
+ const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
4617
+ const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
4618
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
4619
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
4620
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
4621
+ grid32[0] &= 0x0f0f0f0f;
4622
+ #pragma unroll
4623
+ for (int j = 0; j < 8; ++j) {
4624
+ y[j] = d * (q[j] + delta);
4625
+ }
4626
+ #else
4627
+ assert(false);
4628
+ #endif
4629
+
4630
+ }
4631
+
4632
+ template <typename dst_t>
4633
+ __dpct_inline__ static void
4634
+ dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
4635
+ const sycl::nd_item<3> &item_ct1) {
4636
+
4637
+ const int i = item_ct1.get_group(2);
4638
+ const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
4639
+
4640
+ const int tid = item_ct1.get_local_id(2);
4641
+ const int il = tid/8; // 0...3
4642
+ const int ib = tid%8; // 0...7
4643
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
4644
+ const uint8_t * q4 = x[ib].qs + 4*il;
4645
+ const float d = (float)x[ib].d;
4646
+ #pragma unroll
4647
+ for (int j = 0; j < 4; ++j) {
4648
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
4649
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
4650
+ }
4651
+
4652
+ }
4653
+
4654
+
4655
+ template <typename dst_t>
4656
+ __dpct_inline__ static void
4657
+ dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
4658
+ const sycl::nd_item<3> &item_ct1) {
4659
+ const int i = item_ct1.get_group(2);
4660
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
4661
+
4662
+ const int tid = item_ct1.get_local_id(2);
4663
+ const int il = tid/8; // 0...3
4664
+ const int ib = tid%8; // 0...7
4665
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
4666
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
4667
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
4668
+ #pragma unroll
4669
+ for (int j = 0; j < 4; ++j) {
4670
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
4671
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
4672
+ }
4673
+ }
4674
+
4675
+
4676
+
4584
4677
  /*
4585
4678
  DPCT1110:4: The total declared local variable size in device function
4586
4679
  dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
@@ -7387,6 +7480,58 @@ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
7387
7480
  #endif
7388
7481
  }
7389
7482
 
7483
+ static __dpct_inline__ float
7484
+ vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
7485
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7486
+ #if QK_K == 256
7487
+ const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
7488
+
7489
+ const int ib32 = iqs;
7490
+ const int8_t * q8 = bq8_1[ib32].qs;
7491
+ const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
7492
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
7493
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
7494
+ int sumi1 = 0;
7495
+ for (int l = 0; l < 2; ++l) {
7496
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
7497
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7498
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
7499
+ std::equal_to<>());
7500
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7501
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
7502
+ std::equal_to<>());
7503
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7504
+ grid[0] ^ signs0, signs0, std::minus<>());
7505
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7506
+ grid[1] ^ signs1, signs1, std::minus<>());
7507
+ sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
7508
+ sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
7509
+ q8 += 8;
7510
+ }
7511
+ int sumi2 = 0;
7512
+ for (int l = 2; l < 4; ++l) {
7513
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
7514
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7515
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
7516
+ std::equal_to<>());
7517
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7518
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
7519
+ std::equal_to<>());
7520
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7521
+ grid[0] ^ signs0, signs0, std::minus<>());
7522
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7523
+ grid[1] ^ signs1, signs1, std::minus<>());
7524
+ sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
7525
+ sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
7526
+ q8 += 8;
7527
+ }
7528
+ const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
7529
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
7530
+ #else
7531
+ assert(false);
7532
+ #endif
7533
+ }
7534
+
7390
7535
  static __dpct_inline__ float
7391
7536
  vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7392
7537
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
@@ -7429,10 +7574,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7429
7574
 
7430
7575
  static __dpct_inline__ float
7431
7576
  vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7432
- const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7433
- const uint32_t *iq3s_grid, const uint64_t *ksigns64) {
7434
- #if DPCT_COMPATIBILITY_TEMP >= \
7435
- MIN_CC_DP4A // lowest compute capability for integer intrinsics
7577
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7578
+ const uint32_t *iq3s_grid) {
7436
7579
  #if QK_K == 256
7437
7580
  const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
7438
7581
 
@@ -7444,9 +7587,11 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7444
7587
  const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
7445
7588
  const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
7446
7589
  uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7447
- ((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
7590
+ ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,
7591
+ 0x08040201, std::equal_to<>());
7448
7592
  uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7449
- ((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
7593
+ ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,
7594
+ 0x08040201, std::equal_to<>());
7450
7595
  const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7451
7596
  grid1[0] ^ signs0, signs0, std::minus<>());
7452
7597
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
@@ -7455,45 +7600,142 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7455
7600
  sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
7456
7601
  q8 += 8;
7457
7602
  }
7458
- const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * bq8_1[ib32].ds[0];
7603
+ const float d =
7604
+ (float)bq2->d *
7605
+ (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
7606
+ bq8_1[ib32].ds[0];
7459
7607
  return d * sumi;
7460
7608
  #else
7461
7609
  assert(false);
7462
- return 0.f;
7463
- #endif
7464
- #else
7465
- assert(false);
7466
- return 0.f;
7467
7610
  #endif
7468
7611
  }
7469
7612
 
7470
7613
  static __dpct_inline__ float
7471
7614
  vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
7472
- const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7473
- const uint32_t *iq1s_grid, const uint64_t *ksigns64) {
7615
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7616
+ const uint32_t *iq1s_grid_gpu) {
7474
7617
  #if QK_K == 256
7475
7618
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
7476
7619
 
7477
7620
  const int ib32 = iqs;
7478
- const uint8_t * qs = bq1->qs + 4*ib32;
7479
- const int8_t * q8 = bq8_1[ib32].qs;
7480
7621
  int sumi = 0;
7622
+ const int * q8 = (const int *)bq8_1[ib32].qs;
7481
7623
  for (int l = 0; l < 4; ++l) {
7482
- const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]);
7483
- const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8));
7484
- const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7485
- grid[0] ^ signs[0], signs[0], std::minus<>());
7486
- const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7487
- grid[1] ^ signs[1], signs[1], std::minus<>());
7488
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
7489
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
7490
- q8 += 8;
7624
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
7625
+ int grid0 = grid[0] & 0x0f0f0f0f;
7626
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
7627
+ sumi = dpct::dp4a(q8[2 * l + 1], grid1,
7628
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi));
7629
+ }
7630
+
7631
+ const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
7632
+ const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
7633
+ const float d = d1q * bq8_1[ib32].ds[0];
7634
+ const float m = d1q * bq8_1[ib32].ds[1];
7635
+ return d * sumi + m * delta;
7636
+ #else
7637
+ assert(false);
7638
+ #endif
7639
+ }
7640
+
7641
+ static __dpct_inline__ float
7642
+ vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
7643
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7644
+ #if QK_K == 256
7645
+ const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
7646
+
7647
+ const int ib32 = iqs;
7648
+ int sumi[2] = {0, 0};
7649
+ float sumf[2] = {0.f, 0.f};
7650
+
7651
+ const int * q8 = (const int *)bq8_1[ib32].qs;
7652
+ for (int l = 0; l < 4; ++l) {
7653
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
7654
+ int grid0 = grid[0] & 0x0f0f0f0f;
7655
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
7656
+ sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,
7657
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));
7658
+ const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
7659
+ const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,
7660
+ dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));
7661
+ sumf[l/2] += delta*sumy;
7662
+ }
7663
+
7664
+ iq1m_scale_t scale;
7665
+ const uint16_t * sc = (const uint16_t *)bq1->scales;
7666
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
7667
+ const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
7668
+ return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
7669
+ #else
7670
+ assert(false);
7671
+ #endif
7672
+ }
7673
+
7674
+ static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
7675
+ const uint8_t *values,
7676
+ int &val1, int &val2) {
7677
+
7678
+ uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
7679
+ aux32 = q4 & 0x0f0f0f0f;
7680
+ uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
7681
+ uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
7682
+ val1 = v1 | (v2 << 16);
7683
+ aux32 = (q4 >> 4) & 0x0f0f0f0f;
7684
+ v1 = values[q8[0]] | (values[q8[1]] << 8);
7685
+ v2 = values[q8[2]] | (values[q8[3]] << 8);
7686
+ val2 = v1 | (v2 << 16);
7687
+ }
7688
+
7689
+
7690
+ static __dpct_inline__ float
7691
+ vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
7692
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7693
+
7694
+ const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
7695
+
7696
+ const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
7697
+ const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
7698
+
7699
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
7700
+
7701
+ int v1, v2;
7702
+ int sumi1 = 0, sumi2 = 0;
7703
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
7704
+ const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
7705
+ get_int_from_table_16(aux, values, v1, v2);
7706
+ sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);
7707
+ sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);
7491
7708
  }
7492
- const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f;
7493
- return d * sumi;
7709
+
7710
+ const float d = (float)bq->d * bq8_1->ds[0];
7711
+ return d * (sumi1 + sumi2);
7712
+ }
7713
+
7714
+
7715
+ static __dpct_inline__ float
7716
+ vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
7717
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7718
+
7719
+ #if QK_K == 256
7720
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
7721
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
7722
+
7723
+ // iqs is 0...7
7724
+ const int ib32 = iqs;
7725
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
7726
+ const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
7727
+ const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
7728
+ const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];
7729
+ int v1, v2;
7730
+ int sumi1 = 0, sumi2 = 0;
7731
+ for (int j = 0; j < 4; ++j) {
7732
+ get_int_from_table_16(q4[j], values, v1, v2);
7733
+ sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);
7734
+ sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
7735
+ }
7736
+ return d * (sumi1 + sumi2);
7494
7737
  #else
7495
7738
  assert(false);
7496
- return 0.f;
7497
7739
  #endif
7498
7740
  }
7499
7741
 
@@ -8078,8 +8320,199 @@ template <bool need_check> static void
8078
8320
 
8079
8321
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
8080
8322
  static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8081
- const sycl::nd_item<3> &item_ct1,
8082
- const uint32_t *iq3xxs_grid_ptr=nullptr, const uint64_t *ksigns64_ptr=nullptr) {
8323
+ const sycl::nd_item<3> &item_ct1) {
8324
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8325
+ item_ct1.get_local_id(1);
8326
+
8327
+ if (row >= nrows) {
8328
+ return;
8329
+ }
8330
+
8331
+ const int blocks_per_row = ncols / qk;
8332
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8333
+
8334
+ // partial sum for each thread
8335
+ float tmp = 0.0f;
8336
+
8337
+ const block_q_t * x = (const block_q_t *) vx;
8338
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8339
+
8340
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8341
+ i += blocks_per_warp) {
8342
+ const int ibx = row*blocks_per_row + i; // x block index
8343
+
8344
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8345
+
8346
+ const int iqs =
8347
+ vdr *
8348
+ (item_ct1.get_local_id(2) %
8349
+ (qi / vdr)); // x block quant index when casting the quants to int
8350
+
8351
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8352
+ }
8353
+
8354
+ // sum up partial sums and write back result
8355
+ #pragma unroll
8356
+ for (int mask = 16; mask > 0; mask >>= 1) {
8357
+ tmp +=
8358
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8359
+ }
8360
+
8361
+ if (item_ct1.get_local_id(2) == 0) {
8362
+ dst[row] = tmp;
8363
+ }
8364
+ }
8365
+
8366
+ template <int qk, int qi, typename block_q_t, int vdr>
8367
+ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
8368
+ const void *__restrict__ vy,
8369
+ float *__restrict__ dst, const int ncols,
8370
+ const int nrows,
8371
+ const sycl::nd_item<3> &item_ct1) {
8372
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8373
+ item_ct1.get_local_id(1);
8374
+
8375
+ if (row >= nrows) {
8376
+ return;
8377
+ }
8378
+
8379
+ const int blocks_per_row = ncols / qk;
8380
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8381
+
8382
+ // partial sum for each thread
8383
+ float tmp = 0.0f;
8384
+
8385
+ const block_q_t * x = (const block_q_t *) vx;
8386
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8387
+
8388
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8389
+ i += blocks_per_warp) {
8390
+ const int ibx = row*blocks_per_row + i; // x block index
8391
+
8392
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8393
+
8394
+ const int iqs =
8395
+ vdr *
8396
+ (item_ct1.get_local_id(2) %
8397
+ (qi / vdr)); // x block quant index when casting the quants to int
8398
+
8399
+ tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
8400
+ }
8401
+
8402
+ // sum up partial sums and write back result
8403
+ #pragma unroll
8404
+ for (int mask = 16; mask > 0; mask >>= 1) {
8405
+ tmp +=
8406
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8407
+ }
8408
+
8409
+ if (item_ct1.get_local_id(2) == 0) {
8410
+ dst[row] = tmp;
8411
+ }
8412
+ }
8413
+
8414
+ template <int qk, int qi, typename block_q_t, int vdr>
8415
+ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
8416
+ const void *__restrict__ vy,
8417
+ float *__restrict__ dst, const int ncols,
8418
+ const int nrows,
8419
+ const sycl::nd_item<3> &item_ct1) {
8420
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8421
+ item_ct1.get_local_id(1);
8422
+
8423
+ if (row >= nrows) {
8424
+ return;
8425
+ }
8426
+
8427
+ const int blocks_per_row = ncols / qk;
8428
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8429
+
8430
+ // partial sum for each thread
8431
+ float tmp = 0.0f;
8432
+
8433
+ const block_q_t * x = (const block_q_t *) vx;
8434
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8435
+
8436
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8437
+ i += blocks_per_warp) {
8438
+ const int ibx = row*blocks_per_row + i; // x block index
8439
+
8440
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8441
+
8442
+ const int iqs =
8443
+ vdr *
8444
+ (item_ct1.get_local_id(2) %
8445
+ (qi / vdr)); // x block quant index when casting the quants to int
8446
+
8447
+ tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
8448
+ }
8449
+
8450
+ // sum up partial sums and write back result
8451
+ #pragma unroll
8452
+ for (int mask = 16; mask > 0; mask >>= 1) {
8453
+ tmp +=
8454
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8455
+ }
8456
+
8457
+ if (item_ct1.get_local_id(2) == 0) {
8458
+ dst[row] = tmp;
8459
+ }
8460
+ }
8461
+
8462
+ template <int qk, int qi, typename block_q_t, int vdr>
8463
+ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
8464
+ const void *__restrict__ vy,
8465
+ float *__restrict__ dst, const int ncols,
8466
+ const int nrows,
8467
+ const sycl::nd_item<3> &item_ct1) {
8468
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8469
+ item_ct1.get_local_id(1);
8470
+
8471
+ if (row >= nrows) {
8472
+ return;
8473
+ }
8474
+
8475
+ const int blocks_per_row = ncols / qk;
8476
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8477
+
8478
+ // partial sum for each thread
8479
+ float tmp = 0.0f;
8480
+
8481
+ const block_q_t * x = (const block_q_t *) vx;
8482
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8483
+
8484
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8485
+ i += blocks_per_warp) {
8486
+ const int ibx = row*blocks_per_row + i; // x block index
8487
+
8488
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8489
+
8490
+ const int iqs =
8491
+ vdr *
8492
+ (item_ct1.get_local_id(2) %
8493
+ (qi / vdr)); // x block quant index when casting the quants to int
8494
+
8495
+ tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
8496
+ }
8497
+
8498
+ // sum up partial sums and write back result
8499
+ #pragma unroll
8500
+ for (int mask = 16; mask > 0; mask >>= 1) {
8501
+ tmp +=
8502
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8503
+ }
8504
+
8505
+ if (item_ct1.get_local_id(2) == 0) {
8506
+ dst[row] = tmp;
8507
+ }
8508
+ }
8509
+
8510
+ template <int qk, int qi, typename block_q_t, int vdr>
8511
+ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
8512
+ const void *__restrict__ vy,
8513
+ float *__restrict__ dst, const int ncols,
8514
+ const int nrows,
8515
+ const sycl::nd_item<3> &item_ct1) {
8083
8516
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8084
8517
  item_ct1.get_local_id(1);
8085
8518
 
@@ -8107,7 +8540,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
8107
8540
  (item_ct1.get_local_id(2) %
8108
8541
  (qi / vdr)); // x block quant index when casting the quants to int
8109
8542
 
8110
- tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8543
+ tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
8111
8544
  }
8112
8545
 
8113
8546
  // sum up partial sums and write back result
@@ -8123,10 +8556,11 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
8123
8556
  }
8124
8557
 
8125
8558
  template <int qk, int qi, typename block_q_t, int vdr>
8126
- static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8127
- const sycl::nd_item<3> &item_ct1,
8128
- const uint64_t *iq2xxs_grid_ptr, const uint8_t *ksigns_iq2xs_ptr,
8129
- const uint8_t *kmask_iq2xs_ptr ) {
8559
+ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
8560
+ const void *__restrict__ vy,
8561
+ float *__restrict__ dst, const int ncols,
8562
+ const int nrows,
8563
+ const sycl::nd_item<3> &item_ct1) {
8130
8564
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8131
8565
  item_ct1.get_local_id(1);
8132
8566
 
@@ -8154,7 +8588,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void
8154
8588
  (item_ct1.get_local_id(2) %
8155
8589
  (qi / vdr)); // x block quant index when casting the quants to int
8156
8590
 
8157
- tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid_ptr, ksigns_iq2xs_ptr, kmask_iq2xs_ptr);
8591
+ tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
8158
8592
  }
8159
8593
 
8160
8594
  // sum up partial sums and write back result
@@ -8170,9 +8604,11 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void
8170
8604
  }
8171
8605
 
8172
8606
  template <int qk, int qi, typename block_q_t, int vdr>
8173
- static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8174
- const sycl::nd_item<3> &item_ct1,
8175
- const uint64_t *iq2xs_grid_ptr, const uint64_t *ksigns64_ptr ) {
8607
+ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
8608
+ const void *__restrict__ vy,
8609
+ float *__restrict__ dst, const int ncols,
8610
+ const int nrows,
8611
+ const sycl::nd_item<3> &item_ct1) {
8176
8612
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8177
8613
  item_ct1.get_local_id(1);
8178
8614
 
@@ -8200,7 +8636,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void *
8200
8636
  (item_ct1.get_local_id(2) %
8201
8637
  (qi / vdr)); // x block quant index when casting the quants to int
8202
8638
 
8203
- tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid_ptr, ksigns64_ptr);
8639
+ tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
8204
8640
  }
8205
8641
 
8206
8642
  // sum up partial sums and write back result
@@ -8216,9 +8652,11 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void *
8216
8652
  }
8217
8653
 
8218
8654
  template <int qk, int qi, typename block_q_t, int vdr>
8219
- static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8220
- const sycl::nd_item<3> &item_ct1,
8221
- const uint32_t *iq3xxs_grid_ptr, const uint64_t *ksigns64_ptr ) {
8655
+ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
8656
+ const void *__restrict__ vy,
8657
+ float *__restrict__ dst, const int ncols,
8658
+ const int nrows,
8659
+ const sycl::nd_item<3> &item_ct1) {
8222
8660
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8223
8661
  item_ct1.get_local_id(1);
8224
8662
 
@@ -8246,7 +8684,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
8246
8684
  (item_ct1.get_local_id(2) %
8247
8685
  (qi / vdr)); // x block quant index when casting the quants to int
8248
8686
 
8249
- tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid_ptr, ksigns64_ptr);
8687
+ tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
8250
8688
  }
8251
8689
 
8252
8690
  // sum up partial sums and write back result
@@ -8262,9 +8700,11 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
8262
8700
  }
8263
8701
 
8264
8702
  template <int qk, int qi, typename block_q_t, int vdr>
8265
- static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8266
- const sycl::nd_item<3> &item_ct1,
8267
- const uint32_t *iq3s_grid_ptr, const uint64_t *ksigns64_ptr ) {
8703
+ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
8704
+ const void *__restrict__ vy,
8705
+ float *__restrict__ dst, const int ncols,
8706
+ const int nrows,
8707
+ const sycl::nd_item<3> &item_ct1) {
8268
8708
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8269
8709
  item_ct1.get_local_id(1);
8270
8710
 
@@ -8292,7 +8732,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
8292
8732
  (item_ct1.get_local_id(2) %
8293
8733
  (qi / vdr)); // x block quant index when casting the quants to int
8294
8734
 
8295
- tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid_ptr, ksigns64_ptr);
8735
+ tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
8296
8736
  }
8297
8737
 
8298
8738
  // sum up partial sums and write back result
@@ -8307,10 +8747,13 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
8307
8747
  }
8308
8748
  }
8309
8749
 
8750
+
8310
8751
  template <int qk, int qi, typename block_q_t, int vdr>
8311
- static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8312
- const sycl::nd_item<3> &item_ct1,
8313
- const uint32_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) {
8752
+ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
8753
+ const void *__restrict__ vy,
8754
+ float *__restrict__ dst, const int ncols,
8755
+ const int nrows,
8756
+ const sycl::nd_item<3> &item_ct1) {
8314
8757
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8315
8758
  item_ct1.get_local_id(1);
8316
8759
 
@@ -8338,7 +8781,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
8338
8781
  (item_ct1.get_local_id(2) %
8339
8782
  (qi / vdr)); // x block quant index when casting the quants to int
8340
8783
 
8341
- tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr);
8784
+ tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
8342
8785
  }
8343
8786
 
8344
8787
  // sum up partial sums and write back result
@@ -8353,6 +8796,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
8353
8796
  }
8354
8797
  }
8355
8798
 
8799
+
8356
8800
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
8357
8801
  static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
8358
8802
  const sycl::nd_item<3> &item_ct1) {
@@ -8914,64 +9358,71 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
8914
9358
  }
8915
9359
  }
8916
9360
 
9361
+
8917
9362
  template<typename T>
8918
- static inline void swap(T & a, T & b) {
9363
+ static inline void ggml_sycl_swap(T & a, T & b) {
8919
9364
  T tmp = a;
8920
9365
  a = b;
8921
9366
  b = tmp;
8922
9367
  }
8923
9368
 
8924
- template<ggml_sort_order order>
8925
- static void k_argsort_f32_i32(const float * x, int * dst, const int ncols,
8926
- const sycl::nd_item<3> &item_ct1) {
9369
+ template <ggml_sort_order order>
9370
+ __dpct_inline__ static void
9371
+ k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
9372
+ const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
8927
9373
  // bitonic sort
8928
9374
  int col = item_ct1.get_local_id(2);
8929
9375
  int row = item_ct1.get_group(1);
8930
9376
 
8931
- if (col >= ncols) return;
9377
+ if (col >= ncols_pad) {
9378
+ return;
9379
+ }
8932
9380
 
8933
9381
  const float * x_row = x + row * ncols;
8934
- int * dst_row = dst + row * ncols;
9382
+ auto dst_row = (int *)dpct_local;
8935
9383
 
8936
9384
  // initialize indices
8937
- if (col < ncols) {
8938
- dst_row[col] = col;
8939
- }
8940
- /*
8941
- DPCT1065:58: Consider replacing sycl::nd_item::barrier() with
8942
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
8943
- performance if there is no access to global memory.
8944
- */
8945
- item_ct1.barrier();
9385
+ dst_row[col] = col;
9386
+
9387
+ item_ct1.barrier(sycl::access::fence_space::local_space);
8946
9388
 
8947
- for (int k = 2; k <= ncols; k *= 2) {
9389
+ for (int k = 2; k <= ncols_pad; k *= 2) {
8948
9390
  for (int j = k / 2; j > 0; j /= 2) {
8949
9391
  int ixj = col ^ j;
8950
9392
  if (ixj > col) {
8951
9393
  if ((col & k) == 0) {
8952
- if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
8953
- swap(dst_row[col], dst_row[ixj]);
9394
+ if (dst_row[col] >= ncols ||
9395
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
9396
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
9397
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
9398
+ ) {
9399
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
8954
9400
  }
8955
9401
  } else {
8956
- if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
8957
- swap(dst_row[col], dst_row[ixj]);
9402
+ if (dst_row[ixj] >= ncols ||
9403
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
9404
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
9405
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
9406
+ ) {
9407
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
8958
9408
  }
8959
9409
  }
8960
9410
  }
8961
9411
  /*
8962
- DPCT1118:11: SYCL group functions and algorithms must be encountered
9412
+ DPCT1118:1: SYCL group functions and algorithms must be encountered
8963
9413
  in converged control flow. You may need to adjust the code.
8964
9414
  */
8965
- /*
8966
- DPCT1065:59: Consider replacing sycl::nd_item::barrier() with
8967
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8968
- better performance if there is no access to global memory.
8969
- */
8970
- item_ct1.barrier();
9415
+ item_ct1.barrier(sycl::access::fence_space::local_space);
8971
9416
  }
8972
9417
  }
9418
+
9419
+ // copy the result to dst without the padding
9420
+ if (col < ncols) {
9421
+ dst[row * ncols + col] = dst_row[col];
9422
+ }
8973
9423
  }
8974
9424
 
9425
+
8975
9426
  static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
8976
9427
  const sycl::nd_item<3> &item_ct1) {
8977
9428
  const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -9950,28 +10401,64 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
9950
10401
  #endif
9951
10402
  }
9952
10403
 
9953
-
9954
10404
  template <typename dst_t>
9955
- static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
10405
+ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
9956
10406
  dpct::queue_ptr stream) {
9957
10407
  const int nb = k / QK_K;
9958
10408
  {
10409
+ dpct::has_capability_or_fail(stream->get_device(),
10410
+ {sycl::aspect::fp16});
10411
+
10412
+ stream->submit([&](sycl::handler &cgh) {
10413
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10414
+ sycl::range<3>(1, 1, 32),
10415
+ sycl::range<3>(1, 1, 32)),
10416
+ [=](sycl::nd_item<3> item_ct1) {
10417
+ dequantize_block_iq1_s(
10418
+ vx, y, item_ct1, iq1s_grid_gpu
10419
+ );
10420
+ });
10421
+ });
10422
+ }
10423
+ }
9959
10424
 
10425
+ template <typename dst_t>
10426
+ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
10427
+ dpct::queue_ptr stream) {
10428
+ const int nb = k / QK_K;
10429
+ {
9960
10430
  dpct::has_capability_or_fail(stream->get_device(),
9961
10431
  {sycl::aspect::fp16});
9962
10432
 
9963
10433
  stream->submit([&](sycl::handler &cgh) {
9964
- auto iq2xxs_grid_ptr_ct1 = &iq2xxs_grid[0];
9965
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
9966
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10434
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10435
+ sycl::range<3>(1, 1, 32),
10436
+ sycl::range<3>(1, 1, 32)),
10437
+ [=](sycl::nd_item<3> item_ct1) {
10438
+ dequantize_block_iq1_m(
10439
+ vx, y, item_ct1, iq1s_grid_gpu
10440
+ );
10441
+ });
10442
+ });
10443
+ }
10444
+ }
10445
+
10446
+ template <typename dst_t>
10447
+ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
10448
+ dpct::queue_ptr stream) {
10449
+ const int nb = k / QK_K;
10450
+ {
10451
+ dpct::has_capability_or_fail(stream->get_device(),
10452
+ {sycl::aspect::fp16});
9967
10453
 
10454
+ stream->submit([&](sycl::handler &cgh) {
9968
10455
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
9969
10456
  sycl::range<3>(1, 1, 32),
9970
10457
  sycl::range<3>(1, 1, 32)),
9971
10458
  [=](sycl::nd_item<3> item_ct1) {
9972
10459
  dequantize_block_iq2_xxs(
9973
- vx, y, item_ct1, iq2xxs_grid_ptr_ct1,
9974
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10460
+ vx, y, item_ct1, iq2xxs_grid,
10461
+ ksigns_iq2xs, kmask_iq2xs);
9975
10462
  });
9976
10463
  });
9977
10464
  }
@@ -9982,105 +10469,130 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
9982
10469
  dpct::queue_ptr stream) {
9983
10470
  const int nb = k / QK_K;
9984
10471
  {
9985
-
9986
10472
  dpct::has_capability_or_fail(stream->get_device(),
9987
10473
  {sycl::aspect::fp16});
9988
10474
 
9989
10475
  stream->submit([&](sycl::handler &cgh) {
9990
- auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
9991
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
9992
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
9993
-
9994
10476
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
9995
10477
  sycl::range<3>(1, 1, 32),
9996
10478
  sycl::range<3>(1, 1, 32)),
9997
10479
  [=](sycl::nd_item<3> item_ct1) {
9998
10480
  dequantize_block_iq2_xs(
9999
- vx, y, item_ct1, iq2xs_grid_ptr_ct1,
10000
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10481
+ vx, y, item_ct1, iq2xs_grid,
10482
+ ksigns_iq2xs, kmask_iq2xs);
10001
10483
  });
10002
10484
  });
10003
10485
  }
10004
10486
  }
10005
10487
 
10006
10488
  template <typename dst_t>
10007
- static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
10008
- dpct::queue_ptr stream) {
10489
+ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
10490
+ dpct::queue_ptr stream) {
10009
10491
  const int nb = k / QK_K;
10010
10492
  {
10011
-
10012
10493
  dpct::has_capability_or_fail(stream->get_device(),
10013
10494
  {sycl::aspect::fp16});
10014
10495
 
10015
10496
  stream->submit([&](sycl::handler &cgh) {
10016
- auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0];
10017
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10018
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10019
-
10020
10497
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10021
10498
  sycl::range<3>(1, 1, 32),
10022
10499
  sycl::range<3>(1, 1, 32)),
10023
10500
  [=](sycl::nd_item<3> item_ct1) {
10024
- dequantize_block_iq3_xxs(
10025
- vx, y, item_ct1, iq3xxs_grid_ptr_ct1,
10026
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10501
+ dequantize_block_iq2_s(vx, y, item_ct1);
10027
10502
  });
10028
10503
  });
10029
10504
  }
10030
10505
  }
10031
10506
 
10507
+
10032
10508
  template <typename dst_t>
10033
- static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
10509
+ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
10034
10510
  dpct::queue_ptr stream) {
10035
10511
  const int nb = k / QK_K;
10036
10512
  {
10037
-
10038
10513
  dpct::has_capability_or_fail(stream->get_device(),
10039
10514
  {sycl::aspect::fp16});
10040
10515
 
10041
10516
  stream->submit([&](sycl::handler &cgh) {
10042
- auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
10043
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10044
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10045
-
10046
10517
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10047
10518
  sycl::range<3>(1, 1, 32),
10048
10519
  sycl::range<3>(1, 1, 32)),
10049
10520
  [=](sycl::nd_item<3> item_ct1) {
10050
- dequantize_block_iq3_s(
10051
- vx, y, item_ct1, iq3s_grid_ptr_ct1,
10052
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10521
+ dequantize_block_iq3_xxs(
10522
+ vx, y, item_ct1, iq3xxs_grid,
10523
+ ksigns_iq2xs, kmask_iq2xs);
10053
10524
  });
10054
10525
  });
10055
10526
  }
10056
10527
  }
10057
10528
 
10058
10529
  template <typename dst_t>
10059
- static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
10530
+ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
10060
10531
  dpct::queue_ptr stream) {
10061
10532
  const int nb = k / QK_K;
10062
10533
  {
10063
-
10064
10534
  dpct::has_capability_or_fail(stream->get_device(),
10065
10535
  {sycl::aspect::fp16});
10066
10536
 
10067
10537
  stream->submit([&](sycl::handler &cgh) {
10068
- auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0];
10069
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10070
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10071
-
10072
10538
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10073
10539
  sycl::range<3>(1, 1, 32),
10074
10540
  sycl::range<3>(1, 1, 32)),
10075
10541
  [=](sycl::nd_item<3> item_ct1) {
10076
- dequantize_block_iq1_s(
10077
- vx, y, item_ct1, iq1s_grid_ptr_ct1,
10078
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10542
+ dequantize_block_iq3_s(
10543
+ vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
10079
10544
  });
10080
10545
  });
10081
10546
  }
10082
10547
  }
10083
10548
 
10549
+ template <typename dst_t>
10550
+ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
10551
+ dpct::queue_ptr stream) {
10552
+ const int nb = (k + QK_K - 1) / QK_K;
10553
+ #if QK_K == 64
10554
+ dequantize_row_iq4_nl_sycl(vx, y, k, stream);
10555
+ #else
10556
+ {
10557
+ dpct::has_capability_or_fail(stream->get_device(),
10558
+ {sycl::aspect::fp16});
10559
+
10560
+ stream->submit([&](sycl::handler &cgh) {
10561
+ cgh.parallel_for(
10562
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10563
+ sycl::range<3>(1, 1, 32),
10564
+ sycl::range<3>(1, 1, 32)),
10565
+ [=](sycl::nd_item<3> item_ct1) {
10566
+ dequantize_block_iq4_xs(vx, y, item_ct1);
10567
+ });
10568
+ });
10569
+ }
10570
+ #endif
10571
+ }
10572
+
10573
+
10574
+ template <typename dst_t>
10575
+ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
10576
+ dpct::queue_ptr stream) {
10577
+ const int nb = (k + QK_K - 1) / QK_K;
10578
+ {
10579
+ dpct::has_capability_or_fail(stream->get_device(),
10580
+ {sycl::aspect::fp16});
10581
+
10582
+ stream->submit([&](sycl::handler &cgh) {
10583
+ cgh.parallel_for(
10584
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10585
+ sycl::range<3>(1, 1, 32),
10586
+ sycl::range<3>(1, 1, 32)),
10587
+ [=](sycl::nd_item<3> item_ct1) {
10588
+ dequantize_block_iq4_nl(vx, y, item_ct1);
10589
+ });
10590
+ });
10591
+ }
10592
+ }
10593
+
10594
+
10595
+
10084
10596
  template <typename src_t, typename dst_t>
10085
10597
  static void convert_unary_sycl(const void *__restrict__ vx,
10086
10598
  dst_t *__restrict__ y, const int k,
@@ -10125,16 +10637,24 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
10125
10637
  return dequantize_row_q5_K_sycl;
10126
10638
  case GGML_TYPE_Q6_K:
10127
10639
  return dequantize_row_q6_K_sycl;
10640
+ case GGML_TYPE_IQ1_S:
10641
+ return dequantize_row_iq1_s_sycl;
10642
+ case GGML_TYPE_IQ1_M:
10643
+ return dequantize_row_iq1_m_sycl;
10128
10644
  case GGML_TYPE_IQ2_XXS:
10129
10645
  return dequantize_row_iq2_xxs_sycl;
10130
10646
  case GGML_TYPE_IQ2_XS:
10131
10647
  return dequantize_row_iq2_xs_sycl;
10648
+ case GGML_TYPE_IQ2_S:
10649
+ return dequantize_row_iq2_s_sycl;
10132
10650
  case GGML_TYPE_IQ3_XXS:
10133
10651
  return dequantize_row_iq3_xxs_sycl;
10134
10652
  case GGML_TYPE_IQ3_S:
10135
10653
  return dequantize_row_iq3_s_sycl;
10136
- case GGML_TYPE_IQ1_S:
10137
- return dequantize_row_iq1_s_sycl;
10654
+ case GGML_TYPE_IQ4_XS:
10655
+ return dequantize_row_iq4_xs_sycl;
10656
+ case GGML_TYPE_IQ4_NL:
10657
+ return dequantize_row_iq4_nl_sycl;
10138
10658
  case GGML_TYPE_F32:
10139
10659
  return convert_unary_sycl<float>;
10140
10660
  default:
@@ -10169,16 +10689,24 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
10169
10689
  return dequantize_row_q5_K_sycl;
10170
10690
  case GGML_TYPE_Q6_K:
10171
10691
  return dequantize_row_q6_K_sycl;
10692
+ case GGML_TYPE_IQ1_S:
10693
+ return dequantize_row_iq1_s_sycl;
10694
+ case GGML_TYPE_IQ1_M:
10695
+ return dequantize_row_iq1_m_sycl;
10172
10696
  case GGML_TYPE_IQ2_XXS:
10173
10697
  return dequantize_row_iq2_xxs_sycl;
10174
10698
  case GGML_TYPE_IQ2_XS:
10175
10699
  return dequantize_row_iq2_xs_sycl;
10700
+ case GGML_TYPE_IQ2_S:
10701
+ return dequantize_row_iq2_s_sycl;
10176
10702
  case GGML_TYPE_IQ3_XXS:
10177
10703
  return dequantize_row_iq3_xxs_sycl;
10178
10704
  case GGML_TYPE_IQ3_S:
10179
10705
  return dequantize_row_iq3_s_sycl;
10180
- case GGML_TYPE_IQ1_S:
10181
- return dequantize_row_iq1_s_sycl;
10706
+ case GGML_TYPE_IQ4_XS:
10707
+ return dequantize_row_iq4_xs_sycl;
10708
+ case GGML_TYPE_IQ4_NL:
10709
+ return dequantize_row_iq4_nl_sycl;
10182
10710
  case GGML_TYPE_F16:
10183
10711
  return convert_unary_sycl<sycl::half>;
10184
10712
  default:
@@ -10641,19 +11169,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
10641
11169
  const sycl::range<3> block_nums(1, 1, block_num_y);
10642
11170
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
10643
11171
  {
10644
-
10645
11172
  stream->submit([&](sycl::handler &cgh) {
10646
- auto iq2xxs_grid_ptr_ct1 = &iq2xxs_grid[0];
10647
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10648
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10649
-
10650
11173
  cgh.parallel_for(
10651
11174
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
10652
11175
  [=](sycl::nd_item<3> item_ct1)
10653
11176
  [[intel::reqd_sub_group_size(32)]] {
10654
11177
  mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS, block_iq2_xxs, 1>(
10655
- vx, vy, dst, ncols, nrows, item_ct1,
10656
- iq2xxs_grid_ptr_ct1, ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
11178
+ vx, vy, dst, ncols, nrows, item_ct1);
10657
11179
  });
10658
11180
  });
10659
11181
  }
@@ -10678,8 +11200,32 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
10678
11200
  [=](sycl::nd_item<3> item_ct1)
10679
11201
  [[intel::reqd_sub_group_size(32)]] {
10680
11202
  mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS, block_iq2_xs, 1>(
10681
- vx, vy, dst, ncols, nrows, item_ct1,
10682
- iq2xs_grid_ptr_ct1, ksigns64_ptr_ct1);
11203
+ vx, vy, dst, ncols, nrows, item_ct1);
11204
+ });
11205
+ });
11206
+ }
11207
+ }
11208
+
11209
+ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
11210
+ float *dst, const int ncols,
11211
+ const int nrows,
11212
+ dpct::queue_ptr stream) {
11213
+ GGML_ASSERT(ncols % QK_K == 0);
11214
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11215
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11216
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11217
+ {
11218
+
11219
+ stream->submit([&](sycl::handler &cgh) {
11220
+ auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
11221
+ auto ksigns64_ptr_ct1 = &ksigns64[0];
11222
+
11223
+ cgh.parallel_for(
11224
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11225
+ [=](sycl::nd_item<3> item_ct1)
11226
+ [[intel::reqd_sub_group_size(32)]] {
11227
+ mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
11228
+ vx, vy, dst, ncols, nrows, item_ct1);
10683
11229
  });
10684
11230
  });
10685
11231
  }
@@ -10704,8 +11250,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
10704
11250
  [=](sycl::nd_item<3> item_ct1)
10705
11251
  [[intel::reqd_sub_group_size(32)]] {
10706
11252
  mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS, block_iq3_xxs, 1>(
10707
- vx, vy, dst, ncols, nrows, item_ct1,
10708
- iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1);
11253
+ vx, vy, dst, ncols, nrows, item_ct1);
10709
11254
  });
10710
11255
  });
10711
11256
  }
@@ -10723,15 +11268,13 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
10723
11268
 
10724
11269
  stream->submit([&](sycl::handler &cgh) {
10725
11270
  auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
10726
- auto ksigns64_ptr_ct1 = &ksigns64[0];
10727
11271
 
10728
11272
  cgh.parallel_for(
10729
11273
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
10730
11274
  [=](sycl::nd_item<3> item_ct1)
10731
11275
  [[intel::reqd_sub_group_size(32)]] {
10732
11276
  mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_XS, block_iq3_s, 1>(
10733
- vx, vy, dst, ncols, nrows, item_ct1,
10734
- iq3s_grid_ptr_ct1, ksigns64_ptr_ct1);
11277
+ vx, vy, dst, ncols, nrows, item_ct1);
10735
11278
  });
10736
11279
  });
10737
11280
  }
@@ -10756,8 +11299,72 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
10756
11299
  [=](sycl::nd_item<3> item_ct1)
10757
11300
  [[intel::reqd_sub_group_size(32)]] {
10758
11301
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
10759
- vx, vy, dst, ncols, nrows, item_ct1,
10760
- iq1s_grid_ptr_ct1, ksigns64_ptr_ct1);
11302
+ vx, vy, dst, ncols, nrows, item_ct1);
11303
+ });
11304
+ });
11305
+ }
11306
+ }
11307
+
11308
+ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
11309
+ float *dst, const int ncols,
11310
+ const int nrows,
11311
+ dpct::queue_ptr stream) {
11312
+ GGML_ASSERT(ncols % QK_K == 0);
11313
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11314
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11315
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11316
+ {
11317
+ stream->submit([&](sycl::handler &cgh) {
11318
+ cgh.parallel_for(
11319
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11320
+ [=](sycl::nd_item<3> item_ct1)
11321
+ [[intel::reqd_sub_group_size(32)]] {
11322
+ mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
11323
+ vx, vy, dst, ncols, nrows, item_ct1);
11324
+ });
11325
+ });
11326
+ }
11327
+ }
11328
+
11329
+ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
11330
+ float *dst, const int ncols,
11331
+ const int nrows,
11332
+ dpct::queue_ptr stream) {
11333
+ GGML_ASSERT(ncols % QK4_NL == 0);
11334
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11335
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11336
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11337
+ {
11338
+
11339
+ stream->submit([&](sycl::handler &cgh) {
11340
+ cgh.parallel_for(
11341
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11342
+ [=](sycl::nd_item<3> item_ct1)
11343
+ [[intel::reqd_sub_group_size(32)]] {
11344
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
11345
+ vx, vy, dst, ncols, nrows, item_ct1);
11346
+ });
11347
+ });
11348
+ }
11349
+ }
11350
+
11351
+ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
11352
+ float *dst, const int ncols,
11353
+ const int nrows,
11354
+ dpct::queue_ptr stream) {
11355
+ GGML_ASSERT(ncols % QK_K == 0);
11356
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11357
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11358
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11359
+ {
11360
+
11361
+ stream->submit([&](sycl::handler &cgh) {
11362
+ cgh.parallel_for(
11363
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11364
+ [=](sycl::nd_item<3> item_ct1)
11365
+ [[intel::reqd_sub_group_size(32)]] {
11366
+ mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS, block_iq4_xs, 1>(
11367
+ vx, vy, dst, ncols, nrows, item_ct1);
10761
11368
  });
10762
11369
  });
10763
11370
  }
@@ -12381,36 +12988,54 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12381
12988
  });
12382
12989
  }
12383
12990
 
12991
+ static int next_power_of_2(int x) {
12992
+ int n = 1;
12993
+ while (n < x) {
12994
+ n *= 2;
12995
+ }
12996
+ return n;
12997
+ }
12998
+
12384
12999
  static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
12385
13000
  const int nrows, ggml_sort_order order,
12386
13001
  dpct::queue_ptr stream) {
12387
13002
  // bitonic sort requires ncols to be power of 2
12388
- GGML_ASSERT((ncols & (ncols - 1)) == 0);
13003
+ const int ncols_pad = next_power_of_2(ncols);
12389
13004
 
12390
- const sycl::range<3> block_dims(1, 1, ncols);
13005
+ const sycl::range<3> block_dims(1, 1, ncols_pad);
12391
13006
  const sycl::range<3> block_nums(1, nrows, 1);
13007
+ const size_t shared_mem = ncols_pad * sizeof(int);
13008
+
13009
+ // GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
13010
+
12392
13011
  if (order == GGML_SORT_ORDER_ASC) {
12393
- /*
12394
- DPCT1049:44: The work-group size passed to the SYCL kernel may exceed
12395
- the limit. To get the device limit, query
12396
- info::device::max_work_group_size. Adjust the work-group size if needed.
12397
- */
12398
- stream->parallel_for(
12399
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12400
- [=](sycl::nd_item<3> item_ct1) {
12401
- k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(x, dst, ncols, item_ct1);
12402
- });
13012
+ stream->submit([&](sycl::handler &cgh) {
13013
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
13014
+ sycl::range<1>(shared_mem), cgh);
13015
+
13016
+ cgh.parallel_for(
13017
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
13018
+ [=](sycl::nd_item<3> item_ct1) {
13019
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
13020
+ x, dst, ncols, ncols_pad, item_ct1,
13021
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
13022
+ .get());
13023
+ });
13024
+ });
12403
13025
  } else if (order == GGML_SORT_ORDER_DESC) {
12404
- /*
12405
- DPCT1049:45: The work-group size passed to the SYCL kernel may exceed
12406
- the limit. To get the device limit, query
12407
- info::device::max_work_group_size. Adjust the work-group size if needed.
12408
- */
12409
- stream->parallel_for(
12410
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12411
- [=](sycl::nd_item<3> item_ct1) {
12412
- k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(x, dst, ncols, item_ct1);
12413
- });
13026
+ stream->submit([&](sycl::handler &cgh) {
13027
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
13028
+ sycl::range<1>(shared_mem), cgh);
13029
+
13030
+ cgh.parallel_for(
13031
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
13032
+ [=](sycl::nd_item<3> item_ct1) {
13033
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
13034
+ x, dst, ncols, ncols_pad, item_ct1,
13035
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
13036
+ .get());
13037
+ });
13038
+ });
12414
13039
  } else {
12415
13040
  GGML_ASSERT(false);
12416
13041
  }
@@ -13538,8 +14163,12 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
13538
14163
  case GGML_TYPE_Q5_K:
13539
14164
  case GGML_TYPE_IQ2_XXS:
13540
14165
  case GGML_TYPE_IQ2_XS:
14166
+ case GGML_TYPE_IQ2_S:
13541
14167
  case GGML_TYPE_IQ1_S:
14168
+ case GGML_TYPE_IQ1_M:
13542
14169
  case GGML_TYPE_IQ3_XXS:
14170
+ case GGML_TYPE_IQ4_XS:
14171
+ case GGML_TYPE_IQ4_NL:
13543
14172
  return max_compute_capability >= VER_GEN9 ? 128 : 64;
13544
14173
  case GGML_TYPE_IQ3_S:
13545
14174
  return max_compute_capability >= VER_GEN9 ? 128 : 64;
@@ -13558,11 +14187,20 @@ inline void ggml_sycl_op_mul_mat_vec_q(
13558
14187
  const int64_t src1_ncols, const int64_t src1_padded_row_size,
13559
14188
  const dpct::queue_ptr &stream) {
13560
14189
 
13561
- GGML_ASSERT(ggml_nrows(src1) == 1);
14190
+ const int64_t ne10 = src1->ne[0];
14191
+ GGML_ASSERT(ne10 % QK8_1 == 0);
13562
14192
 
13563
14193
  const int64_t ne00 = src0->ne[0];
13564
14194
  const int64_t row_diff = row_high - row_low;
13565
14195
 
14196
+ int id;
14197
+ SYCL_CHECK(
14198
+ CHECK_TRY_ERROR(id = get_current_device_id()));
14199
+
14200
+ // the main device has a larger memory buffer to hold the results from all GPUs
14201
+ // nrows_dst == nrows of the matrix that the kernel writes into
14202
+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne00 : row_diff;
14203
+
13566
14204
  switch (src0->type) {
13567
14205
  case GGML_TYPE_Q4_0:
13568
14206
  mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
@@ -13594,20 +14232,32 @@ inline void ggml_sycl_op_mul_mat_vec_q(
13594
14232
  case GGML_TYPE_Q6_K:
13595
14233
  mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13596
14234
  break;
14235
+ case GGML_TYPE_IQ1_S:
14236
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14237
+ break;
14238
+ case GGML_TYPE_IQ1_M:
14239
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14240
+ break;
13597
14241
  case GGML_TYPE_IQ2_XXS:
13598
14242
  mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13599
14243
  break;
13600
14244
  case GGML_TYPE_IQ2_XS:
13601
14245
  mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13602
14246
  break;
14247
+ case GGML_TYPE_IQ2_S:
14248
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14249
+ break;
13603
14250
  case GGML_TYPE_IQ3_XXS:
13604
14251
  mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13605
14252
  break;
13606
14253
  case GGML_TYPE_IQ3_S:
13607
14254
  mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13608
14255
  break;
13609
- case GGML_TYPE_IQ1_S:
13610
- mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14256
+ case GGML_TYPE_IQ4_NL:
14257
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14258
+ break;
14259
+ case GGML_TYPE_IQ4_XS:
14260
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13611
14261
  break;
13612
14262
  default:
13613
14263
  GGML_ASSERT(false);
@@ -13689,6 +14339,7 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
13689
14339
  convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
13690
14340
  break;
13691
14341
  default:
14342
+ printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
13692
14343
  GGML_ASSERT(false);
13693
14344
  break;
13694
14345
  }
@@ -14543,8 +15194,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
14543
15194
  src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
14544
15195
  }
14545
15196
  // do the computation
14546
- op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
14547
- dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream);
15197
+ SYCL_CHECK(CHECK_TRY_ERROR(op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
15198
+ dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
14548
15199
  /*
14549
15200
  DPCT1010:93: SYCL uses exceptions to report errors and does not
14550
15201
  use the error codes. The call was replaced with 0. You need to
@@ -15125,11 +15776,17 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15125
15776
  #ifdef GGML_SYCL_FORCE_DMMV
15126
15777
  const bool use_mul_mat_vec_q = false;
15127
15778
  #else
15128
- const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
15779
+ bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15780
+ use_mul_mat_vec_q = use_mul_mat_vec_q ||
15781
+ (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15782
+ (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15783
+ (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15784
+ (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15785
+
15786
+
15129
15787
  #endif // GGML_SYCL_FORCE_DMMV
15130
15788
 
15131
15789
  if (use_mul_mat_vec_q) {
15132
- // NOTE: this kernel does not support ggml_nrows(src1) > 1
15133
15790
  // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15134
15791
  ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15135
15792
  } else {
@@ -16985,9 +17642,14 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
16985
17642
  return false;
16986
17643
  }
16987
17644
  ggml_type a_type = a->type;
16988
- if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ2_S ||
16989
- a_type == GGML_TYPE_IQ4_XS) {
16990
- return false;
17645
+ if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
17646
+ a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
17647
+ a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
17648
+ a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
17649
+ ) {
17650
+ if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
17651
+ return false;
17652
+ }
16991
17653
  }
16992
17654
  return true;
16993
17655
  } break;