llama_cpp 0.14.4 → 0.14.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -1
- data/examples/chat.rb +2 -4
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +23 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +10 -0
- data/vendor/tmp/llama.cpp/LICENSE +1 -1
- data/vendor/tmp/llama.cpp/Makefile +11 -3
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +7 -3
- data/vendor/tmp/llama.cpp/ggml-quants.c +155 -155
- data/vendor/tmp/llama.cpp/ggml-quants.h +82 -82
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +878 -216
- data/vendor/tmp/llama.cpp/ggml.c +8 -8
- data/vendor/tmp/llama.cpp/ggml.h +7 -7
- data/vendor/tmp/llama.cpp/llama.cpp +686 -124
- data/vendor/tmp/llama.cpp/llama.h +81 -13
- metadata +2 -2
@@ -1664,24 +1664,6 @@ namespace dpct
|
|
1664
1664
|
const void *alpha, const void *a, int lda, const void *b,
|
1665
1665
|
int ldb, const void *beta, void *c, int ldc)
|
1666
1666
|
{
|
1667
|
-
#ifndef __INTEL_MKL__
|
1668
|
-
GGML_UNUSED(q);
|
1669
|
-
GGML_UNUSED(a_trans);
|
1670
|
-
GGML_UNUSED(b_trans);
|
1671
|
-
GGML_UNUSED(m);
|
1672
|
-
GGML_UNUSED(n);
|
1673
|
-
GGML_UNUSED(k);
|
1674
|
-
GGML_UNUSED(alpha);
|
1675
|
-
GGML_UNUSED(a);
|
1676
|
-
GGML_UNUSED(lda);
|
1677
|
-
GGML_UNUSED(b);
|
1678
|
-
GGML_UNUSED(ldb);
|
1679
|
-
GGML_UNUSED(beta);
|
1680
|
-
GGML_UNUSED(c);
|
1681
|
-
GGML_UNUSED(ldc);
|
1682
|
-
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
|
1683
|
-
"Project does not support this API.");
|
1684
|
-
#else
|
1685
1667
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
1686
1668
|
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
1687
1669
|
auto data_a = get_memory<const Ta>(a);
|
@@ -1690,7 +1672,6 @@ namespace dpct
|
|
1690
1672
|
oneapi::mkl::blas::column_major::gemm(
|
1691
1673
|
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
1692
1674
|
data_b, ldb, beta_value, data_c, ldc);
|
1693
|
-
#endif
|
1694
1675
|
}
|
1695
1676
|
|
1696
1677
|
template <typename VecT, class BinaryOperation, class = void>
|
@@ -2330,6 +2311,7 @@ namespace dpct
|
|
2330
2311
|
lda, b, ldb, beta, c, ldc);
|
2331
2312
|
break;
|
2332
2313
|
}
|
2314
|
+
#ifdef __INTEL_MKL__
|
2333
2315
|
case detail::get_type_combination_id(
|
2334
2316
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
2335
2317
|
library_data_t::real_float, library_data_t::real_float):
|
@@ -2391,6 +2373,7 @@ namespace dpct
|
|
2391
2373
|
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
|
2392
2374
|
break;
|
2393
2375
|
}
|
2376
|
+
#endif // __INTEL_MKL__
|
2394
2377
|
default:
|
2395
2378
|
throw std::runtime_error("the combination of data type is unsupported");
|
2396
2379
|
}
|
@@ -3055,6 +3038,10 @@ typedef float dfloat; // dequantize float
|
|
3055
3038
|
typedef sycl::float2 dfloat2;
|
3056
3039
|
#endif //GGML_SYCL_F16
|
3057
3040
|
|
3041
|
+
#define MMVQ_MAX_BATCH_SIZE 8
|
3042
|
+
|
3043
|
+
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
3044
|
+
|
3058
3045
|
bool ggml_sycl_loaded(void);
|
3059
3046
|
void * ggml_sycl_host_malloc(size_t size);
|
3060
3047
|
void ggml_sycl_host_free(void * ptr);
|
@@ -4490,6 +4477,32 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
|
|
4490
4477
|
|
4491
4478
|
}
|
4492
4479
|
|
4480
|
+
template <typename dst_t>
|
4481
|
+
__dpct_inline__ static void
|
4482
|
+
dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
4483
|
+
const sycl::nd_item<3> &item_ct1) {
|
4484
|
+
|
4485
|
+
const int i = item_ct1.get_group(2);
|
4486
|
+
const block_iq2_s * x = (const block_iq2_s *) vx;
|
4487
|
+
|
4488
|
+
const int tid = item_ct1.get_local_id(2);
|
4489
|
+
#if QK_K == 256
|
4490
|
+
const int il = tid/8; // 0...3
|
4491
|
+
const int ib = tid%8; // 0...7
|
4492
|
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
4493
|
+
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
4494
|
+
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
4495
|
+
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
4496
|
+
#pragma unroll
|
4497
|
+
for (int j = 0; j < 8; ++j)
|
4498
|
+
y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
4499
|
+
#else
|
4500
|
+
assert(false);
|
4501
|
+
|
4502
|
+
#endif
|
4503
|
+
|
4504
|
+
}
|
4505
|
+
|
4493
4506
|
template<typename dst_t>
|
4494
4507
|
static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
4495
4508
|
const sycl::nd_item<3> &item_ct1,
|
@@ -4522,26 +4535,26 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
|
|
4522
4535
|
|
4523
4536
|
}
|
4524
4537
|
|
4525
|
-
template<typename dst_t>
|
4526
|
-
static void
|
4527
|
-
|
4528
|
-
|
4529
|
-
|
4530
|
-
const uint8_t *kmask_iq2xs) {
|
4538
|
+
template <typename dst_t>
|
4539
|
+
__dpct_inline__ static void
|
4540
|
+
dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
4541
|
+
const sycl::nd_item<3> &item_ct1,
|
4542
|
+
const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
|
4531
4543
|
|
4532
4544
|
const int i = item_ct1.get_group(2);
|
4533
|
-
const block_iq3_s * x = (const block_iq3_s
|
4545
|
+
const block_iq3_s * x = (const block_iq3_s *) vx;
|
4534
4546
|
|
4535
4547
|
const int tid = item_ct1.get_local_id(2);
|
4536
4548
|
#if QK_K == 256
|
4537
4549
|
const int il = tid/8; // 0...3
|
4538
4550
|
const int ib = tid%8; // 0...7
|
4539
4551
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
4540
|
-
const uint8_t
|
4541
|
-
const uint8_t
|
4542
|
-
const uint8_t
|
4552
|
+
const uint8_t * qs = x[i].qs + 8*ib;
|
4553
|
+
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
4554
|
+
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
|
4543
4555
|
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
|
4544
4556
|
const uint8_t signs = x[i].signs[4*ib + il];
|
4557
|
+
#pragma unroll
|
4545
4558
|
for (int j = 0; j < 4; ++j) {
|
4546
4559
|
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
4547
4560
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
@@ -4552,12 +4565,12 @@ static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restr
|
|
4552
4565
|
|
4553
4566
|
}
|
4554
4567
|
|
4555
|
-
template<typename dst_t>
|
4556
|
-
static void
|
4557
|
-
|
4558
|
-
|
4559
|
-
|
4560
|
-
|
4568
|
+
template <typename dst_t>
|
4569
|
+
__dpct_inline__ static void
|
4570
|
+
dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
4571
|
+
const sycl::nd_item<3> &item_ct1,
|
4572
|
+
const uint32_t *iq1s_grid_gpu) {
|
4573
|
+
|
4561
4574
|
const int i = item_ct1.get_group(2);
|
4562
4575
|
const block_iq1_s * x = (const block_iq1_s *) vx;
|
4563
4576
|
|
@@ -4566,14 +4579,15 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
|
|
4566
4579
|
const int il = tid/8; // 0...3
|
4567
4580
|
const int ib = tid%8; // 0...7
|
4568
4581
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
4569
|
-
const
|
4570
|
-
const
|
4571
|
-
const
|
4572
|
-
|
4573
|
-
|
4574
|
-
|
4575
|
-
|
4576
|
-
|
4582
|
+
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
4583
|
+
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
4584
|
+
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
4585
|
+
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
|
4586
|
+
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
4587
|
+
grid32[0] &= 0x0f0f0f0f;
|
4588
|
+
#pragma unroll
|
4589
|
+
for (int j = 0; j < 8; ++j) {
|
4590
|
+
y[j] = d * (q[j] + delta);
|
4577
4591
|
}
|
4578
4592
|
#else
|
4579
4593
|
assert(false);
|
@@ -4581,6 +4595,85 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
|
|
4581
4595
|
|
4582
4596
|
}
|
4583
4597
|
|
4598
|
+
template <typename dst_t>
|
4599
|
+
__dpct_inline__ static void
|
4600
|
+
dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
4601
|
+
const sycl::nd_item<3> &item_ct1,
|
4602
|
+
const uint32_t *iq1s_grid_gpu) {
|
4603
|
+
|
4604
|
+
const int i = item_ct1.get_group(2);
|
4605
|
+
const block_iq1_m * x = (const block_iq1_m *) vx;
|
4606
|
+
|
4607
|
+
const int tid = item_ct1.get_local_id(2);
|
4608
|
+
#if QK_K == 256
|
4609
|
+
const int il = tid/8; // 0...3
|
4610
|
+
const int ib = tid%8; // 0...7
|
4611
|
+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
4612
|
+
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
4613
|
+
iq1m_scale_t scale;
|
4614
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
4615
|
+
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
|
4616
|
+
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
|
4617
|
+
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
|
4618
|
+
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
4619
|
+
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
|
4620
|
+
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
4621
|
+
grid32[0] &= 0x0f0f0f0f;
|
4622
|
+
#pragma unroll
|
4623
|
+
for (int j = 0; j < 8; ++j) {
|
4624
|
+
y[j] = d * (q[j] + delta);
|
4625
|
+
}
|
4626
|
+
#else
|
4627
|
+
assert(false);
|
4628
|
+
#endif
|
4629
|
+
|
4630
|
+
}
|
4631
|
+
|
4632
|
+
template <typename dst_t>
|
4633
|
+
__dpct_inline__ static void
|
4634
|
+
dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
4635
|
+
const sycl::nd_item<3> &item_ct1) {
|
4636
|
+
|
4637
|
+
const int i = item_ct1.get_group(2);
|
4638
|
+
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
4639
|
+
|
4640
|
+
const int tid = item_ct1.get_local_id(2);
|
4641
|
+
const int il = tid/8; // 0...3
|
4642
|
+
const int ib = tid%8; // 0...7
|
4643
|
+
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
4644
|
+
const uint8_t * q4 = x[ib].qs + 4*il;
|
4645
|
+
const float d = (float)x[ib].d;
|
4646
|
+
#pragma unroll
|
4647
|
+
for (int j = 0; j < 4; ++j) {
|
4648
|
+
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
4649
|
+
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
4650
|
+
}
|
4651
|
+
|
4652
|
+
}
|
4653
|
+
|
4654
|
+
|
4655
|
+
template <typename dst_t>
|
4656
|
+
__dpct_inline__ static void
|
4657
|
+
dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
4658
|
+
const sycl::nd_item<3> &item_ct1) {
|
4659
|
+
const int i = item_ct1.get_group(2);
|
4660
|
+
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
4661
|
+
|
4662
|
+
const int tid = item_ct1.get_local_id(2);
|
4663
|
+
const int il = tid/8; // 0...3
|
4664
|
+
const int ib = tid%8; // 0...7
|
4665
|
+
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
4666
|
+
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
4667
|
+
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
4668
|
+
#pragma unroll
|
4669
|
+
for (int j = 0; j < 4; ++j) {
|
4670
|
+
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
4671
|
+
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
4672
|
+
}
|
4673
|
+
}
|
4674
|
+
|
4675
|
+
|
4676
|
+
|
4584
4677
|
/*
|
4585
4678
|
DPCT1110:4: The total declared local variable size in device function
|
4586
4679
|
dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
|
@@ -7387,6 +7480,58 @@ vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
|
|
7387
7480
|
#endif
|
7388
7481
|
}
|
7389
7482
|
|
7483
|
+
static __dpct_inline__ float
|
7484
|
+
vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
|
7485
|
+
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
|
7486
|
+
#if QK_K == 256
|
7487
|
+
const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
|
7488
|
+
|
7489
|
+
const int ib32 = iqs;
|
7490
|
+
const int8_t * q8 = bq8_1[ib32].qs;
|
7491
|
+
const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
|
7492
|
+
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
|
7493
|
+
const uint8_t ls2 = bq2->scales[ib32] >> 4;
|
7494
|
+
int sumi1 = 0;
|
7495
|
+
for (int l = 0; l < 2; ++l) {
|
7496
|
+
const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
|
7497
|
+
const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
|
7498
|
+
((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
|
7499
|
+
std::equal_to<>());
|
7500
|
+
const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
|
7501
|
+
((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
|
7502
|
+
std::equal_to<>());
|
7503
|
+
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
|
7504
|
+
grid[0] ^ signs0, signs0, std::minus<>());
|
7505
|
+
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
7506
|
+
grid[1] ^ signs1, signs1, std::minus<>());
|
7507
|
+
sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
|
7508
|
+
sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
|
7509
|
+
q8 += 8;
|
7510
|
+
}
|
7511
|
+
int sumi2 = 0;
|
7512
|
+
for (int l = 2; l < 4; ++l) {
|
7513
|
+
const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
|
7514
|
+
const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
|
7515
|
+
((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
|
7516
|
+
std::equal_to<>());
|
7517
|
+
const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
|
7518
|
+
((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
|
7519
|
+
std::equal_to<>());
|
7520
|
+
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
|
7521
|
+
grid[0] ^ signs0, signs0, std::minus<>());
|
7522
|
+
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
7523
|
+
grid[1] ^ signs1, signs1, std::minus<>());
|
7524
|
+
sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
|
7525
|
+
sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
|
7526
|
+
q8 += 8;
|
7527
|
+
}
|
7528
|
+
const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
|
7529
|
+
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
|
7530
|
+
#else
|
7531
|
+
assert(false);
|
7532
|
+
#endif
|
7533
|
+
}
|
7534
|
+
|
7390
7535
|
static __dpct_inline__ float
|
7391
7536
|
vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
|
7392
7537
|
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
|
@@ -7429,10 +7574,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
|
|
7429
7574
|
|
7430
7575
|
static __dpct_inline__ float
|
7431
7576
|
vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
|
7432
|
-
|
7433
|
-
|
7434
|
-
#if DPCT_COMPATIBILITY_TEMP >= \
|
7435
|
-
MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
7577
|
+
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
|
7578
|
+
const uint32_t *iq3s_grid) {
|
7436
7579
|
#if QK_K == 256
|
7437
7580
|
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
|
7438
7581
|
|
@@ -7444,9 +7587,11 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
|
|
7444
7587
|
const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
|
7445
7588
|
const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
|
7446
7589
|
uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
|
7447
|
-
((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201,
|
7590
|
+
((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,
|
7591
|
+
0x08040201, std::equal_to<>());
|
7448
7592
|
uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
|
7449
|
-
((bq2->signs[4*ib32+l] >>
|
7593
|
+
((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,
|
7594
|
+
0x08040201, std::equal_to<>());
|
7450
7595
|
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
|
7451
7596
|
grid1[0] ^ signs0, signs0, std::minus<>());
|
7452
7597
|
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
@@ -7455,45 +7600,142 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
|
|
7455
7600
|
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
7456
7601
|
q8 += 8;
|
7457
7602
|
}
|
7458
|
-
const float d =
|
7603
|
+
const float d =
|
7604
|
+
(float)bq2->d *
|
7605
|
+
(1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
|
7606
|
+
bq8_1[ib32].ds[0];
|
7459
7607
|
return d * sumi;
|
7460
7608
|
#else
|
7461
7609
|
assert(false);
|
7462
|
-
return 0.f;
|
7463
|
-
#endif
|
7464
|
-
#else
|
7465
|
-
assert(false);
|
7466
|
-
return 0.f;
|
7467
7610
|
#endif
|
7468
7611
|
}
|
7469
7612
|
|
7470
7613
|
static __dpct_inline__ float
|
7471
7614
|
vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
|
7472
|
-
|
7473
|
-
|
7615
|
+
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
|
7616
|
+
const uint32_t *iq1s_grid_gpu) {
|
7474
7617
|
#if QK_K == 256
|
7475
7618
|
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
7476
7619
|
|
7477
7620
|
const int ib32 = iqs;
|
7478
|
-
const uint8_t * qs = bq1->qs + 4*ib32;
|
7479
|
-
const int8_t * q8 = bq8_1[ib32].qs;
|
7480
7621
|
int sumi = 0;
|
7622
|
+
const int * q8 = (const int *)bq8_1[ib32].qs;
|
7481
7623
|
for (int l = 0; l < 4; ++l) {
|
7482
|
-
const
|
7483
|
-
|
7484
|
-
|
7485
|
-
|
7486
|
-
|
7487
|
-
|
7488
|
-
|
7489
|
-
|
7490
|
-
|
7624
|
+
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
7625
|
+
int grid0 = grid[0] & 0x0f0f0f0f;
|
7626
|
+
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
|
7627
|
+
sumi = dpct::dp4a(q8[2 * l + 1], grid1,
|
7628
|
+
dpct::dp4a(q8[2 * l + 0], grid0, sumi));
|
7629
|
+
}
|
7630
|
+
|
7631
|
+
const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
|
7632
|
+
const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
|
7633
|
+
const float d = d1q * bq8_1[ib32].ds[0];
|
7634
|
+
const float m = d1q * bq8_1[ib32].ds[1];
|
7635
|
+
return d * sumi + m * delta;
|
7636
|
+
#else
|
7637
|
+
assert(false);
|
7638
|
+
#endif
|
7639
|
+
}
|
7640
|
+
|
7641
|
+
static __dpct_inline__ float
|
7642
|
+
vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
|
7643
|
+
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
|
7644
|
+
#if QK_K == 256
|
7645
|
+
const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
|
7646
|
+
|
7647
|
+
const int ib32 = iqs;
|
7648
|
+
int sumi[2] = {0, 0};
|
7649
|
+
float sumf[2] = {0.f, 0.f};
|
7650
|
+
|
7651
|
+
const int * q8 = (const int *)bq8_1[ib32].qs;
|
7652
|
+
for (int l = 0; l < 4; ++l) {
|
7653
|
+
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
|
7654
|
+
int grid0 = grid[0] & 0x0f0f0f0f;
|
7655
|
+
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
|
7656
|
+
sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,
|
7657
|
+
dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));
|
7658
|
+
const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
|
7659
|
+
const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,
|
7660
|
+
dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));
|
7661
|
+
sumf[l/2] += delta*sumy;
|
7662
|
+
}
|
7663
|
+
|
7664
|
+
iq1m_scale_t scale;
|
7665
|
+
const uint16_t * sc = (const uint16_t *)bq1->scales;
|
7666
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
7667
|
+
const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
|
7668
|
+
return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
|
7669
|
+
#else
|
7670
|
+
assert(false);
|
7671
|
+
#endif
|
7672
|
+
}
|
7673
|
+
|
7674
|
+
static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
|
7675
|
+
const uint8_t *values,
|
7676
|
+
int &val1, int &val2) {
|
7677
|
+
|
7678
|
+
uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
|
7679
|
+
aux32 = q4 & 0x0f0f0f0f;
|
7680
|
+
uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
|
7681
|
+
uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
|
7682
|
+
val1 = v1 | (v2 << 16);
|
7683
|
+
aux32 = (q4 >> 4) & 0x0f0f0f0f;
|
7684
|
+
v1 = values[q8[0]] | (values[q8[1]] << 8);
|
7685
|
+
v2 = values[q8[2]] | (values[q8[3]] << 8);
|
7686
|
+
val2 = v1 | (v2 << 16);
|
7687
|
+
}
|
7688
|
+
|
7689
|
+
|
7690
|
+
static __dpct_inline__ float
|
7691
|
+
vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
|
7692
|
+
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
|
7693
|
+
|
7694
|
+
const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
|
7695
|
+
|
7696
|
+
const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
|
7697
|
+
const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
|
7698
|
+
|
7699
|
+
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
|
7700
|
+
|
7701
|
+
int v1, v2;
|
7702
|
+
int sumi1 = 0, sumi2 = 0;
|
7703
|
+
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
7704
|
+
const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
|
7705
|
+
get_int_from_table_16(aux, values, v1, v2);
|
7706
|
+
sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);
|
7707
|
+
sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);
|
7491
7708
|
}
|
7492
|
-
|
7493
|
-
|
7709
|
+
|
7710
|
+
const float d = (float)bq->d * bq8_1->ds[0];
|
7711
|
+
return d * (sumi1 + sumi2);
|
7712
|
+
}
|
7713
|
+
|
7714
|
+
|
7715
|
+
static __dpct_inline__ float
|
7716
|
+
vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
|
7717
|
+
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
|
7718
|
+
|
7719
|
+
#if QK_K == 256
|
7720
|
+
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
|
7721
|
+
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
|
7722
|
+
|
7723
|
+
// iqs is 0...7
|
7724
|
+
const int ib32 = iqs;
|
7725
|
+
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
7726
|
+
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
|
7727
|
+
const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
|
7728
|
+
const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];
|
7729
|
+
int v1, v2;
|
7730
|
+
int sumi1 = 0, sumi2 = 0;
|
7731
|
+
for (int j = 0; j < 4; ++j) {
|
7732
|
+
get_int_from_table_16(q4[j], values, v1, v2);
|
7733
|
+
sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);
|
7734
|
+
sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
|
7735
|
+
}
|
7736
|
+
return d * (sumi1 + sumi2);
|
7494
7737
|
#else
|
7495
7738
|
assert(false);
|
7496
|
-
return 0.f;
|
7497
7739
|
#endif
|
7498
7740
|
}
|
7499
7741
|
|
@@ -8078,8 +8320,199 @@ template <bool need_check> static void
|
|
8078
8320
|
|
8079
8321
|
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
8080
8322
|
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
|
8081
|
-
const sycl::nd_item<3> &item_ct1
|
8082
|
-
|
8323
|
+
const sycl::nd_item<3> &item_ct1) {
|
8324
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8325
|
+
item_ct1.get_local_id(1);
|
8326
|
+
|
8327
|
+
if (row >= nrows) {
|
8328
|
+
return;
|
8329
|
+
}
|
8330
|
+
|
8331
|
+
const int blocks_per_row = ncols / qk;
|
8332
|
+
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
8333
|
+
|
8334
|
+
// partial sum for each thread
|
8335
|
+
float tmp = 0.0f;
|
8336
|
+
|
8337
|
+
const block_q_t * x = (const block_q_t *) vx;
|
8338
|
+
const block_q8_1 * y = (const block_q8_1 *) vy;
|
8339
|
+
|
8340
|
+
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
|
8341
|
+
i += blocks_per_warp) {
|
8342
|
+
const int ibx = row*blocks_per_row + i; // x block index
|
8343
|
+
|
8344
|
+
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
8345
|
+
|
8346
|
+
const int iqs =
|
8347
|
+
vdr *
|
8348
|
+
(item_ct1.get_local_id(2) %
|
8349
|
+
(qi / vdr)); // x block quant index when casting the quants to int
|
8350
|
+
|
8351
|
+
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
|
8352
|
+
}
|
8353
|
+
|
8354
|
+
// sum up partial sums and write back result
|
8355
|
+
#pragma unroll
|
8356
|
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
8357
|
+
tmp +=
|
8358
|
+
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
8359
|
+
}
|
8360
|
+
|
8361
|
+
if (item_ct1.get_local_id(2) == 0) {
|
8362
|
+
dst[row] = tmp;
|
8363
|
+
}
|
8364
|
+
}
|
8365
|
+
|
8366
|
+
template <int qk, int qi, typename block_q_t, int vdr>
|
8367
|
+
static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
|
8368
|
+
const void *__restrict__ vy,
|
8369
|
+
float *__restrict__ dst, const int ncols,
|
8370
|
+
const int nrows,
|
8371
|
+
const sycl::nd_item<3> &item_ct1) {
|
8372
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8373
|
+
item_ct1.get_local_id(1);
|
8374
|
+
|
8375
|
+
if (row >= nrows) {
|
8376
|
+
return;
|
8377
|
+
}
|
8378
|
+
|
8379
|
+
const int blocks_per_row = ncols / qk;
|
8380
|
+
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
8381
|
+
|
8382
|
+
// partial sum for each thread
|
8383
|
+
float tmp = 0.0f;
|
8384
|
+
|
8385
|
+
const block_q_t * x = (const block_q_t *) vx;
|
8386
|
+
const block_q8_1 * y = (const block_q8_1 *) vy;
|
8387
|
+
|
8388
|
+
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
|
8389
|
+
i += blocks_per_warp) {
|
8390
|
+
const int ibx = row*blocks_per_row + i; // x block index
|
8391
|
+
|
8392
|
+
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
8393
|
+
|
8394
|
+
const int iqs =
|
8395
|
+
vdr *
|
8396
|
+
(item_ct1.get_local_id(2) %
|
8397
|
+
(qi / vdr)); // x block quant index when casting the quants to int
|
8398
|
+
|
8399
|
+
tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
|
8400
|
+
}
|
8401
|
+
|
8402
|
+
// sum up partial sums and write back result
|
8403
|
+
#pragma unroll
|
8404
|
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
8405
|
+
tmp +=
|
8406
|
+
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
8407
|
+
}
|
8408
|
+
|
8409
|
+
if (item_ct1.get_local_id(2) == 0) {
|
8410
|
+
dst[row] = tmp;
|
8411
|
+
}
|
8412
|
+
}
|
8413
|
+
|
8414
|
+
template <int qk, int qi, typename block_q_t, int vdr>
|
8415
|
+
static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
|
8416
|
+
const void *__restrict__ vy,
|
8417
|
+
float *__restrict__ dst, const int ncols,
|
8418
|
+
const int nrows,
|
8419
|
+
const sycl::nd_item<3> &item_ct1) {
|
8420
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8421
|
+
item_ct1.get_local_id(1);
|
8422
|
+
|
8423
|
+
if (row >= nrows) {
|
8424
|
+
return;
|
8425
|
+
}
|
8426
|
+
|
8427
|
+
const int blocks_per_row = ncols / qk;
|
8428
|
+
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
8429
|
+
|
8430
|
+
// partial sum for each thread
|
8431
|
+
float tmp = 0.0f;
|
8432
|
+
|
8433
|
+
const block_q_t * x = (const block_q_t *) vx;
|
8434
|
+
const block_q8_1 * y = (const block_q8_1 *) vy;
|
8435
|
+
|
8436
|
+
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
|
8437
|
+
i += blocks_per_warp) {
|
8438
|
+
const int ibx = row*blocks_per_row + i; // x block index
|
8439
|
+
|
8440
|
+
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
8441
|
+
|
8442
|
+
const int iqs =
|
8443
|
+
vdr *
|
8444
|
+
(item_ct1.get_local_id(2) %
|
8445
|
+
(qi / vdr)); // x block quant index when casting the quants to int
|
8446
|
+
|
8447
|
+
tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
|
8448
|
+
}
|
8449
|
+
|
8450
|
+
// sum up partial sums and write back result
|
8451
|
+
#pragma unroll
|
8452
|
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
8453
|
+
tmp +=
|
8454
|
+
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
8455
|
+
}
|
8456
|
+
|
8457
|
+
if (item_ct1.get_local_id(2) == 0) {
|
8458
|
+
dst[row] = tmp;
|
8459
|
+
}
|
8460
|
+
}
|
8461
|
+
|
8462
|
+
template <int qk, int qi, typename block_q_t, int vdr>
|
8463
|
+
static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
|
8464
|
+
const void *__restrict__ vy,
|
8465
|
+
float *__restrict__ dst, const int ncols,
|
8466
|
+
const int nrows,
|
8467
|
+
const sycl::nd_item<3> &item_ct1) {
|
8468
|
+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8469
|
+
item_ct1.get_local_id(1);
|
8470
|
+
|
8471
|
+
if (row >= nrows) {
|
8472
|
+
return;
|
8473
|
+
}
|
8474
|
+
|
8475
|
+
const int blocks_per_row = ncols / qk;
|
8476
|
+
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
8477
|
+
|
8478
|
+
// partial sum for each thread
|
8479
|
+
float tmp = 0.0f;
|
8480
|
+
|
8481
|
+
const block_q_t * x = (const block_q_t *) vx;
|
8482
|
+
const block_q8_1 * y = (const block_q8_1 *) vy;
|
8483
|
+
|
8484
|
+
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
|
8485
|
+
i += blocks_per_warp) {
|
8486
|
+
const int ibx = row*blocks_per_row + i; // x block index
|
8487
|
+
|
8488
|
+
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
8489
|
+
|
8490
|
+
const int iqs =
|
8491
|
+
vdr *
|
8492
|
+
(item_ct1.get_local_id(2) %
|
8493
|
+
(qi / vdr)); // x block quant index when casting the quants to int
|
8494
|
+
|
8495
|
+
tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
|
8496
|
+
}
|
8497
|
+
|
8498
|
+
// sum up partial sums and write back result
|
8499
|
+
#pragma unroll
|
8500
|
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
8501
|
+
tmp +=
|
8502
|
+
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
8503
|
+
}
|
8504
|
+
|
8505
|
+
if (item_ct1.get_local_id(2) == 0) {
|
8506
|
+
dst[row] = tmp;
|
8507
|
+
}
|
8508
|
+
}
|
8509
|
+
|
8510
|
+
template <int qk, int qi, typename block_q_t, int vdr>
|
8511
|
+
static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
|
8512
|
+
const void *__restrict__ vy,
|
8513
|
+
float *__restrict__ dst, const int ncols,
|
8514
|
+
const int nrows,
|
8515
|
+
const sycl::nd_item<3> &item_ct1) {
|
8083
8516
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8084
8517
|
item_ct1.get_local_id(1);
|
8085
8518
|
|
@@ -8107,7 +8540,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|
8107
8540
|
(item_ct1.get_local_id(2) %
|
8108
8541
|
(qi / vdr)); // x block quant index when casting the quants to int
|
8109
8542
|
|
8110
|
-
tmp +=
|
8543
|
+
tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
|
8111
8544
|
}
|
8112
8545
|
|
8113
8546
|
// sum up partial sums and write back result
|
@@ -8123,10 +8556,11 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|
8123
8556
|
}
|
8124
8557
|
|
8125
8558
|
template <int qk, int qi, typename block_q_t, int vdr>
|
8126
|
-
static void
|
8127
|
-
|
8128
|
-
|
8129
|
-
|
8559
|
+
static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
|
8560
|
+
const void *__restrict__ vy,
|
8561
|
+
float *__restrict__ dst, const int ncols,
|
8562
|
+
const int nrows,
|
8563
|
+
const sycl::nd_item<3> &item_ct1) {
|
8130
8564
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8131
8565
|
item_ct1.get_local_id(1);
|
8132
8566
|
|
@@ -8154,7 +8588,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void
|
|
8154
8588
|
(item_ct1.get_local_id(2) %
|
8155
8589
|
(qi / vdr)); // x block quant index when casting the quants to int
|
8156
8590
|
|
8157
|
-
tmp +=
|
8591
|
+
tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
|
8158
8592
|
}
|
8159
8593
|
|
8160
8594
|
// sum up partial sums and write back result
|
@@ -8170,9 +8604,11 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void
|
|
8170
8604
|
}
|
8171
8605
|
|
8172
8606
|
template <int qk, int qi, typename block_q_t, int vdr>
|
8173
|
-
static void
|
8174
|
-
|
8175
|
-
|
8607
|
+
static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
|
8608
|
+
const void *__restrict__ vy,
|
8609
|
+
float *__restrict__ dst, const int ncols,
|
8610
|
+
const int nrows,
|
8611
|
+
const sycl::nd_item<3> &item_ct1) {
|
8176
8612
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8177
8613
|
item_ct1.get_local_id(1);
|
8178
8614
|
|
@@ -8200,7 +8636,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void *
|
|
8200
8636
|
(item_ct1.get_local_id(2) %
|
8201
8637
|
(qi / vdr)); // x block quant index when casting the quants to int
|
8202
8638
|
|
8203
|
-
tmp +=
|
8639
|
+
tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
|
8204
8640
|
}
|
8205
8641
|
|
8206
8642
|
// sum up partial sums and write back result
|
@@ -8216,9 +8652,11 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void *
|
|
8216
8652
|
}
|
8217
8653
|
|
8218
8654
|
template <int qk, int qi, typename block_q_t, int vdr>
|
8219
|
-
static void
|
8220
|
-
|
8221
|
-
|
8655
|
+
static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
|
8656
|
+
const void *__restrict__ vy,
|
8657
|
+
float *__restrict__ dst, const int ncols,
|
8658
|
+
const int nrows,
|
8659
|
+
const sycl::nd_item<3> &item_ct1) {
|
8222
8660
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8223
8661
|
item_ct1.get_local_id(1);
|
8224
8662
|
|
@@ -8246,7 +8684,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
|
|
8246
8684
|
(item_ct1.get_local_id(2) %
|
8247
8685
|
(qi / vdr)); // x block quant index when casting the quants to int
|
8248
8686
|
|
8249
|
-
tmp +=
|
8687
|
+
tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
|
8250
8688
|
}
|
8251
8689
|
|
8252
8690
|
// sum up partial sums and write back result
|
@@ -8262,9 +8700,11 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
|
|
8262
8700
|
}
|
8263
8701
|
|
8264
8702
|
template <int qk, int qi, typename block_q_t, int vdr>
|
8265
|
-
static void
|
8266
|
-
|
8267
|
-
|
8703
|
+
static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
|
8704
|
+
const void *__restrict__ vy,
|
8705
|
+
float *__restrict__ dst, const int ncols,
|
8706
|
+
const int nrows,
|
8707
|
+
const sycl::nd_item<3> &item_ct1) {
|
8268
8708
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8269
8709
|
item_ct1.get_local_id(1);
|
8270
8710
|
|
@@ -8292,7 +8732,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
|
|
8292
8732
|
(item_ct1.get_local_id(2) %
|
8293
8733
|
(qi / vdr)); // x block quant index when casting the quants to int
|
8294
8734
|
|
8295
|
-
tmp +=
|
8735
|
+
tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
|
8296
8736
|
}
|
8297
8737
|
|
8298
8738
|
// sum up partial sums and write back result
|
@@ -8307,10 +8747,13 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
|
|
8307
8747
|
}
|
8308
8748
|
}
|
8309
8749
|
|
8750
|
+
|
8310
8751
|
template <int qk, int qi, typename block_q_t, int vdr>
|
8311
|
-
static void
|
8312
|
-
|
8313
|
-
|
8752
|
+
static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
8753
|
+
const void *__restrict__ vy,
|
8754
|
+
float *__restrict__ dst, const int ncols,
|
8755
|
+
const int nrows,
|
8756
|
+
const sycl::nd_item<3> &item_ct1) {
|
8314
8757
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
8315
8758
|
item_ct1.get_local_id(1);
|
8316
8759
|
|
@@ -8338,7 +8781,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
|
|
8338
8781
|
(item_ct1.get_local_id(2) %
|
8339
8782
|
(qi / vdr)); // x block quant index when casting the quants to int
|
8340
8783
|
|
8341
|
-
tmp +=
|
8784
|
+
tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
|
8342
8785
|
}
|
8343
8786
|
|
8344
8787
|
// sum up partial sums and write back result
|
@@ -8353,6 +8796,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
|
|
8353
8796
|
}
|
8354
8797
|
}
|
8355
8798
|
|
8799
|
+
|
8356
8800
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
8357
8801
|
static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
|
8358
8802
|
const sycl::nd_item<3> &item_ct1) {
|
@@ -8914,64 +9358,71 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
|
8914
9358
|
}
|
8915
9359
|
}
|
8916
9360
|
|
9361
|
+
|
8917
9362
|
template<typename T>
|
8918
|
-
static inline void
|
9363
|
+
static inline void ggml_sycl_swap(T & a, T & b) {
|
8919
9364
|
T tmp = a;
|
8920
9365
|
a = b;
|
8921
9366
|
b = tmp;
|
8922
9367
|
}
|
8923
9368
|
|
8924
|
-
template<ggml_sort_order order>
|
8925
|
-
static void
|
8926
|
-
|
9369
|
+
template <ggml_sort_order order>
|
9370
|
+
__dpct_inline__ static void
|
9371
|
+
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
|
9372
|
+
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
|
8927
9373
|
// bitonic sort
|
8928
9374
|
int col = item_ct1.get_local_id(2);
|
8929
9375
|
int row = item_ct1.get_group(1);
|
8930
9376
|
|
8931
|
-
if (col >=
|
9377
|
+
if (col >= ncols_pad) {
|
9378
|
+
return;
|
9379
|
+
}
|
8932
9380
|
|
8933
9381
|
const float * x_row = x + row * ncols;
|
8934
|
-
|
9382
|
+
auto dst_row = (int *)dpct_local;
|
8935
9383
|
|
8936
9384
|
// initialize indices
|
8937
|
-
|
8938
|
-
|
8939
|
-
|
8940
|
-
/*
|
8941
|
-
DPCT1065:58: Consider replacing sycl::nd_item::barrier() with
|
8942
|
-
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
|
8943
|
-
performance if there is no access to global memory.
|
8944
|
-
*/
|
8945
|
-
item_ct1.barrier();
|
9385
|
+
dst_row[col] = col;
|
9386
|
+
|
9387
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
8946
9388
|
|
8947
|
-
for (int k = 2; k <=
|
9389
|
+
for (int k = 2; k <= ncols_pad; k *= 2) {
|
8948
9390
|
for (int j = k / 2; j > 0; j /= 2) {
|
8949
9391
|
int ixj = col ^ j;
|
8950
9392
|
if (ixj > col) {
|
8951
9393
|
if ((col & k) == 0) {
|
8952
|
-
if (
|
8953
|
-
|
9394
|
+
if (dst_row[col] >= ncols ||
|
9395
|
+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
9396
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
9397
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
9398
|
+
) {
|
9399
|
+
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
8954
9400
|
}
|
8955
9401
|
} else {
|
8956
|
-
if (
|
8957
|
-
|
9402
|
+
if (dst_row[ixj] >= ncols ||
|
9403
|
+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
9404
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
9405
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
9406
|
+
) {
|
9407
|
+
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
8958
9408
|
}
|
8959
9409
|
}
|
8960
9410
|
}
|
8961
9411
|
/*
|
8962
|
-
DPCT1118:
|
9412
|
+
DPCT1118:1: SYCL group functions and algorithms must be encountered
|
8963
9413
|
in converged control flow. You may need to adjust the code.
|
8964
9414
|
*/
|
8965
|
-
|
8966
|
-
DPCT1065:59: Consider replacing sycl::nd_item::barrier() with
|
8967
|
-
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
|
8968
|
-
better performance if there is no access to global memory.
|
8969
|
-
*/
|
8970
|
-
item_ct1.barrier();
|
9415
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
8971
9416
|
}
|
8972
9417
|
}
|
9418
|
+
|
9419
|
+
// copy the result to dst without the padding
|
9420
|
+
if (col < ncols) {
|
9421
|
+
dst[row * ncols + col] = dst_row[col];
|
9422
|
+
}
|
8973
9423
|
}
|
8974
9424
|
|
9425
|
+
|
8975
9426
|
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
|
8976
9427
|
const sycl::nd_item<3> &item_ct1) {
|
8977
9428
|
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
@@ -9950,28 +10401,64 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
|
|
9950
10401
|
#endif
|
9951
10402
|
}
|
9952
10403
|
|
9953
|
-
|
9954
10404
|
template <typename dst_t>
|
9955
|
-
static void
|
10405
|
+
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
9956
10406
|
dpct::queue_ptr stream) {
|
9957
10407
|
const int nb = k / QK_K;
|
9958
10408
|
{
|
10409
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
10410
|
+
{sycl::aspect::fp16});
|
10411
|
+
|
10412
|
+
stream->submit([&](sycl::handler &cgh) {
|
10413
|
+
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
10414
|
+
sycl::range<3>(1, 1, 32),
|
10415
|
+
sycl::range<3>(1, 1, 32)),
|
10416
|
+
[=](sycl::nd_item<3> item_ct1) {
|
10417
|
+
dequantize_block_iq1_s(
|
10418
|
+
vx, y, item_ct1, iq1s_grid_gpu
|
10419
|
+
);
|
10420
|
+
});
|
10421
|
+
});
|
10422
|
+
}
|
10423
|
+
}
|
9959
10424
|
|
10425
|
+
template <typename dst_t>
|
10426
|
+
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
|
10427
|
+
dpct::queue_ptr stream) {
|
10428
|
+
const int nb = k / QK_K;
|
10429
|
+
{
|
9960
10430
|
dpct::has_capability_or_fail(stream->get_device(),
|
9961
10431
|
{sycl::aspect::fp16});
|
9962
10432
|
|
9963
10433
|
stream->submit([&](sycl::handler &cgh) {
|
9964
|
-
|
9965
|
-
|
9966
|
-
|
10434
|
+
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
10435
|
+
sycl::range<3>(1, 1, 32),
|
10436
|
+
sycl::range<3>(1, 1, 32)),
|
10437
|
+
[=](sycl::nd_item<3> item_ct1) {
|
10438
|
+
dequantize_block_iq1_m(
|
10439
|
+
vx, y, item_ct1, iq1s_grid_gpu
|
10440
|
+
);
|
10441
|
+
});
|
10442
|
+
});
|
10443
|
+
}
|
10444
|
+
}
|
10445
|
+
|
10446
|
+
template <typename dst_t>
|
10447
|
+
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
|
10448
|
+
dpct::queue_ptr stream) {
|
10449
|
+
const int nb = k / QK_K;
|
10450
|
+
{
|
10451
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
10452
|
+
{sycl::aspect::fp16});
|
9967
10453
|
|
10454
|
+
stream->submit([&](sycl::handler &cgh) {
|
9968
10455
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
9969
10456
|
sycl::range<3>(1, 1, 32),
|
9970
10457
|
sycl::range<3>(1, 1, 32)),
|
9971
10458
|
[=](sycl::nd_item<3> item_ct1) {
|
9972
10459
|
dequantize_block_iq2_xxs(
|
9973
|
-
vx, y, item_ct1,
|
9974
|
-
|
10460
|
+
vx, y, item_ct1, iq2xxs_grid,
|
10461
|
+
ksigns_iq2xs, kmask_iq2xs);
|
9975
10462
|
});
|
9976
10463
|
});
|
9977
10464
|
}
|
@@ -9982,105 +10469,130 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
|
|
9982
10469
|
dpct::queue_ptr stream) {
|
9983
10470
|
const int nb = k / QK_K;
|
9984
10471
|
{
|
9985
|
-
|
9986
10472
|
dpct::has_capability_or_fail(stream->get_device(),
|
9987
10473
|
{sycl::aspect::fp16});
|
9988
10474
|
|
9989
10475
|
stream->submit([&](sycl::handler &cgh) {
|
9990
|
-
auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
|
9991
|
-
auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
|
9992
|
-
auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
|
9993
|
-
|
9994
10476
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
9995
10477
|
sycl::range<3>(1, 1, 32),
|
9996
10478
|
sycl::range<3>(1, 1, 32)),
|
9997
10479
|
[=](sycl::nd_item<3> item_ct1) {
|
9998
10480
|
dequantize_block_iq2_xs(
|
9999
|
-
vx, y, item_ct1,
|
10000
|
-
|
10481
|
+
vx, y, item_ct1, iq2xs_grid,
|
10482
|
+
ksigns_iq2xs, kmask_iq2xs);
|
10001
10483
|
});
|
10002
10484
|
});
|
10003
10485
|
}
|
10004
10486
|
}
|
10005
10487
|
|
10006
10488
|
template <typename dst_t>
|
10007
|
-
static void
|
10008
|
-
|
10489
|
+
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
|
10490
|
+
dpct::queue_ptr stream) {
|
10009
10491
|
const int nb = k / QK_K;
|
10010
10492
|
{
|
10011
|
-
|
10012
10493
|
dpct::has_capability_or_fail(stream->get_device(),
|
10013
10494
|
{sycl::aspect::fp16});
|
10014
10495
|
|
10015
10496
|
stream->submit([&](sycl::handler &cgh) {
|
10016
|
-
auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0];
|
10017
|
-
auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
|
10018
|
-
auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
|
10019
|
-
|
10020
10497
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
10021
10498
|
sycl::range<3>(1, 1, 32),
|
10022
10499
|
sycl::range<3>(1, 1, 32)),
|
10023
10500
|
[=](sycl::nd_item<3> item_ct1) {
|
10024
|
-
|
10025
|
-
vx, y, item_ct1, iq3xxs_grid_ptr_ct1,
|
10026
|
-
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
|
10501
|
+
dequantize_block_iq2_s(vx, y, item_ct1);
|
10027
10502
|
});
|
10028
10503
|
});
|
10029
10504
|
}
|
10030
10505
|
}
|
10031
10506
|
|
10507
|
+
|
10032
10508
|
template <typename dst_t>
|
10033
|
-
static void
|
10509
|
+
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
|
10034
10510
|
dpct::queue_ptr stream) {
|
10035
10511
|
const int nb = k / QK_K;
|
10036
10512
|
{
|
10037
|
-
|
10038
10513
|
dpct::has_capability_or_fail(stream->get_device(),
|
10039
10514
|
{sycl::aspect::fp16});
|
10040
10515
|
|
10041
10516
|
stream->submit([&](sycl::handler &cgh) {
|
10042
|
-
auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
|
10043
|
-
auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
|
10044
|
-
auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
|
10045
|
-
|
10046
10517
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
10047
10518
|
sycl::range<3>(1, 1, 32),
|
10048
10519
|
sycl::range<3>(1, 1, 32)),
|
10049
10520
|
[=](sycl::nd_item<3> item_ct1) {
|
10050
|
-
|
10051
|
-
vx, y, item_ct1,
|
10052
|
-
|
10521
|
+
dequantize_block_iq3_xxs(
|
10522
|
+
vx, y, item_ct1, iq3xxs_grid,
|
10523
|
+
ksigns_iq2xs, kmask_iq2xs);
|
10053
10524
|
});
|
10054
10525
|
});
|
10055
10526
|
}
|
10056
10527
|
}
|
10057
10528
|
|
10058
10529
|
template <typename dst_t>
|
10059
|
-
static void
|
10530
|
+
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
|
10060
10531
|
dpct::queue_ptr stream) {
|
10061
10532
|
const int nb = k / QK_K;
|
10062
10533
|
{
|
10063
|
-
|
10064
10534
|
dpct::has_capability_or_fail(stream->get_device(),
|
10065
10535
|
{sycl::aspect::fp16});
|
10066
10536
|
|
10067
10537
|
stream->submit([&](sycl::handler &cgh) {
|
10068
|
-
auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0];
|
10069
|
-
auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
|
10070
|
-
auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
|
10071
|
-
|
10072
10538
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
10073
10539
|
sycl::range<3>(1, 1, 32),
|
10074
10540
|
sycl::range<3>(1, 1, 32)),
|
10075
10541
|
[=](sycl::nd_item<3> item_ct1) {
|
10076
|
-
|
10077
|
-
vx, y, item_ct1,
|
10078
|
-
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
|
10542
|
+
dequantize_block_iq3_s(
|
10543
|
+
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
|
10079
10544
|
});
|
10080
10545
|
});
|
10081
10546
|
}
|
10082
10547
|
}
|
10083
10548
|
|
10549
|
+
template <typename dst_t>
|
10550
|
+
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
|
10551
|
+
dpct::queue_ptr stream) {
|
10552
|
+
const int nb = (k + QK_K - 1) / QK_K;
|
10553
|
+
#if QK_K == 64
|
10554
|
+
dequantize_row_iq4_nl_sycl(vx, y, k, stream);
|
10555
|
+
#else
|
10556
|
+
{
|
10557
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
10558
|
+
{sycl::aspect::fp16});
|
10559
|
+
|
10560
|
+
stream->submit([&](sycl::handler &cgh) {
|
10561
|
+
cgh.parallel_for(
|
10562
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
10563
|
+
sycl::range<3>(1, 1, 32),
|
10564
|
+
sycl::range<3>(1, 1, 32)),
|
10565
|
+
[=](sycl::nd_item<3> item_ct1) {
|
10566
|
+
dequantize_block_iq4_xs(vx, y, item_ct1);
|
10567
|
+
});
|
10568
|
+
});
|
10569
|
+
}
|
10570
|
+
#endif
|
10571
|
+
}
|
10572
|
+
|
10573
|
+
|
10574
|
+
template <typename dst_t>
|
10575
|
+
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
|
10576
|
+
dpct::queue_ptr stream) {
|
10577
|
+
const int nb = (k + QK_K - 1) / QK_K;
|
10578
|
+
{
|
10579
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
10580
|
+
{sycl::aspect::fp16});
|
10581
|
+
|
10582
|
+
stream->submit([&](sycl::handler &cgh) {
|
10583
|
+
cgh.parallel_for(
|
10584
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
10585
|
+
sycl::range<3>(1, 1, 32),
|
10586
|
+
sycl::range<3>(1, 1, 32)),
|
10587
|
+
[=](sycl::nd_item<3> item_ct1) {
|
10588
|
+
dequantize_block_iq4_nl(vx, y, item_ct1);
|
10589
|
+
});
|
10590
|
+
});
|
10591
|
+
}
|
10592
|
+
}
|
10593
|
+
|
10594
|
+
|
10595
|
+
|
10084
10596
|
template <typename src_t, typename dst_t>
|
10085
10597
|
static void convert_unary_sycl(const void *__restrict__ vx,
|
10086
10598
|
dst_t *__restrict__ y, const int k,
|
@@ -10125,16 +10637,24 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
|
|
10125
10637
|
return dequantize_row_q5_K_sycl;
|
10126
10638
|
case GGML_TYPE_Q6_K:
|
10127
10639
|
return dequantize_row_q6_K_sycl;
|
10640
|
+
case GGML_TYPE_IQ1_S:
|
10641
|
+
return dequantize_row_iq1_s_sycl;
|
10642
|
+
case GGML_TYPE_IQ1_M:
|
10643
|
+
return dequantize_row_iq1_m_sycl;
|
10128
10644
|
case GGML_TYPE_IQ2_XXS:
|
10129
10645
|
return dequantize_row_iq2_xxs_sycl;
|
10130
10646
|
case GGML_TYPE_IQ2_XS:
|
10131
10647
|
return dequantize_row_iq2_xs_sycl;
|
10648
|
+
case GGML_TYPE_IQ2_S:
|
10649
|
+
return dequantize_row_iq2_s_sycl;
|
10132
10650
|
case GGML_TYPE_IQ3_XXS:
|
10133
10651
|
return dequantize_row_iq3_xxs_sycl;
|
10134
10652
|
case GGML_TYPE_IQ3_S:
|
10135
10653
|
return dequantize_row_iq3_s_sycl;
|
10136
|
-
case
|
10137
|
-
return
|
10654
|
+
case GGML_TYPE_IQ4_XS:
|
10655
|
+
return dequantize_row_iq4_xs_sycl;
|
10656
|
+
case GGML_TYPE_IQ4_NL:
|
10657
|
+
return dequantize_row_iq4_nl_sycl;
|
10138
10658
|
case GGML_TYPE_F32:
|
10139
10659
|
return convert_unary_sycl<float>;
|
10140
10660
|
default:
|
@@ -10169,16 +10689,24 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
|
|
10169
10689
|
return dequantize_row_q5_K_sycl;
|
10170
10690
|
case GGML_TYPE_Q6_K:
|
10171
10691
|
return dequantize_row_q6_K_sycl;
|
10692
|
+
case GGML_TYPE_IQ1_S:
|
10693
|
+
return dequantize_row_iq1_s_sycl;
|
10694
|
+
case GGML_TYPE_IQ1_M:
|
10695
|
+
return dequantize_row_iq1_m_sycl;
|
10172
10696
|
case GGML_TYPE_IQ2_XXS:
|
10173
10697
|
return dequantize_row_iq2_xxs_sycl;
|
10174
10698
|
case GGML_TYPE_IQ2_XS:
|
10175
10699
|
return dequantize_row_iq2_xs_sycl;
|
10700
|
+
case GGML_TYPE_IQ2_S:
|
10701
|
+
return dequantize_row_iq2_s_sycl;
|
10176
10702
|
case GGML_TYPE_IQ3_XXS:
|
10177
10703
|
return dequantize_row_iq3_xxs_sycl;
|
10178
10704
|
case GGML_TYPE_IQ3_S:
|
10179
10705
|
return dequantize_row_iq3_s_sycl;
|
10180
|
-
case
|
10181
|
-
return
|
10706
|
+
case GGML_TYPE_IQ4_XS:
|
10707
|
+
return dequantize_row_iq4_xs_sycl;
|
10708
|
+
case GGML_TYPE_IQ4_NL:
|
10709
|
+
return dequantize_row_iq4_nl_sycl;
|
10182
10710
|
case GGML_TYPE_F16:
|
10183
10711
|
return convert_unary_sycl<sycl::half>;
|
10184
10712
|
default:
|
@@ -10641,19 +11169,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|
10641
11169
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
10642
11170
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
10643
11171
|
{
|
10644
|
-
|
10645
11172
|
stream->submit([&](sycl::handler &cgh) {
|
10646
|
-
auto iq2xxs_grid_ptr_ct1 = &iq2xxs_grid[0];
|
10647
|
-
auto ksigns_iq2xs_ptr_ct1 = &ksigns_iq2xs[0];
|
10648
|
-
auto kmask_iq2xs_ptr_ct1 = &kmask_iq2xs[0];
|
10649
|
-
|
10650
11173
|
cgh.parallel_for(
|
10651
11174
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
10652
11175
|
[=](sycl::nd_item<3> item_ct1)
|
10653
11176
|
[[intel::reqd_sub_group_size(32)]] {
|
10654
11177
|
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS, block_iq2_xxs, 1>(
|
10655
|
-
vx, vy, dst, ncols, nrows, item_ct1
|
10656
|
-
iq2xxs_grid_ptr_ct1, ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
|
11178
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
10657
11179
|
});
|
10658
11180
|
});
|
10659
11181
|
}
|
@@ -10678,8 +11200,32 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
|
10678
11200
|
[=](sycl::nd_item<3> item_ct1)
|
10679
11201
|
[[intel::reqd_sub_group_size(32)]] {
|
10680
11202
|
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS, block_iq2_xs, 1>(
|
10681
|
-
vx, vy, dst, ncols, nrows, item_ct1
|
10682
|
-
|
11203
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
11204
|
+
});
|
11205
|
+
});
|
11206
|
+
}
|
11207
|
+
}
|
11208
|
+
|
11209
|
+
static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
11210
|
+
float *dst, const int ncols,
|
11211
|
+
const int nrows,
|
11212
|
+
dpct::queue_ptr stream) {
|
11213
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
11214
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
11215
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
11216
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
11217
|
+
{
|
11218
|
+
|
11219
|
+
stream->submit([&](sycl::handler &cgh) {
|
11220
|
+
auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
|
11221
|
+
auto ksigns64_ptr_ct1 = &ksigns64[0];
|
11222
|
+
|
11223
|
+
cgh.parallel_for(
|
11224
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
11225
|
+
[=](sycl::nd_item<3> item_ct1)
|
11226
|
+
[[intel::reqd_sub_group_size(32)]] {
|
11227
|
+
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
|
11228
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
10683
11229
|
});
|
10684
11230
|
});
|
10685
11231
|
}
|
@@ -10704,8 +11250,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|
10704
11250
|
[=](sycl::nd_item<3> item_ct1)
|
10705
11251
|
[[intel::reqd_sub_group_size(32)]] {
|
10706
11252
|
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS, block_iq3_xxs, 1>(
|
10707
|
-
vx, vy, dst, ncols, nrows, item_ct1
|
10708
|
-
iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1);
|
11253
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
10709
11254
|
});
|
10710
11255
|
});
|
10711
11256
|
}
|
@@ -10723,15 +11268,13 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
|
10723
11268
|
|
10724
11269
|
stream->submit([&](sycl::handler &cgh) {
|
10725
11270
|
auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
|
10726
|
-
auto ksigns64_ptr_ct1 = &ksigns64[0];
|
10727
11271
|
|
10728
11272
|
cgh.parallel_for(
|
10729
11273
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
10730
11274
|
[=](sycl::nd_item<3> item_ct1)
|
10731
11275
|
[[intel::reqd_sub_group_size(32)]] {
|
10732
11276
|
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_XS, block_iq3_s, 1>(
|
10733
|
-
vx, vy, dst, ncols, nrows, item_ct1
|
10734
|
-
iq3s_grid_ptr_ct1, ksigns64_ptr_ct1);
|
11277
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
10735
11278
|
});
|
10736
11279
|
});
|
10737
11280
|
}
|
@@ -10756,8 +11299,72 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
|
10756
11299
|
[=](sycl::nd_item<3> item_ct1)
|
10757
11300
|
[[intel::reqd_sub_group_size(32)]] {
|
10758
11301
|
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
10759
|
-
vx, vy, dst, ncols, nrows, item_ct1
|
10760
|
-
|
11302
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
11303
|
+
});
|
11304
|
+
});
|
11305
|
+
}
|
11306
|
+
}
|
11307
|
+
|
11308
|
+
static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
11309
|
+
float *dst, const int ncols,
|
11310
|
+
const int nrows,
|
11311
|
+
dpct::queue_ptr stream) {
|
11312
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
11313
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
11314
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
11315
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
11316
|
+
{
|
11317
|
+
stream->submit([&](sycl::handler &cgh) {
|
11318
|
+
cgh.parallel_for(
|
11319
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
11320
|
+
[=](sycl::nd_item<3> item_ct1)
|
11321
|
+
[[intel::reqd_sub_group_size(32)]] {
|
11322
|
+
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
11323
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
11324
|
+
});
|
11325
|
+
});
|
11326
|
+
}
|
11327
|
+
}
|
11328
|
+
|
11329
|
+
static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
11330
|
+
float *dst, const int ncols,
|
11331
|
+
const int nrows,
|
11332
|
+
dpct::queue_ptr stream) {
|
11333
|
+
GGML_ASSERT(ncols % QK4_NL == 0);
|
11334
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
11335
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
11336
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
11337
|
+
{
|
11338
|
+
|
11339
|
+
stream->submit([&](sycl::handler &cgh) {
|
11340
|
+
cgh.parallel_for(
|
11341
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
11342
|
+
[=](sycl::nd_item<3> item_ct1)
|
11343
|
+
[[intel::reqd_sub_group_size(32)]] {
|
11344
|
+
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
|
11345
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
11346
|
+
});
|
11347
|
+
});
|
11348
|
+
}
|
11349
|
+
}
|
11350
|
+
|
11351
|
+
static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
11352
|
+
float *dst, const int ncols,
|
11353
|
+
const int nrows,
|
11354
|
+
dpct::queue_ptr stream) {
|
11355
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
11356
|
+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
11357
|
+
const sycl::range<3> block_nums(1, 1, block_num_y);
|
11358
|
+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
11359
|
+
{
|
11360
|
+
|
11361
|
+
stream->submit([&](sycl::handler &cgh) {
|
11362
|
+
cgh.parallel_for(
|
11363
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
11364
|
+
[=](sycl::nd_item<3> item_ct1)
|
11365
|
+
[[intel::reqd_sub_group_size(32)]] {
|
11366
|
+
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS, block_iq4_xs, 1>(
|
11367
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
10761
11368
|
});
|
10762
11369
|
});
|
10763
11370
|
}
|
@@ -12381,36 +12988,54 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
|
12381
12988
|
});
|
12382
12989
|
}
|
12383
12990
|
|
12991
|
+
static int next_power_of_2(int x) {
|
12992
|
+
int n = 1;
|
12993
|
+
while (n < x) {
|
12994
|
+
n *= 2;
|
12995
|
+
}
|
12996
|
+
return n;
|
12997
|
+
}
|
12998
|
+
|
12384
12999
|
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
12385
13000
|
const int nrows, ggml_sort_order order,
|
12386
13001
|
dpct::queue_ptr stream) {
|
12387
13002
|
// bitonic sort requires ncols to be power of 2
|
12388
|
-
|
13003
|
+
const int ncols_pad = next_power_of_2(ncols);
|
12389
13004
|
|
12390
|
-
const sycl::range<3> block_dims(1, 1,
|
13005
|
+
const sycl::range<3> block_dims(1, 1, ncols_pad);
|
12391
13006
|
const sycl::range<3> block_nums(1, nrows, 1);
|
13007
|
+
const size_t shared_mem = ncols_pad * sizeof(int);
|
13008
|
+
|
13009
|
+
// GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
13010
|
+
|
12392
13011
|
if (order == GGML_SORT_ORDER_ASC) {
|
12393
|
-
|
12394
|
-
|
12395
|
-
|
12396
|
-
|
12397
|
-
|
12398
|
-
|
12399
|
-
|
12400
|
-
|
12401
|
-
|
12402
|
-
|
13012
|
+
stream->submit([&](sycl::handler &cgh) {
|
13013
|
+
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
13014
|
+
sycl::range<1>(shared_mem), cgh);
|
13015
|
+
|
13016
|
+
cgh.parallel_for(
|
13017
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
13018
|
+
[=](sycl::nd_item<3> item_ct1) {
|
13019
|
+
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
13020
|
+
x, dst, ncols, ncols_pad, item_ct1,
|
13021
|
+
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
13022
|
+
.get());
|
13023
|
+
});
|
13024
|
+
});
|
12403
13025
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
12404
|
-
|
12405
|
-
|
12406
|
-
|
12407
|
-
|
12408
|
-
|
12409
|
-
|
12410
|
-
|
12411
|
-
|
12412
|
-
|
12413
|
-
|
13026
|
+
stream->submit([&](sycl::handler &cgh) {
|
13027
|
+
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
13028
|
+
sycl::range<1>(shared_mem), cgh);
|
13029
|
+
|
13030
|
+
cgh.parallel_for(
|
13031
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
13032
|
+
[=](sycl::nd_item<3> item_ct1) {
|
13033
|
+
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
13034
|
+
x, dst, ncols, ncols_pad, item_ct1,
|
13035
|
+
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
13036
|
+
.get());
|
13037
|
+
});
|
13038
|
+
});
|
12414
13039
|
} else {
|
12415
13040
|
GGML_ASSERT(false);
|
12416
13041
|
}
|
@@ -13538,8 +14163,12 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
|
|
13538
14163
|
case GGML_TYPE_Q5_K:
|
13539
14164
|
case GGML_TYPE_IQ2_XXS:
|
13540
14165
|
case GGML_TYPE_IQ2_XS:
|
14166
|
+
case GGML_TYPE_IQ2_S:
|
13541
14167
|
case GGML_TYPE_IQ1_S:
|
14168
|
+
case GGML_TYPE_IQ1_M:
|
13542
14169
|
case GGML_TYPE_IQ3_XXS:
|
14170
|
+
case GGML_TYPE_IQ4_XS:
|
14171
|
+
case GGML_TYPE_IQ4_NL:
|
13543
14172
|
return max_compute_capability >= VER_GEN9 ? 128 : 64;
|
13544
14173
|
case GGML_TYPE_IQ3_S:
|
13545
14174
|
return max_compute_capability >= VER_GEN9 ? 128 : 64;
|
@@ -13558,11 +14187,20 @@ inline void ggml_sycl_op_mul_mat_vec_q(
|
|
13558
14187
|
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
13559
14188
|
const dpct::queue_ptr &stream) {
|
13560
14189
|
|
13561
|
-
|
14190
|
+
const int64_t ne10 = src1->ne[0];
|
14191
|
+
GGML_ASSERT(ne10 % QK8_1 == 0);
|
13562
14192
|
|
13563
14193
|
const int64_t ne00 = src0->ne[0];
|
13564
14194
|
const int64_t row_diff = row_high - row_low;
|
13565
14195
|
|
14196
|
+
int id;
|
14197
|
+
SYCL_CHECK(
|
14198
|
+
CHECK_TRY_ERROR(id = get_current_device_id()));
|
14199
|
+
|
14200
|
+
// the main device has a larger memory buffer to hold the results from all GPUs
|
14201
|
+
// nrows_dst == nrows of the matrix that the kernel writes into
|
14202
|
+
const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne00 : row_diff;
|
14203
|
+
|
13566
14204
|
switch (src0->type) {
|
13567
14205
|
case GGML_TYPE_Q4_0:
|
13568
14206
|
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
@@ -13594,20 +14232,32 @@ inline void ggml_sycl_op_mul_mat_vec_q(
|
|
13594
14232
|
case GGML_TYPE_Q6_K:
|
13595
14233
|
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
13596
14234
|
break;
|
14235
|
+
case GGML_TYPE_IQ1_S:
|
14236
|
+
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
14237
|
+
break;
|
14238
|
+
case GGML_TYPE_IQ1_M:
|
14239
|
+
mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
14240
|
+
break;
|
13597
14241
|
case GGML_TYPE_IQ2_XXS:
|
13598
14242
|
mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
13599
14243
|
break;
|
13600
14244
|
case GGML_TYPE_IQ2_XS:
|
13601
14245
|
mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
13602
14246
|
break;
|
14247
|
+
case GGML_TYPE_IQ2_S:
|
14248
|
+
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
14249
|
+
break;
|
13603
14250
|
case GGML_TYPE_IQ3_XXS:
|
13604
14251
|
mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
13605
14252
|
break;
|
13606
14253
|
case GGML_TYPE_IQ3_S:
|
13607
14254
|
mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
13608
14255
|
break;
|
13609
|
-
case
|
13610
|
-
|
14256
|
+
case GGML_TYPE_IQ4_NL:
|
14257
|
+
mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
14258
|
+
break;
|
14259
|
+
case GGML_TYPE_IQ4_XS:
|
14260
|
+
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
13611
14261
|
break;
|
13612
14262
|
default:
|
13613
14263
|
GGML_ASSERT(false);
|
@@ -13689,6 +14339,7 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
|
|
13689
14339
|
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
13690
14340
|
break;
|
13691
14341
|
default:
|
14342
|
+
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
|
13692
14343
|
GGML_ASSERT(false);
|
13693
14344
|
break;
|
13694
14345
|
}
|
@@ -14543,8 +15194,8 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
|
|
14543
15194
|
src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
|
14544
15195
|
}
|
14545
15196
|
// do the computation
|
14546
|
-
op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
|
14547
|
-
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream);
|
15197
|
+
SYCL_CHECK(CHECK_TRY_ERROR(op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
|
15198
|
+
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
|
14548
15199
|
/*
|
14549
15200
|
DPCT1010:93: SYCL uses exceptions to report errors and does not
|
14550
15201
|
use the error codes. The call was replaced with 0. You need to
|
@@ -15125,11 +15776,17 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|
15125
15776
|
#ifdef GGML_SYCL_FORCE_DMMV
|
15126
15777
|
const bool use_mul_mat_vec_q = false;
|
15127
15778
|
#else
|
15128
|
-
|
15779
|
+
bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
|
15780
|
+
use_mul_mat_vec_q = use_mul_mat_vec_q ||
|
15781
|
+
(src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
|
15782
|
+
(src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
|
15783
|
+
(src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
|
15784
|
+
(src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
|
15785
|
+
|
15786
|
+
|
15129
15787
|
#endif // GGML_SYCL_FORCE_DMMV
|
15130
15788
|
|
15131
15789
|
if (use_mul_mat_vec_q) {
|
15132
|
-
// NOTE: this kernel does not support ggml_nrows(src1) > 1
|
15133
15790
|
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
|
15134
15791
|
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
|
15135
15792
|
} else {
|
@@ -16985,9 +17642,14 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
|
16985
17642
|
return false;
|
16986
17643
|
}
|
16987
17644
|
ggml_type a_type = a->type;
|
16988
|
-
if (a_type == GGML_TYPE_IQ4_NL
|
16989
|
-
a_type ==
|
16990
|
-
|
17645
|
+
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
|
17646
|
+
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
|
17647
|
+
a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
|
17648
|
+
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
|
17649
|
+
) {
|
17650
|
+
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
17651
|
+
return false;
|
17652
|
+
}
|
16991
17653
|
}
|
16992
17654
|
return true;
|
16993
17655
|
} break;
|