llama_cpp 0.14.4 → 0.14.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1664,24 +1664,6 @@ namespace dpct
1664
1664
  const void *alpha, const void *a, int lda, const void *b,
1665
1665
  int ldb, const void *beta, void *c, int ldc)
1666
1666
  {
1667
- #ifndef __INTEL_MKL__
1668
- GGML_UNUSED(q);
1669
- GGML_UNUSED(a_trans);
1670
- GGML_UNUSED(b_trans);
1671
- GGML_UNUSED(m);
1672
- GGML_UNUSED(n);
1673
- GGML_UNUSED(k);
1674
- GGML_UNUSED(alpha);
1675
- GGML_UNUSED(a);
1676
- GGML_UNUSED(lda);
1677
- GGML_UNUSED(b);
1678
- GGML_UNUSED(ldb);
1679
- GGML_UNUSED(beta);
1680
- GGML_UNUSED(c);
1681
- GGML_UNUSED(ldc);
1682
- throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
1683
- "Project does not support this API.");
1684
- #else
1685
1667
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1686
1668
  Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1687
1669
  auto data_a = get_memory<const Ta>(a);
@@ -1690,7 +1672,6 @@ namespace dpct
1690
1672
  oneapi::mkl::blas::column_major::gemm(
1691
1673
  q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1692
1674
  data_b, ldb, beta_value, data_c, ldc);
1693
- #endif
1694
1675
  }
1695
1676
 
1696
1677
  template <typename VecT, class BinaryOperation, class = void>
@@ -2330,6 +2311,7 @@ namespace dpct
2330
2311
  lda, b, ldb, beta, c, ldc);
2331
2312
  break;
2332
2313
  }
2314
+ #ifdef __INTEL_MKL__
2333
2315
  case detail::get_type_combination_id(
2334
2316
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2335
2317
  library_data_t::real_float, library_data_t::real_float):
@@ -2391,6 +2373,7 @@ namespace dpct
2391
2373
  q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
2392
2374
  break;
2393
2375
  }
2376
+ #endif // __INTEL_MKL__
2394
2377
  default:
2395
2378
  throw std::runtime_error("the combination of data type is unsupported");
2396
2379
  }
@@ -3055,6 +3038,10 @@ typedef float dfloat; // dequantize float
3055
3038
  typedef sycl::float2 dfloat2;
3056
3039
  #endif //GGML_SYCL_F16
3057
3040
 
3041
+ #define MMVQ_MAX_BATCH_SIZE 8
3042
+
3043
+ static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
3044
+
3058
3045
  bool ggml_sycl_loaded(void);
3059
3046
  void * ggml_sycl_host_malloc(size_t size);
3060
3047
  void ggml_sycl_host_free(void * ptr);
@@ -4490,6 +4477,32 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
4490
4477
 
4491
4478
  }
4492
4479
 
4480
+ template <typename dst_t>
4481
+ __dpct_inline__ static void
4482
+ dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4483
+ const sycl::nd_item<3> &item_ct1) {
4484
+
4485
+ const int i = item_ct1.get_group(2);
4486
+ const block_iq2_s * x = (const block_iq2_s *) vx;
4487
+
4488
+ const int tid = item_ct1.get_local_id(2);
4489
+ #if QK_K == 256
4490
+ const int il = tid/8; // 0...3
4491
+ const int ib = tid%8; // 0...7
4492
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4493
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
4494
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
4495
+ const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
4496
+ #pragma unroll
4497
+ for (int j = 0; j < 8; ++j)
4498
+ y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
4499
+ #else
4500
+ assert(false);
4501
+
4502
+ #endif
4503
+
4504
+ }
4505
+
4493
4506
  template<typename dst_t>
4494
4507
  static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
4495
4508
  const sycl::nd_item<3> &item_ct1,
@@ -4522,26 +4535,26 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
4522
4535
 
4523
4536
  }
4524
4537
 
4525
- template<typename dst_t>
4526
- static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
4527
- const sycl::nd_item<3> &item_ct1,
4528
- const uint32_t *iq3s_grid,
4529
- const uint8_t *ksigns_iq2xs,
4530
- const uint8_t *kmask_iq2xs) {
4538
+ template <typename dst_t>
4539
+ __dpct_inline__ static void
4540
+ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4541
+ const sycl::nd_item<3> &item_ct1,
4542
+ const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
4531
4543
 
4532
4544
  const int i = item_ct1.get_group(2);
4533
- const block_iq3_s * x = (const block_iq3_s *) vx;
4545
+ const block_iq3_s * x = (const block_iq3_s *) vx;
4534
4546
 
4535
4547
  const int tid = item_ct1.get_local_id(2);
4536
4548
  #if QK_K == 256
4537
4549
  const int il = tid/8; // 0...3
4538
4550
  const int ib = tid%8; // 0...7
4539
4551
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4540
- const uint8_t * qs = x[i].qs + 8*ib;
4541
- const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + qs[2*il+0]);
4542
- const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + qs[2*il+1]);
4552
+ const uint8_t * qs = x[i].qs + 8*ib;
4553
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
4554
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
4543
4555
  const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
4544
4556
  const uint8_t signs = x[i].signs[4*ib + il];
4557
+ #pragma unroll
4545
4558
  for (int j = 0; j < 4; ++j) {
4546
4559
  y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4547
4560
  y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
@@ -4552,12 +4565,12 @@ static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restr
4552
4565
 
4553
4566
  }
4554
4567
 
4555
- template<typename dst_t>
4556
- static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
4557
- const sycl::nd_item<3> &item_ct1,
4558
- const uint32_t *iq1s_grid,
4559
- const uint8_t *ksigns_iq2xs,
4560
- const uint8_t *kmask_iq2xs) {
4568
+ template <typename dst_t>
4569
+ __dpct_inline__ static void
4570
+ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
4571
+ const sycl::nd_item<3> &item_ct1,
4572
+ const uint32_t *iq1s_grid_gpu) {
4573
+
4561
4574
  const int i = item_ct1.get_group(2);
4562
4575
  const block_iq1_s * x = (const block_iq1_s *) vx;
4563
4576
 
@@ -4566,14 +4579,15 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
4566
4579
  const int il = tid/8; // 0...3
4567
4580
  const int ib = tid%8; // 0...7
4568
4581
  dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4569
- const uint8_t * qs = x[i].qs + 8*ib;
4570
- const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]);
4571
- const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]);
4572
- const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
4573
- const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
4574
- for (int j = 0; j < 4; ++j) {
4575
- y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4576
- y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4582
+ const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
4583
+ const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
4584
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
4585
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
4586
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
4587
+ grid32[0] &= 0x0f0f0f0f;
4588
+ #pragma unroll
4589
+ for (int j = 0; j < 8; ++j) {
4590
+ y[j] = d * (q[j] + delta);
4577
4591
  }
4578
4592
  #else
4579
4593
  assert(false);
@@ -4581,6 +4595,85 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
4581
4595
 
4582
4596
  }
4583
4597
 
4598
+ template <typename dst_t>
4599
+ __dpct_inline__ static void
4600
+ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
4601
+ const sycl::nd_item<3> &item_ct1,
4602
+ const uint32_t *iq1s_grid_gpu) {
4603
+
4604
+ const int i = item_ct1.get_group(2);
4605
+ const block_iq1_m * x = (const block_iq1_m *) vx;
4606
+
4607
+ const int tid = item_ct1.get_local_id(2);
4608
+ #if QK_K == 256
4609
+ const int il = tid/8; // 0...3
4610
+ const int ib = tid%8; // 0...7
4611
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
4612
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
4613
+ iq1m_scale_t scale;
4614
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4615
+ const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
4616
+ const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
4617
+ const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
4618
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
4619
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
4620
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
4621
+ grid32[0] &= 0x0f0f0f0f;
4622
+ #pragma unroll
4623
+ for (int j = 0; j < 8; ++j) {
4624
+ y[j] = d * (q[j] + delta);
4625
+ }
4626
+ #else
4627
+ assert(false);
4628
+ #endif
4629
+
4630
+ }
4631
+
4632
+ template <typename dst_t>
4633
+ __dpct_inline__ static void
4634
+ dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
4635
+ const sycl::nd_item<3> &item_ct1) {
4636
+
4637
+ const int i = item_ct1.get_group(2);
4638
+ const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
4639
+
4640
+ const int tid = item_ct1.get_local_id(2);
4641
+ const int il = tid/8; // 0...3
4642
+ const int ib = tid%8; // 0...7
4643
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
4644
+ const uint8_t * q4 = x[ib].qs + 4*il;
4645
+ const float d = (float)x[ib].d;
4646
+ #pragma unroll
4647
+ for (int j = 0; j < 4; ++j) {
4648
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
4649
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
4650
+ }
4651
+
4652
+ }
4653
+
4654
+
4655
+ template <typename dst_t>
4656
+ __dpct_inline__ static void
4657
+ dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
4658
+ const sycl::nd_item<3> &item_ct1) {
4659
+ const int i = item_ct1.get_group(2);
4660
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
4661
+
4662
+ const int tid = item_ct1.get_local_id(2);
4663
+ const int il = tid/8; // 0...3
4664
+ const int ib = tid%8; // 0...7
4665
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
4666
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
4667
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
4668
+ #pragma unroll
4669
+ for (int j = 0; j < 4; ++j) {
4670
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
4671
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
4672
+ }
4673
+ }
4674
+
4675
+
4676
+
4584
4677
  /*
4585
4678
  DPCT1110:4: The total declared local variable size in device function
4586
4679
  dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
@@ -7387,6 +7480,58 @@ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
7387
7480
  #endif
7388
7481
  }
7389
7482
 
7483
+ static __dpct_inline__ float
7484
+ vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
7485
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7486
+ #if QK_K == 256
7487
+ const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
7488
+
7489
+ const int ib32 = iqs;
7490
+ const int8_t * q8 = bq8_1[ib32].qs;
7491
+ const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
7492
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
7493
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
7494
+ int sumi1 = 0;
7495
+ for (int l = 0; l < 2; ++l) {
7496
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
7497
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7498
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
7499
+ std::equal_to<>());
7500
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7501
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
7502
+ std::equal_to<>());
7503
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7504
+ grid[0] ^ signs0, signs0, std::minus<>());
7505
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7506
+ grid[1] ^ signs1, signs1, std::minus<>());
7507
+ sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
7508
+ sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
7509
+ q8 += 8;
7510
+ }
7511
+ int sumi2 = 0;
7512
+ for (int l = 2; l < 4; ++l) {
7513
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
7514
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7515
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
7516
+ std::equal_to<>());
7517
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7518
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
7519
+ std::equal_to<>());
7520
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7521
+ grid[0] ^ signs0, signs0, std::minus<>());
7522
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7523
+ grid[1] ^ signs1, signs1, std::minus<>());
7524
+ sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
7525
+ sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
7526
+ q8 += 8;
7527
+ }
7528
+ const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
7529
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
7530
+ #else
7531
+ assert(false);
7532
+ #endif
7533
+ }
7534
+
7390
7535
  static __dpct_inline__ float
7391
7536
  vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7392
7537
  const block_q8_1 *__restrict__ bq8_1, const int &iqs,
@@ -7429,10 +7574,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
7429
7574
 
7430
7575
  static __dpct_inline__ float
7431
7576
  vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7432
- const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7433
- const uint32_t *iq3s_grid, const uint64_t *ksigns64) {
7434
- #if DPCT_COMPATIBILITY_TEMP >= \
7435
- MIN_CC_DP4A // lowest compute capability for integer intrinsics
7577
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7578
+ const uint32_t *iq3s_grid) {
7436
7579
  #if QK_K == 256
7437
7580
  const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
7438
7581
 
@@ -7444,9 +7587,11 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7444
7587
  const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
7445
7588
  const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
7446
7589
  uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
7447
- ((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
7590
+ ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,
7591
+ 0x08040201, std::equal_to<>());
7448
7592
  uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
7449
- ((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
7593
+ ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,
7594
+ 0x08040201, std::equal_to<>());
7450
7595
  const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7451
7596
  grid1[0] ^ signs0, signs0, std::minus<>());
7452
7597
  const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
@@ -7455,45 +7600,142 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
7455
7600
  sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
7456
7601
  q8 += 8;
7457
7602
  }
7458
- const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * bq8_1[ib32].ds[0];
7603
+ const float d =
7604
+ (float)bq2->d *
7605
+ (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
7606
+ bq8_1[ib32].ds[0];
7459
7607
  return d * sumi;
7460
7608
  #else
7461
7609
  assert(false);
7462
- return 0.f;
7463
- #endif
7464
- #else
7465
- assert(false);
7466
- return 0.f;
7467
7610
  #endif
7468
7611
  }
7469
7612
 
7470
7613
  static __dpct_inline__ float
7471
7614
  vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
7472
- const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7473
- const uint32_t *iq1s_grid, const uint64_t *ksigns64) {
7615
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
7616
+ const uint32_t *iq1s_grid_gpu) {
7474
7617
  #if QK_K == 256
7475
7618
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
7476
7619
 
7477
7620
  const int ib32 = iqs;
7478
- const uint8_t * qs = bq1->qs + 4*ib32;
7479
- const int8_t * q8 = bq8_1[ib32].qs;
7480
7621
  int sumi = 0;
7622
+ const int * q8 = (const int *)bq8_1[ib32].qs;
7481
7623
  for (int l = 0; l < 4; ++l) {
7482
- const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]);
7483
- const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8));
7484
- const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
7485
- grid[0] ^ signs[0], signs[0], std::minus<>());
7486
- const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
7487
- grid[1] ^ signs[1], signs[1], std::minus<>());
7488
- sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
7489
- sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
7490
- q8 += 8;
7624
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
7625
+ int grid0 = grid[0] & 0x0f0f0f0f;
7626
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
7627
+ sumi = dpct::dp4a(q8[2 * l + 1], grid1,
7628
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi));
7629
+ }
7630
+
7631
+ const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
7632
+ const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
7633
+ const float d = d1q * bq8_1[ib32].ds[0];
7634
+ const float m = d1q * bq8_1[ib32].ds[1];
7635
+ return d * sumi + m * delta;
7636
+ #else
7637
+ assert(false);
7638
+ #endif
7639
+ }
7640
+
7641
+ static __dpct_inline__ float
7642
+ vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
7643
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7644
+ #if QK_K == 256
7645
+ const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
7646
+
7647
+ const int ib32 = iqs;
7648
+ int sumi[2] = {0, 0};
7649
+ float sumf[2] = {0.f, 0.f};
7650
+
7651
+ const int * q8 = (const int *)bq8_1[ib32].qs;
7652
+ for (int l = 0; l < 4; ++l) {
7653
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
7654
+ int grid0 = grid[0] & 0x0f0f0f0f;
7655
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
7656
+ sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,
7657
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));
7658
+ const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
7659
+ const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,
7660
+ dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));
7661
+ sumf[l/2] += delta*sumy;
7662
+ }
7663
+
7664
+ iq1m_scale_t scale;
7665
+ const uint16_t * sc = (const uint16_t *)bq1->scales;
7666
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
7667
+ const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
7668
+ return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
7669
+ #else
7670
+ assert(false);
7671
+ #endif
7672
+ }
7673
+
7674
+ static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
7675
+ const uint8_t *values,
7676
+ int &val1, int &val2) {
7677
+
7678
+ uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
7679
+ aux32 = q4 & 0x0f0f0f0f;
7680
+ uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
7681
+ uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
7682
+ val1 = v1 | (v2 << 16);
7683
+ aux32 = (q4 >> 4) & 0x0f0f0f0f;
7684
+ v1 = values[q8[0]] | (values[q8[1]] << 8);
7685
+ v2 = values[q8[2]] | (values[q8[3]] << 8);
7686
+ val2 = v1 | (v2 << 16);
7687
+ }
7688
+
7689
+
7690
+ static __dpct_inline__ float
7691
+ vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
7692
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7693
+
7694
+ const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
7695
+
7696
+ const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
7697
+ const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
7698
+
7699
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
7700
+
7701
+ int v1, v2;
7702
+ int sumi1 = 0, sumi2 = 0;
7703
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
7704
+ const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
7705
+ get_int_from_table_16(aux, values, v1, v2);
7706
+ sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);
7707
+ sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);
7491
7708
  }
7492
- const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f;
7493
- return d * sumi;
7709
+
7710
+ const float d = (float)bq->d * bq8_1->ds[0];
7711
+ return d * (sumi1 + sumi2);
7712
+ }
7713
+
7714
+
7715
+ static __dpct_inline__ float
7716
+ vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
7717
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
7718
+
7719
+ #if QK_K == 256
7720
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
7721
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
7722
+
7723
+ // iqs is 0...7
7724
+ const int ib32 = iqs;
7725
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
7726
+ const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
7727
+ const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
7728
+ const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];
7729
+ int v1, v2;
7730
+ int sumi1 = 0, sumi2 = 0;
7731
+ for (int j = 0; j < 4; ++j) {
7732
+ get_int_from_table_16(q4[j], values, v1, v2);
7733
+ sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);
7734
+ sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
7735
+ }
7736
+ return d * (sumi1 + sumi2);
7494
7737
  #else
7495
7738
  assert(false);
7496
- return 0.f;
7497
7739
  #endif
7498
7740
  }
7499
7741
 
@@ -8078,8 +8320,199 @@ template <bool need_check> static void
8078
8320
 
8079
8321
  template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
8080
8322
  static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8081
- const sycl::nd_item<3> &item_ct1,
8082
- const uint32_t *iq3xxs_grid_ptr=nullptr, const uint64_t *ksigns64_ptr=nullptr) {
8323
+ const sycl::nd_item<3> &item_ct1) {
8324
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8325
+ item_ct1.get_local_id(1);
8326
+
8327
+ if (row >= nrows) {
8328
+ return;
8329
+ }
8330
+
8331
+ const int blocks_per_row = ncols / qk;
8332
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8333
+
8334
+ // partial sum for each thread
8335
+ float tmp = 0.0f;
8336
+
8337
+ const block_q_t * x = (const block_q_t *) vx;
8338
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8339
+
8340
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8341
+ i += blocks_per_warp) {
8342
+ const int ibx = row*blocks_per_row + i; // x block index
8343
+
8344
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8345
+
8346
+ const int iqs =
8347
+ vdr *
8348
+ (item_ct1.get_local_id(2) %
8349
+ (qi / vdr)); // x block quant index when casting the quants to int
8350
+
8351
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8352
+ }
8353
+
8354
+ // sum up partial sums and write back result
8355
+ #pragma unroll
8356
+ for (int mask = 16; mask > 0; mask >>= 1) {
8357
+ tmp +=
8358
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8359
+ }
8360
+
8361
+ if (item_ct1.get_local_id(2) == 0) {
8362
+ dst[row] = tmp;
8363
+ }
8364
+ }
8365
+
8366
+ template <int qk, int qi, typename block_q_t, int vdr>
8367
+ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
8368
+ const void *__restrict__ vy,
8369
+ float *__restrict__ dst, const int ncols,
8370
+ const int nrows,
8371
+ const sycl::nd_item<3> &item_ct1) {
8372
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8373
+ item_ct1.get_local_id(1);
8374
+
8375
+ if (row >= nrows) {
8376
+ return;
8377
+ }
8378
+
8379
+ const int blocks_per_row = ncols / qk;
8380
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8381
+
8382
+ // partial sum for each thread
8383
+ float tmp = 0.0f;
8384
+
8385
+ const block_q_t * x = (const block_q_t *) vx;
8386
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8387
+
8388
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8389
+ i += blocks_per_warp) {
8390
+ const int ibx = row*blocks_per_row + i; // x block index
8391
+
8392
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8393
+
8394
+ const int iqs =
8395
+ vdr *
8396
+ (item_ct1.get_local_id(2) %
8397
+ (qi / vdr)); // x block quant index when casting the quants to int
8398
+
8399
+ tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
8400
+ }
8401
+
8402
+ // sum up partial sums and write back result
8403
+ #pragma unroll
8404
+ for (int mask = 16; mask > 0; mask >>= 1) {
8405
+ tmp +=
8406
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8407
+ }
8408
+
8409
+ if (item_ct1.get_local_id(2) == 0) {
8410
+ dst[row] = tmp;
8411
+ }
8412
+ }
8413
+
8414
+ template <int qk, int qi, typename block_q_t, int vdr>
8415
+ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
8416
+ const void *__restrict__ vy,
8417
+ float *__restrict__ dst, const int ncols,
8418
+ const int nrows,
8419
+ const sycl::nd_item<3> &item_ct1) {
8420
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8421
+ item_ct1.get_local_id(1);
8422
+
8423
+ if (row >= nrows) {
8424
+ return;
8425
+ }
8426
+
8427
+ const int blocks_per_row = ncols / qk;
8428
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8429
+
8430
+ // partial sum for each thread
8431
+ float tmp = 0.0f;
8432
+
8433
+ const block_q_t * x = (const block_q_t *) vx;
8434
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8435
+
8436
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8437
+ i += blocks_per_warp) {
8438
+ const int ibx = row*blocks_per_row + i; // x block index
8439
+
8440
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8441
+
8442
+ const int iqs =
8443
+ vdr *
8444
+ (item_ct1.get_local_id(2) %
8445
+ (qi / vdr)); // x block quant index when casting the quants to int
8446
+
8447
+ tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
8448
+ }
8449
+
8450
+ // sum up partial sums and write back result
8451
+ #pragma unroll
8452
+ for (int mask = 16; mask > 0; mask >>= 1) {
8453
+ tmp +=
8454
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8455
+ }
8456
+
8457
+ if (item_ct1.get_local_id(2) == 0) {
8458
+ dst[row] = tmp;
8459
+ }
8460
+ }
8461
+
8462
+ template <int qk, int qi, typename block_q_t, int vdr>
8463
+ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
8464
+ const void *__restrict__ vy,
8465
+ float *__restrict__ dst, const int ncols,
8466
+ const int nrows,
8467
+ const sycl::nd_item<3> &item_ct1) {
8468
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8469
+ item_ct1.get_local_id(1);
8470
+
8471
+ if (row >= nrows) {
8472
+ return;
8473
+ }
8474
+
8475
+ const int blocks_per_row = ncols / qk;
8476
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
8477
+
8478
+ // partial sum for each thread
8479
+ float tmp = 0.0f;
8480
+
8481
+ const block_q_t * x = (const block_q_t *) vx;
8482
+ const block_q8_1 * y = (const block_q8_1 *) vy;
8483
+
8484
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8485
+ i += blocks_per_warp) {
8486
+ const int ibx = row*blocks_per_row + i; // x block index
8487
+
8488
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8489
+
8490
+ const int iqs =
8491
+ vdr *
8492
+ (item_ct1.get_local_id(2) %
8493
+ (qi / vdr)); // x block quant index when casting the quants to int
8494
+
8495
+ tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
8496
+ }
8497
+
8498
+ // sum up partial sums and write back result
8499
+ #pragma unroll
8500
+ for (int mask = 16; mask > 0; mask >>= 1) {
8501
+ tmp +=
8502
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
8503
+ }
8504
+
8505
+ if (item_ct1.get_local_id(2) == 0) {
8506
+ dst[row] = tmp;
8507
+ }
8508
+ }
8509
+
8510
+ template <int qk, int qi, typename block_q_t, int vdr>
8511
+ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
8512
+ const void *__restrict__ vy,
8513
+ float *__restrict__ dst, const int ncols,
8514
+ const int nrows,
8515
+ const sycl::nd_item<3> &item_ct1) {
8083
8516
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8084
8517
  item_ct1.get_local_id(1);
8085
8518
 
@@ -8107,7 +8540,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
8107
8540
  (item_ct1.get_local_id(2) %
8108
8541
  (qi / vdr)); // x block quant index when casting the quants to int
8109
8542
 
8110
- tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8543
+ tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
8111
8544
  }
8112
8545
 
8113
8546
  // sum up partial sums and write back result
@@ -8123,10 +8556,11 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
8123
8556
  }
8124
8557
 
8125
8558
  template <int qk, int qi, typename block_q_t, int vdr>
8126
- static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8127
- const sycl::nd_item<3> &item_ct1,
8128
- const uint64_t *iq2xxs_grid_ptr, const uint8_t *ksigns_iq2xs_ptr,
8129
- const uint8_t *kmask_iq2xs_ptr ) {
8559
+ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
8560
+ const void *__restrict__ vy,
8561
+ float *__restrict__ dst, const int ncols,
8562
+ const int nrows,
8563
+ const sycl::nd_item<3> &item_ct1) {
8130
8564
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8131
8565
  item_ct1.get_local_id(1);
8132
8566
 
@@ -8154,7 +8588,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void
8154
8588
  (item_ct1.get_local_id(2) %
8155
8589
  (qi / vdr)); // x block quant index when casting the quants to int
8156
8590
 
8157
- tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid_ptr, ksigns_iq2xs_ptr, kmask_iq2xs_ptr);
8591
+ tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
8158
8592
  }
8159
8593
 
8160
8594
  // sum up partial sums and write back result
@@ -8170,9 +8604,11 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void
8170
8604
  }
8171
8605
 
8172
8606
  template <int qk, int qi, typename block_q_t, int vdr>
8173
- static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8174
- const sycl::nd_item<3> &item_ct1,
8175
- const uint64_t *iq2xs_grid_ptr, const uint64_t *ksigns64_ptr ) {
8607
+ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
8608
+ const void *__restrict__ vy,
8609
+ float *__restrict__ dst, const int ncols,
8610
+ const int nrows,
8611
+ const sycl::nd_item<3> &item_ct1) {
8176
8612
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8177
8613
  item_ct1.get_local_id(1);
8178
8614
 
@@ -8200,7 +8636,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void *
8200
8636
  (item_ct1.get_local_id(2) %
8201
8637
  (qi / vdr)); // x block quant index when casting the quants to int
8202
8638
 
8203
- tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid_ptr, ksigns64_ptr);
8639
+ tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
8204
8640
  }
8205
8641
 
8206
8642
  // sum up partial sums and write back result
@@ -8216,9 +8652,11 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void *
8216
8652
  }
8217
8653
 
8218
8654
  template <int qk, int qi, typename block_q_t, int vdr>
8219
- static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8220
- const sycl::nd_item<3> &item_ct1,
8221
- const uint32_t *iq3xxs_grid_ptr, const uint64_t *ksigns64_ptr ) {
8655
+ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
8656
+ const void *__restrict__ vy,
8657
+ float *__restrict__ dst, const int ncols,
8658
+ const int nrows,
8659
+ const sycl::nd_item<3> &item_ct1) {
8222
8660
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8223
8661
  item_ct1.get_local_id(1);
8224
8662
 
@@ -8246,7 +8684,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
8246
8684
  (item_ct1.get_local_id(2) %
8247
8685
  (qi / vdr)); // x block quant index when casting the quants to int
8248
8686
 
8249
- tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid_ptr, ksigns64_ptr);
8687
+ tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
8250
8688
  }
8251
8689
 
8252
8690
  // sum up partial sums and write back result
@@ -8262,9 +8700,11 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
8262
8700
  }
8263
8701
 
8264
8702
  template <int qk, int qi, typename block_q_t, int vdr>
8265
- static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8266
- const sycl::nd_item<3> &item_ct1,
8267
- const uint32_t *iq3s_grid_ptr, const uint64_t *ksigns64_ptr ) {
8703
+ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
8704
+ const void *__restrict__ vy,
8705
+ float *__restrict__ dst, const int ncols,
8706
+ const int nrows,
8707
+ const sycl::nd_item<3> &item_ct1) {
8268
8708
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8269
8709
  item_ct1.get_local_id(1);
8270
8710
 
@@ -8292,7 +8732,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
8292
8732
  (item_ct1.get_local_id(2) %
8293
8733
  (qi / vdr)); // x block quant index when casting the quants to int
8294
8734
 
8295
- tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid_ptr, ksigns64_ptr);
8735
+ tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
8296
8736
  }
8297
8737
 
8298
8738
  // sum up partial sums and write back result
@@ -8307,10 +8747,13 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
8307
8747
  }
8308
8748
  }
8309
8749
 
8750
+
8310
8751
  template <int qk, int qi, typename block_q_t, int vdr>
8311
- static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
8312
- const sycl::nd_item<3> &item_ct1,
8313
- const uint32_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) {
8752
+ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
8753
+ const void *__restrict__ vy,
8754
+ float *__restrict__ dst, const int ncols,
8755
+ const int nrows,
8756
+ const sycl::nd_item<3> &item_ct1) {
8314
8757
  const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
8315
8758
  item_ct1.get_local_id(1);
8316
8759
 
@@ -8338,7 +8781,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
8338
8781
  (item_ct1.get_local_id(2) %
8339
8782
  (qi / vdr)); // x block quant index when casting the quants to int
8340
8783
 
8341
- tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr);
8784
+ tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
8342
8785
  }
8343
8786
 
8344
8787
  // sum up partial sums and write back result
@@ -8353,6 +8796,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
8353
8796
  }
8354
8797
  }
8355
8798
 
8799
+
8356
8800
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
8357
8801
  static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
8358
8802
  const sycl::nd_item<3> &item_ct1) {
@@ -8914,64 +9358,71 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
8914
9358
  }
8915
9359
  }
8916
9360
 
9361
+
8917
9362
  template<typename T>
8918
- static inline void swap(T & a, T & b) {
9363
+ static inline void ggml_sycl_swap(T & a, T & b) {
8919
9364
  T tmp = a;
8920
9365
  a = b;
8921
9366
  b = tmp;
8922
9367
  }
8923
9368
 
8924
- template<ggml_sort_order order>
8925
- static void k_argsort_f32_i32(const float * x, int * dst, const int ncols,
8926
- const sycl::nd_item<3> &item_ct1) {
9369
+ template <ggml_sort_order order>
9370
+ __dpct_inline__ static void
9371
+ k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
9372
+ const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
8927
9373
  // bitonic sort
8928
9374
  int col = item_ct1.get_local_id(2);
8929
9375
  int row = item_ct1.get_group(1);
8930
9376
 
8931
- if (col >= ncols) return;
9377
+ if (col >= ncols_pad) {
9378
+ return;
9379
+ }
8932
9380
 
8933
9381
  const float * x_row = x + row * ncols;
8934
- int * dst_row = dst + row * ncols;
9382
+ auto dst_row = (int *)dpct_local;
8935
9383
 
8936
9384
  // initialize indices
8937
- if (col < ncols) {
8938
- dst_row[col] = col;
8939
- }
8940
- /*
8941
- DPCT1065:58: Consider replacing sycl::nd_item::barrier() with
8942
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
8943
- performance if there is no access to global memory.
8944
- */
8945
- item_ct1.barrier();
9385
+ dst_row[col] = col;
9386
+
9387
+ item_ct1.barrier(sycl::access::fence_space::local_space);
8946
9388
 
8947
- for (int k = 2; k <= ncols; k *= 2) {
9389
+ for (int k = 2; k <= ncols_pad; k *= 2) {
8948
9390
  for (int j = k / 2; j > 0; j /= 2) {
8949
9391
  int ixj = col ^ j;
8950
9392
  if (ixj > col) {
8951
9393
  if ((col & k) == 0) {
8952
- if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
8953
- swap(dst_row[col], dst_row[ixj]);
9394
+ if (dst_row[col] >= ncols ||
9395
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
9396
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
9397
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
9398
+ ) {
9399
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
8954
9400
  }
8955
9401
  } else {
8956
- if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
8957
- swap(dst_row[col], dst_row[ixj]);
9402
+ if (dst_row[ixj] >= ncols ||
9403
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
9404
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
9405
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
9406
+ ) {
9407
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
8958
9408
  }
8959
9409
  }
8960
9410
  }
8961
9411
  /*
8962
- DPCT1118:11: SYCL group functions and algorithms must be encountered
9412
+ DPCT1118:1: SYCL group functions and algorithms must be encountered
8963
9413
  in converged control flow. You may need to adjust the code.
8964
9414
  */
8965
- /*
8966
- DPCT1065:59: Consider replacing sycl::nd_item::barrier() with
8967
- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
8968
- better performance if there is no access to global memory.
8969
- */
8970
- item_ct1.barrier();
9415
+ item_ct1.barrier(sycl::access::fence_space::local_space);
8971
9416
  }
8972
9417
  }
9418
+
9419
+ // copy the result to dst without the padding
9420
+ if (col < ncols) {
9421
+ dst[row * ncols + col] = dst_row[col];
9422
+ }
8973
9423
  }
8974
9424
 
9425
+
8975
9426
  static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
8976
9427
  const sycl::nd_item<3> &item_ct1) {
8977
9428
  const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -9950,28 +10401,64 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
9950
10401
  #endif
9951
10402
  }
9952
10403
 
9953
-
9954
10404
  template <typename dst_t>
9955
- static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
10405
+ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
9956
10406
  dpct::queue_ptr stream) {
9957
10407
  const int nb = k / QK_K;
9958
10408
  {
10409
+ dpct::has_capability_or_fail(stream->get_device(),
10410
+ {sycl::aspect::fp16});
10411
+
10412
+ stream->submit([&](sycl::handler &cgh) {
10413
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10414
+ sycl::range<3>(1, 1, 32),
10415
+ sycl::range<3>(1, 1, 32)),
10416
+ [=](sycl::nd_item<3> item_ct1) {
10417
+ dequantize_block_iq1_s(
10418
+ vx, y, item_ct1, iq1s_grid_gpu
10419
+ );
10420
+ });
10421
+ });
10422
+ }
10423
+ }
9959
10424
 
10425
+ template <typename dst_t>
10426
+ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
10427
+ dpct::queue_ptr stream) {
10428
+ const int nb = k / QK_K;
10429
+ {
9960
10430
  dpct::has_capability_or_fail(stream->get_device(),
9961
10431
  {sycl::aspect::fp16});
9962
10432
 
9963
10433
  stream->submit([&](sycl::handler &cgh) {
9964
- auto iq2xxs_grid_ptr_ct1 = &iq2xxs_grid[0];
9965
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
9966
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10434
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10435
+ sycl::range<3>(1, 1, 32),
10436
+ sycl::range<3>(1, 1, 32)),
10437
+ [=](sycl::nd_item<3> item_ct1) {
10438
+ dequantize_block_iq1_m(
10439
+ vx, y, item_ct1, iq1s_grid_gpu
10440
+ );
10441
+ });
10442
+ });
10443
+ }
10444
+ }
10445
+
10446
+ template <typename dst_t>
10447
+ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
10448
+ dpct::queue_ptr stream) {
10449
+ const int nb = k / QK_K;
10450
+ {
10451
+ dpct::has_capability_or_fail(stream->get_device(),
10452
+ {sycl::aspect::fp16});
9967
10453
 
10454
+ stream->submit([&](sycl::handler &cgh) {
9968
10455
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
9969
10456
  sycl::range<3>(1, 1, 32),
9970
10457
  sycl::range<3>(1, 1, 32)),
9971
10458
  [=](sycl::nd_item<3> item_ct1) {
9972
10459
  dequantize_block_iq2_xxs(
9973
- vx, y, item_ct1, iq2xxs_grid_ptr_ct1,
9974
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10460
+ vx, y, item_ct1, iq2xxs_grid,
10461
+ ksigns_iq2xs, kmask_iq2xs);
9975
10462
  });
9976
10463
  });
9977
10464
  }
@@ -9982,105 +10469,130 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
9982
10469
  dpct::queue_ptr stream) {
9983
10470
  const int nb = k / QK_K;
9984
10471
  {
9985
-
9986
10472
  dpct::has_capability_or_fail(stream->get_device(),
9987
10473
  {sycl::aspect::fp16});
9988
10474
 
9989
10475
  stream->submit([&](sycl::handler &cgh) {
9990
- auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
9991
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
9992
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
9993
-
9994
10476
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
9995
10477
  sycl::range<3>(1, 1, 32),
9996
10478
  sycl::range<3>(1, 1, 32)),
9997
10479
  [=](sycl::nd_item<3> item_ct1) {
9998
10480
  dequantize_block_iq2_xs(
9999
- vx, y, item_ct1, iq2xs_grid_ptr_ct1,
10000
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10481
+ vx, y, item_ct1, iq2xs_grid,
10482
+ ksigns_iq2xs, kmask_iq2xs);
10001
10483
  });
10002
10484
  });
10003
10485
  }
10004
10486
  }
10005
10487
 
10006
10488
  template <typename dst_t>
10007
- static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
10008
- dpct::queue_ptr stream) {
10489
+ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
10490
+ dpct::queue_ptr stream) {
10009
10491
  const int nb = k / QK_K;
10010
10492
  {
10011
-
10012
10493
  dpct::has_capability_or_fail(stream->get_device(),
10013
10494
  {sycl::aspect::fp16});
10014
10495
 
10015
10496
  stream->submit([&](sycl::handler &cgh) {
10016
- auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0];
10017
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10018
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10019
-
10020
10497
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10021
10498
  sycl::range<3>(1, 1, 32),
10022
10499
  sycl::range<3>(1, 1, 32)),
10023
10500
  [=](sycl::nd_item<3> item_ct1) {
10024
- dequantize_block_iq3_xxs(
10025
- vx, y, item_ct1, iq3xxs_grid_ptr_ct1,
10026
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10501
+ dequantize_block_iq2_s(vx, y, item_ct1);
10027
10502
  });
10028
10503
  });
10029
10504
  }
10030
10505
  }
10031
10506
 
10507
+
10032
10508
  template <typename dst_t>
10033
- static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
10509
+ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
10034
10510
  dpct::queue_ptr stream) {
10035
10511
  const int nb = k / QK_K;
10036
10512
  {
10037
-
10038
10513
  dpct::has_capability_or_fail(stream->get_device(),
10039
10514
  {sycl::aspect::fp16});
10040
10515
 
10041
10516
  stream->submit([&](sycl::handler &cgh) {
10042
- auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
10043
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10044
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10045
-
10046
10517
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10047
10518
  sycl::range<3>(1, 1, 32),
10048
10519
  sycl::range<3>(1, 1, 32)),
10049
10520
  [=](sycl::nd_item<3> item_ct1) {
10050
- dequantize_block_iq3_s(
10051
- vx, y, item_ct1, iq3s_grid_ptr_ct1,
10052
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10521
+ dequantize_block_iq3_xxs(
10522
+ vx, y, item_ct1, iq3xxs_grid,
10523
+ ksigns_iq2xs, kmask_iq2xs);
10053
10524
  });
10054
10525
  });
10055
10526
  }
10056
10527
  }
10057
10528
 
10058
10529
  template <typename dst_t>
10059
- static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
10530
+ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
10060
10531
  dpct::queue_ptr stream) {
10061
10532
  const int nb = k / QK_K;
10062
10533
  {
10063
-
10064
10534
  dpct::has_capability_or_fail(stream->get_device(),
10065
10535
  {sycl::aspect::fp16});
10066
10536
 
10067
10537
  stream->submit([&](sycl::handler &cgh) {
10068
- auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0];
10069
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10070
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10071
-
10072
10538
  cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10073
10539
  sycl::range<3>(1, 1, 32),
10074
10540
  sycl::range<3>(1, 1, 32)),
10075
10541
  [=](sycl::nd_item<3> item_ct1) {
10076
- dequantize_block_iq1_s(
10077
- vx, y, item_ct1, iq1s_grid_ptr_ct1,
10078
- ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
10542
+ dequantize_block_iq3_s(
10543
+ vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
10079
10544
  });
10080
10545
  });
10081
10546
  }
10082
10547
  }
10083
10548
 
10549
+ template <typename dst_t>
10550
+ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
10551
+ dpct::queue_ptr stream) {
10552
+ const int nb = (k + QK_K - 1) / QK_K;
10553
+ #if QK_K == 64
10554
+ dequantize_row_iq4_nl_sycl(vx, y, k, stream);
10555
+ #else
10556
+ {
10557
+ dpct::has_capability_or_fail(stream->get_device(),
10558
+ {sycl::aspect::fp16});
10559
+
10560
+ stream->submit([&](sycl::handler &cgh) {
10561
+ cgh.parallel_for(
10562
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10563
+ sycl::range<3>(1, 1, 32),
10564
+ sycl::range<3>(1, 1, 32)),
10565
+ [=](sycl::nd_item<3> item_ct1) {
10566
+ dequantize_block_iq4_xs(vx, y, item_ct1);
10567
+ });
10568
+ });
10569
+ }
10570
+ #endif
10571
+ }
10572
+
10573
+
10574
+ template <typename dst_t>
10575
+ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
10576
+ dpct::queue_ptr stream) {
10577
+ const int nb = (k + QK_K - 1) / QK_K;
10578
+ {
10579
+ dpct::has_capability_or_fail(stream->get_device(),
10580
+ {sycl::aspect::fp16});
10581
+
10582
+ stream->submit([&](sycl::handler &cgh) {
10583
+ cgh.parallel_for(
10584
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
10585
+ sycl::range<3>(1, 1, 32),
10586
+ sycl::range<3>(1, 1, 32)),
10587
+ [=](sycl::nd_item<3> item_ct1) {
10588
+ dequantize_block_iq4_nl(vx, y, item_ct1);
10589
+ });
10590
+ });
10591
+ }
10592
+ }
10593
+
10594
+
10595
+
10084
10596
  template <typename src_t, typename dst_t>
10085
10597
  static void convert_unary_sycl(const void *__restrict__ vx,
10086
10598
  dst_t *__restrict__ y, const int k,
@@ -10125,16 +10637,24 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
10125
10637
  return dequantize_row_q5_K_sycl;
10126
10638
  case GGML_TYPE_Q6_K:
10127
10639
  return dequantize_row_q6_K_sycl;
10640
+ case GGML_TYPE_IQ1_S:
10641
+ return dequantize_row_iq1_s_sycl;
10642
+ case GGML_TYPE_IQ1_M:
10643
+ return dequantize_row_iq1_m_sycl;
10128
10644
  case GGML_TYPE_IQ2_XXS:
10129
10645
  return dequantize_row_iq2_xxs_sycl;
10130
10646
  case GGML_TYPE_IQ2_XS:
10131
10647
  return dequantize_row_iq2_xs_sycl;
10648
+ case GGML_TYPE_IQ2_S:
10649
+ return dequantize_row_iq2_s_sycl;
10132
10650
  case GGML_TYPE_IQ3_XXS:
10133
10651
  return dequantize_row_iq3_xxs_sycl;
10134
10652
  case GGML_TYPE_IQ3_S:
10135
10653
  return dequantize_row_iq3_s_sycl;
10136
- case GGML_TYPE_IQ1_S:
10137
- return dequantize_row_iq1_s_sycl;
10654
+ case GGML_TYPE_IQ4_XS:
10655
+ return dequantize_row_iq4_xs_sycl;
10656
+ case GGML_TYPE_IQ4_NL:
10657
+ return dequantize_row_iq4_nl_sycl;
10138
10658
  case GGML_TYPE_F32:
10139
10659
  return convert_unary_sycl<float>;
10140
10660
  default:
@@ -10169,16 +10689,24 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
10169
10689
  return dequantize_row_q5_K_sycl;
10170
10690
  case GGML_TYPE_Q6_K:
10171
10691
  return dequantize_row_q6_K_sycl;
10692
+ case GGML_TYPE_IQ1_S:
10693
+ return dequantize_row_iq1_s_sycl;
10694
+ case GGML_TYPE_IQ1_M:
10695
+ return dequantize_row_iq1_m_sycl;
10172
10696
  case GGML_TYPE_IQ2_XXS:
10173
10697
  return dequantize_row_iq2_xxs_sycl;
10174
10698
  case GGML_TYPE_IQ2_XS:
10175
10699
  return dequantize_row_iq2_xs_sycl;
10700
+ case GGML_TYPE_IQ2_S:
10701
+ return dequantize_row_iq2_s_sycl;
10176
10702
  case GGML_TYPE_IQ3_XXS:
10177
10703
  return dequantize_row_iq3_xxs_sycl;
10178
10704
  case GGML_TYPE_IQ3_S:
10179
10705
  return dequantize_row_iq3_s_sycl;
10180
- case GGML_TYPE_IQ1_S:
10181
- return dequantize_row_iq1_s_sycl;
10706
+ case GGML_TYPE_IQ4_XS:
10707
+ return dequantize_row_iq4_xs_sycl;
10708
+ case GGML_TYPE_IQ4_NL:
10709
+ return dequantize_row_iq4_nl_sycl;
10182
10710
  case GGML_TYPE_F16:
10183
10711
  return convert_unary_sycl<sycl::half>;
10184
10712
  default:
@@ -10641,19 +11169,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
10641
11169
  const sycl::range<3> block_nums(1, 1, block_num_y);
10642
11170
  const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
10643
11171
  {
10644
-
10645
11172
  stream->submit([&](sycl::handler &cgh) {
10646
- auto iq2xxs_grid_ptr_ct1 = &iq2xxs_grid[0];
10647
- auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
10648
- auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
10649
-
10650
11173
  cgh.parallel_for(
10651
11174
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
10652
11175
  [=](sycl::nd_item<3> item_ct1)
10653
11176
  [[intel::reqd_sub_group_size(32)]] {
10654
11177
  mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS, block_iq2_xxs, 1>(
10655
- vx, vy, dst, ncols, nrows, item_ct1,
10656
- iq2xxs_grid_ptr_ct1, ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
11178
+ vx, vy, dst, ncols, nrows, item_ct1);
10657
11179
  });
10658
11180
  });
10659
11181
  }
@@ -10678,8 +11200,32 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
10678
11200
  [=](sycl::nd_item<3> item_ct1)
10679
11201
  [[intel::reqd_sub_group_size(32)]] {
10680
11202
  mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS, block_iq2_xs, 1>(
10681
- vx, vy, dst, ncols, nrows, item_ct1,
10682
- iq2xs_grid_ptr_ct1, ksigns64_ptr_ct1);
11203
+ vx, vy, dst, ncols, nrows, item_ct1);
11204
+ });
11205
+ });
11206
+ }
11207
+ }
11208
+
11209
+ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
11210
+ float *dst, const int ncols,
11211
+ const int nrows,
11212
+ dpct::queue_ptr stream) {
11213
+ GGML_ASSERT(ncols % QK_K == 0);
11214
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11215
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11216
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11217
+ {
11218
+
11219
+ stream->submit([&](sycl::handler &cgh) {
11220
+ auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
11221
+ auto ksigns64_ptr_ct1 = &ksigns64[0];
11222
+
11223
+ cgh.parallel_for(
11224
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11225
+ [=](sycl::nd_item<3> item_ct1)
11226
+ [[intel::reqd_sub_group_size(32)]] {
11227
+ mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
11228
+ vx, vy, dst, ncols, nrows, item_ct1);
10683
11229
  });
10684
11230
  });
10685
11231
  }
@@ -10704,8 +11250,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
10704
11250
  [=](sycl::nd_item<3> item_ct1)
10705
11251
  [[intel::reqd_sub_group_size(32)]] {
10706
11252
  mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS, block_iq3_xxs, 1>(
10707
- vx, vy, dst, ncols, nrows, item_ct1,
10708
- iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1);
11253
+ vx, vy, dst, ncols, nrows, item_ct1);
10709
11254
  });
10710
11255
  });
10711
11256
  }
@@ -10723,15 +11268,13 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
10723
11268
 
10724
11269
  stream->submit([&](sycl::handler &cgh) {
10725
11270
  auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
10726
- auto ksigns64_ptr_ct1 = &ksigns64[0];
10727
11271
 
10728
11272
  cgh.parallel_for(
10729
11273
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
10730
11274
  [=](sycl::nd_item<3> item_ct1)
10731
11275
  [[intel::reqd_sub_group_size(32)]] {
10732
11276
  mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_XS, block_iq3_s, 1>(
10733
- vx, vy, dst, ncols, nrows, item_ct1,
10734
- iq3s_grid_ptr_ct1, ksigns64_ptr_ct1);
11277
+ vx, vy, dst, ncols, nrows, item_ct1);
10735
11278
  });
10736
11279
  });
10737
11280
  }
@@ -10756,8 +11299,72 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
10756
11299
  [=](sycl::nd_item<3> item_ct1)
10757
11300
  [[intel::reqd_sub_group_size(32)]] {
10758
11301
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
10759
- vx, vy, dst, ncols, nrows, item_ct1,
10760
- iq1s_grid_ptr_ct1, ksigns64_ptr_ct1);
11302
+ vx, vy, dst, ncols, nrows, item_ct1);
11303
+ });
11304
+ });
11305
+ }
11306
+ }
11307
+
11308
+ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
11309
+ float *dst, const int ncols,
11310
+ const int nrows,
11311
+ dpct::queue_ptr stream) {
11312
+ GGML_ASSERT(ncols % QK_K == 0);
11313
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11314
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11315
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11316
+ {
11317
+ stream->submit([&](sycl::handler &cgh) {
11318
+ cgh.parallel_for(
11319
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11320
+ [=](sycl::nd_item<3> item_ct1)
11321
+ [[intel::reqd_sub_group_size(32)]] {
11322
+ mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
11323
+ vx, vy, dst, ncols, nrows, item_ct1);
11324
+ });
11325
+ });
11326
+ }
11327
+ }
11328
+
11329
+ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
11330
+ float *dst, const int ncols,
11331
+ const int nrows,
11332
+ dpct::queue_ptr stream) {
11333
+ GGML_ASSERT(ncols % QK4_NL == 0);
11334
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11335
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11336
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11337
+ {
11338
+
11339
+ stream->submit([&](sycl::handler &cgh) {
11340
+ cgh.parallel_for(
11341
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11342
+ [=](sycl::nd_item<3> item_ct1)
11343
+ [[intel::reqd_sub_group_size(32)]] {
11344
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
11345
+ vx, vy, dst, ncols, nrows, item_ct1);
11346
+ });
11347
+ });
11348
+ }
11349
+ }
11350
+
11351
+ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
11352
+ float *dst, const int ncols,
11353
+ const int nrows,
11354
+ dpct::queue_ptr stream) {
11355
+ GGML_ASSERT(ncols % QK_K == 0);
11356
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
11357
+ const sycl::range<3> block_nums(1, 1, block_num_y);
11358
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
11359
+ {
11360
+
11361
+ stream->submit([&](sycl::handler &cgh) {
11362
+ cgh.parallel_for(
11363
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
11364
+ [=](sycl::nd_item<3> item_ct1)
11365
+ [[intel::reqd_sub_group_size(32)]] {
11366
+ mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS, block_iq4_xs, 1>(
11367
+ vx, vy, dst, ncols, nrows, item_ct1);
10761
11368
  });
10762
11369
  });
10763
11370
  }
@@ -12381,36 +12988,54 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12381
12988
  });
12382
12989
  }
12383
12990
 
12991
+ static int next_power_of_2(int x) {
12992
+ int n = 1;
12993
+ while (n < x) {
12994
+ n *= 2;
12995
+ }
12996
+ return n;
12997
+ }
12998
+
12384
12999
  static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
12385
13000
  const int nrows, ggml_sort_order order,
12386
13001
  dpct::queue_ptr stream) {
12387
13002
  // bitonic sort requires ncols to be power of 2
12388
- GGML_ASSERT((ncols & (ncols - 1)) == 0);
13003
+ const int ncols_pad = next_power_of_2(ncols);
12389
13004
 
12390
- const sycl::range<3> block_dims(1, 1, ncols);
13005
+ const sycl::range<3> block_dims(1, 1, ncols_pad);
12391
13006
  const sycl::range<3> block_nums(1, nrows, 1);
13007
+ const size_t shared_mem = ncols_pad * sizeof(int);
13008
+
13009
+ // GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
13010
+
12392
13011
  if (order == GGML_SORT_ORDER_ASC) {
12393
- /*
12394
- DPCT1049:44: The work-group size passed to the SYCL kernel may exceed
12395
- the limit. To get the device limit, query
12396
- info::device::max_work_group_size. Adjust the work-group size if needed.
12397
- */
12398
- stream->parallel_for(
12399
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12400
- [=](sycl::nd_item<3> item_ct1) {
12401
- k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(x, dst, ncols, item_ct1);
12402
- });
13012
+ stream->submit([&](sycl::handler &cgh) {
13013
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
13014
+ sycl::range<1>(shared_mem), cgh);
13015
+
13016
+ cgh.parallel_for(
13017
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
13018
+ [=](sycl::nd_item<3> item_ct1) {
13019
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
13020
+ x, dst, ncols, ncols_pad, item_ct1,
13021
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
13022
+ .get());
13023
+ });
13024
+ });
12403
13025
  } else if (order == GGML_SORT_ORDER_DESC) {
12404
- /*
12405
- DPCT1049:45: The work-group size passed to the SYCL kernel may exceed
12406
- the limit. To get the device limit, query
12407
- info::device::max_work_group_size. Adjust the work-group size if needed.
12408
- */
12409
- stream->parallel_for(
12410
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12411
- [=](sycl::nd_item<3> item_ct1) {
12412
- k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(x, dst, ncols, item_ct1);
12413
- });
13026
+ stream->submit([&](sycl::handler &cgh) {
13027
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
13028
+ sycl::range<1>(shared_mem), cgh);
13029
+
13030
+ cgh.parallel_for(
13031
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
13032
+ [=](sycl::nd_item<3> item_ct1) {
13033
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
13034
+ x, dst, ncols, ncols_pad, item_ct1,
13035
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
13036
+ .get());
13037
+ });
13038
+ });
12414
13039
  } else {
12415
13040
  GGML_ASSERT(false);
12416
13041
  }
@@ -13538,8 +14163,12 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
13538
14163
  case GGML_TYPE_Q5_K:
13539
14164
  case GGML_TYPE_IQ2_XXS:
13540
14165
  case GGML_TYPE_IQ2_XS:
14166
+ case GGML_TYPE_IQ2_S:
13541
14167
  case GGML_TYPE_IQ1_S:
14168
+ case GGML_TYPE_IQ1_M:
13542
14169
  case GGML_TYPE_IQ3_XXS:
14170
+ case GGML_TYPE_IQ4_XS:
14171
+ case GGML_TYPE_IQ4_NL:
13543
14172
  return max_compute_capability >= VER_GEN9 ? 128 : 64;
13544
14173
  case GGML_TYPE_IQ3_S:
13545
14174
  return max_compute_capability >= VER_GEN9 ? 128 : 64;
@@ -13558,11 +14187,20 @@ inline void ggml_sycl_op_mul_mat_vec_q(
13558
14187
  const int64_t src1_ncols, const int64_t src1_padded_row_size,
13559
14188
  const dpct::queue_ptr &stream) {
13560
14189
 
13561
- GGML_ASSERT(ggml_nrows(src1) == 1);
14190
+ const int64_t ne10 = src1->ne[0];
14191
+ GGML_ASSERT(ne10 % QK8_1 == 0);
13562
14192
 
13563
14193
  const int64_t ne00 = src0->ne[0];
13564
14194
  const int64_t row_diff = row_high - row_low;
13565
14195
 
14196
+ int id;
14197
+ SYCL_CHECK(
14198
+ CHECK_TRY_ERROR(id = get_current_device_id()));
14199
+
14200
+ // the main device has a larger memory buffer to hold the results from all GPUs
14201
+ // nrows_dst == nrows of the matrix that the kernel writes into
14202
+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne00 : row_diff;
14203
+
13566
14204
  switch (src0->type) {
13567
14205
  case GGML_TYPE_Q4_0:
13568
14206
  mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
@@ -13594,20 +14232,32 @@ inline void ggml_sycl_op_mul_mat_vec_q(
13594
14232
  case GGML_TYPE_Q6_K:
13595
14233
  mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13596
14234
  break;
14235
+ case GGML_TYPE_IQ1_S:
14236
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14237
+ break;
14238
+ case GGML_TYPE_IQ1_M:
14239
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14240
+ break;
13597
14241
  case GGML_TYPE_IQ2_XXS:
13598
14242
  mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13599
14243
  break;
13600
14244
  case GGML_TYPE_IQ2_XS:
13601
14245
  mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13602
14246
  break;
14247
+ case GGML_TYPE_IQ2_S:
14248
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14249
+ break;
13603
14250
  case GGML_TYPE_IQ3_XXS:
13604
14251
  mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13605
14252
  break;
13606
14253
  case GGML_TYPE_IQ3_S:
13607
14254
  mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13608
14255
  break;
13609
- case GGML_TYPE_IQ1_S:
13610
- mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14256
+ case GGML_TYPE_IQ4_NL:
14257
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
14258
+ break;
14259
+ case GGML_TYPE_IQ4_XS:
14260
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
13611
14261
  break;
13612
14262
  default:
13613
14263
  GGML_ASSERT(false);
@@ -13689,6 +14339,7 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
13689
14339
  convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
13690
14340
  break;
13691
14341
  default:
14342
+ printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
13692
14343
  GGML_ASSERT(false);
13693
14344
  break;
13694
14345
  }
@@ -14543,8 +15194,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
14543
15194
  src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
14544
15195
  }
14545
15196
  // do the computation
14546
- op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
14547
- dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream);
15197
+ SYCL_CHECK(CHECK_TRY_ERROR(op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
15198
+ dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
14548
15199
  /*
14549
15200
  DPCT1010:93: SYCL uses exceptions to report errors and does not
14550
15201
  use the error codes. The call was replaced with 0. You need to
@@ -15125,11 +15776,17 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15125
15776
  #ifdef GGML_SYCL_FORCE_DMMV
15126
15777
  const bool use_mul_mat_vec_q = false;
15127
15778
  #else
15128
- const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
15779
+ bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15780
+ use_mul_mat_vec_q = use_mul_mat_vec_q ||
15781
+ (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15782
+ (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15783
+ (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15784
+ (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15785
+
15786
+
15129
15787
  #endif // GGML_SYCL_FORCE_DMMV
15130
15788
 
15131
15789
  if (use_mul_mat_vec_q) {
15132
- // NOTE: this kernel does not support ggml_nrows(src1) > 1
15133
15790
  // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15134
15791
  ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15135
15792
  } else {
@@ -16985,9 +17642,14 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
16985
17642
  return false;
16986
17643
  }
16987
17644
  ggml_type a_type = a->type;
16988
- if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ2_S ||
16989
- a_type == GGML_TYPE_IQ4_XS) {
16990
- return false;
17645
+ if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
17646
+ a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
17647
+ a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
17648
+ a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
17649
+ ) {
17650
+ if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
17651
+ return false;
17652
+ }
16991
17653
  }
16992
17654
  return true;
16993
17655
  } break;