llama_cpp 0.14.3 → 0.14.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +4 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +71 -18
- data/vendor/tmp/llama.cpp/ggml-alloc.c +7 -2
- data/vendor/tmp/llama.cpp/ggml-backend.c +1 -1
- data/vendor/tmp/llama.cpp/ggml-common.h +25 -2
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +300 -9333
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +133 -113
- data/vendor/tmp/llama.cpp/ggml-metal.metal +344 -276
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +638 -43
- data/vendor/tmp/llama.cpp/ggml-quants.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +106 -393
- data/vendor/tmp/llama.cpp/ggml-sycl.h +13 -3
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +37199 -14939
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +329 -308
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -11
- data/vendor/tmp/llama.cpp/ggml.c +133 -93
- data/vendor/tmp/llama.cpp/ggml.h +11 -5
- data/vendor/tmp/llama.cpp/llama.cpp +1763 -431
- data/vendor/tmp/llama.cpp/llama.h +67 -19
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1651 -0
- data/vendor/tmp/llama.cpp/unicode-data.h +16 -0
- data/vendor/tmp/llama.cpp/unicode.cpp +8 -1403
- data/vendor/tmp/llama.cpp/unicode.h +2 -0
- metadata +5 -3
@@ -13,8 +13,8 @@ using namespace metal;
|
|
13
13
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
14
14
|
|
15
15
|
enum ggml_sort_order {
|
16
|
-
|
17
|
-
|
16
|
+
GGML_SORT_ORDER_ASC,
|
17
|
+
GGML_SORT_ORDER_DESC,
|
18
18
|
};
|
19
19
|
|
20
20
|
// general-purpose kernel for addition, multiplication and division of two tensors
|
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
|
|
1973
1973
|
|
1974
1974
|
// bitonic sort implementation following the CUDA kernels as reference
|
1975
1975
|
typedef void (argsort_t)(
|
1976
|
-
device const float
|
1977
|
-
device int32_t
|
1978
|
-
constant int64_t
|
1976
|
+
device const float * x,
|
1977
|
+
device int32_t * dst,
|
1978
|
+
constant int64_t & ncols,
|
1979
|
+
constant int64_t & ncols_pad,
|
1980
|
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
1979
1981
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1980
1982
|
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1981
1983
|
|
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
|
|
1984
1986
|
device const float * x,
|
1985
1987
|
device int32_t * dst,
|
1986
1988
|
constant int64_t & ncols,
|
1989
|
+
constant int64_t & ncols_pad,
|
1990
|
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
1987
1991
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1988
1992
|
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1989
1993
|
// bitonic sort
|
1990
1994
|
int col = tpitg[0];
|
1991
1995
|
int row = tgpig[1];
|
1992
1996
|
|
1993
|
-
if (col >=
|
1997
|
+
if (col >= ncols_pad) return;
|
1994
1998
|
|
1995
|
-
device const float * x_row = x
|
1996
|
-
|
1999
|
+
device const float * x_row = x + row * ncols;
|
2000
|
+
threadgroup int32_t * dst_row = shared_values;
|
1997
2001
|
|
1998
2002
|
// initialize indices
|
1999
|
-
|
2000
|
-
|
2001
|
-
}
|
2003
|
+
dst_row[col] = col;
|
2004
|
+
|
2002
2005
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2003
2006
|
|
2004
|
-
for (int k = 2; k <=
|
2007
|
+
for (int k = 2; k <= ncols_pad; k *= 2) {
|
2005
2008
|
for (int j = k / 2; j > 0; j /= 2) {
|
2006
2009
|
int ixj = col ^ j;
|
2007
2010
|
if (ixj > col) {
|
2008
2011
|
if ((col & k) == 0) {
|
2009
|
-
if (
|
2012
|
+
if (dst_row[col] >= ncols ||
|
2013
|
+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
2014
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
2015
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
2016
|
+
) {
|
2010
2017
|
SWAP(dst_row[col], dst_row[ixj]);
|
2011
2018
|
}
|
2012
2019
|
} else {
|
2013
|
-
if (
|
2020
|
+
if (dst_row[ixj] >= ncols ||
|
2021
|
+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
2022
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
2023
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
2024
|
+
) {
|
2014
2025
|
SWAP(dst_row[col], dst_row[ixj]);
|
2015
2026
|
}
|
2016
2027
|
}
|
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
|
|
2018
2029
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2019
2030
|
}
|
2020
2031
|
}
|
2032
|
+
|
2033
|
+
// copy the result to dst without the padding
|
2034
|
+
if (col < ncols) {
|
2035
|
+
dst[row * ncols + col] = dst_row[col];
|
2036
|
+
}
|
2021
2037
|
}
|
2022
2038
|
|
2023
|
-
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<
|
2024
|
-
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<
|
2039
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
2040
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
2025
2041
|
|
2026
2042
|
kernel void kernel_leaky_relu_f32(
|
2027
2043
|
device const float * src0,
|
@@ -4456,6 +4472,114 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4456
4472
|
}
|
4457
4473
|
}
|
4458
4474
|
|
4475
|
+
void kernel_mul_mv_iq1_m_f32_impl(
|
4476
|
+
device const void * src0,
|
4477
|
+
device const float * src1,
|
4478
|
+
device float * dst,
|
4479
|
+
constant int64_t & ne00,
|
4480
|
+
constant int64_t & ne01,
|
4481
|
+
constant int64_t & ne02,
|
4482
|
+
constant int64_t & ne10,
|
4483
|
+
constant int64_t & ne12,
|
4484
|
+
constant int64_t & ne0,
|
4485
|
+
constant int64_t & ne1,
|
4486
|
+
constant uint & r2,
|
4487
|
+
constant uint & r3,
|
4488
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4489
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4490
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4491
|
+
|
4492
|
+
const int nb = ne00/QK_K;
|
4493
|
+
const int r0 = tgpig.x;
|
4494
|
+
const int r1 = tgpig.y;
|
4495
|
+
const int im = tgpig.z;
|
4496
|
+
|
4497
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
4498
|
+
const int ib_row = first_row * nb;
|
4499
|
+
|
4500
|
+
const uint i12 = im%ne12;
|
4501
|
+
const uint i13 = im/ne12;
|
4502
|
+
|
4503
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4504
|
+
device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
|
4505
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4506
|
+
|
4507
|
+
float yl[32];
|
4508
|
+
float sumf[N_DST]={0.f}, all_sum;
|
4509
|
+
|
4510
|
+
const int nb32 = nb * (QK_K / 32);
|
4511
|
+
|
4512
|
+
const int ix = tiisg;
|
4513
|
+
|
4514
|
+
device const float * y4 = y + 32 * ix;
|
4515
|
+
|
4516
|
+
#if QK_K != 64
|
4517
|
+
iq1m_scale_t scale;
|
4518
|
+
#endif
|
4519
|
+
|
4520
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
4521
|
+
|
4522
|
+
float4 sumy = {0.f};
|
4523
|
+
for (int i = 0; i < 8; ++i) {
|
4524
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
4525
|
+
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
4526
|
+
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
4527
|
+
yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
|
4528
|
+
}
|
4529
|
+
|
4530
|
+
const int ibl = ib32 / (QK_K / 32);
|
4531
|
+
const int ib = ib32 % (QK_K / 32);
|
4532
|
+
|
4533
|
+
device const block_iq1_m * xr = x + ibl;
|
4534
|
+
device const uint8_t * qs = xr->qs + 4 * ib;
|
4535
|
+
device const uint8_t * qh = xr->qh + 2 * ib;
|
4536
|
+
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
4537
|
+
|
4538
|
+
for (int row = 0; row < N_DST; row++) {
|
4539
|
+
|
4540
|
+
#if QK_K != 64
|
4541
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
4542
|
+
#endif
|
4543
|
+
|
4544
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
4545
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
4546
|
+
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
|
4547
|
+
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
4548
|
+
|
4549
|
+
float2 sum = {0.f};
|
4550
|
+
for (int j = 0; j < 4; ++j) {
|
4551
|
+
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
4552
|
+
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
4553
|
+
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
4554
|
+
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
4555
|
+
}
|
4556
|
+
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
4557
|
+
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
4558
|
+
#if QK_K == 64
|
4559
|
+
const float d = (float) *((device const half *)(sc - 1));
|
4560
|
+
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
4561
|
+
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
4562
|
+
#else
|
4563
|
+
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
4564
|
+
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
4565
|
+
#endif
|
4566
|
+
|
4567
|
+
sc += nb*sizeof(block_iq1_m)/2;
|
4568
|
+
qs += nb*sizeof(block_iq1_m);
|
4569
|
+
qh += nb*sizeof(block_iq1_m);
|
4570
|
+
}
|
4571
|
+
|
4572
|
+
y4 += 32 * 32;
|
4573
|
+
}
|
4574
|
+
|
4575
|
+
for (int row = 0; row < N_DST; ++row) {
|
4576
|
+
all_sum = simd_sum(sumf[row]);
|
4577
|
+
if (tiisg == 0) {
|
4578
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4579
|
+
}
|
4580
|
+
}
|
4581
|
+
}
|
4582
|
+
|
4459
4583
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
4460
4584
|
device const void * src0,
|
4461
4585
|
device const float * src1,
|
@@ -4673,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
4673
4797
|
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4674
4798
|
}
|
4675
4799
|
|
4800
|
+
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
4801
|
+
kernel void kernel_mul_mv_iq1_m_f32(
|
4802
|
+
device const void * src0,
|
4803
|
+
device const float * src1,
|
4804
|
+
device float * dst,
|
4805
|
+
constant int64_t & ne00,
|
4806
|
+
constant int64_t & ne01,
|
4807
|
+
constant int64_t & ne02,
|
4808
|
+
constant uint64_t & nb00,
|
4809
|
+
constant uint64_t & nb01,
|
4810
|
+
constant uint64_t & nb02,
|
4811
|
+
constant int64_t & ne10,
|
4812
|
+
constant int64_t & ne11,
|
4813
|
+
constant int64_t & ne12,
|
4814
|
+
constant uint64_t & nb10,
|
4815
|
+
constant uint64_t & nb11,
|
4816
|
+
constant uint64_t & nb12,
|
4817
|
+
constant int64_t & ne0,
|
4818
|
+
constant int64_t & ne1,
|
4819
|
+
constant uint & r2,
|
4820
|
+
constant uint & r3,
|
4821
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4822
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4823
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4824
|
+
|
4825
|
+
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4826
|
+
}
|
4827
|
+
|
4676
4828
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
4677
4829
|
kernel void kernel_mul_mv_iq4_nl_f32(
|
4678
4830
|
device const void * src0,
|
@@ -5146,6 +5298,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|
5146
5298
|
}
|
5147
5299
|
}
|
5148
5300
|
|
5301
|
+
template <typename type4x4>
|
5302
|
+
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
|
5303
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5304
|
+
const int ib32 = il/2;
|
5305
|
+
il = il%2;
|
5306
|
+
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5307
|
+
#if QK_K == 64
|
5308
|
+
const float d = xb->d;
|
5309
|
+
#else
|
5310
|
+
iq1m_scale_t scale;
|
5311
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5312
|
+
const float d = scale.f16;
|
5313
|
+
#endif
|
5314
|
+
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5315
|
+
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
5316
|
+
#if QK_K == 64
|
5317
|
+
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
5318
|
+
#else
|
5319
|
+
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
5320
|
+
#endif
|
5321
|
+
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5322
|
+
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5323
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5324
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
5325
|
+
for (int i = 0; i < 4; ++i) {
|
5326
|
+
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
|
5327
|
+
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
|
5328
|
+
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
|
5329
|
+
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
|
5330
|
+
}
|
5331
|
+
}
|
5332
|
+
|
5149
5333
|
template <typename type4x4>
|
5150
5334
|
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
5151
5335
|
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
@@ -5617,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
5617
5801
|
|
5618
5802
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
5619
5803
|
kernel void kernel_mul_mm_id(
|
5620
|
-
device const uchar *
|
5804
|
+
device const uchar * src0s,
|
5621
5805
|
device const uchar * src1,
|
5622
5806
|
device float * dst,
|
5807
|
+
device const uchar * ids,
|
5623
5808
|
constant uint64_t & nbi1,
|
5624
5809
|
constant int64_t & ne00,
|
5625
5810
|
constant int64_t & ne02,
|
@@ -5636,22 +5821,14 @@ kernel void kernel_mul_mm_id(
|
|
5636
5821
|
constant uint & r2,
|
5637
5822
|
constant uint & r3,
|
5638
5823
|
constant int & idx,
|
5639
|
-
device const uchar * src00,
|
5640
|
-
device const uchar * src01,
|
5641
|
-
device const uchar * src02,
|
5642
|
-
device const uchar * src03,
|
5643
|
-
device const uchar * src04,
|
5644
|
-
device const uchar * src05,
|
5645
|
-
device const uchar * src06,
|
5646
|
-
device const uchar * src07,
|
5647
5824
|
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
5648
5825
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5649
5826
|
uint tiitg[[thread_index_in_threadgroup]],
|
5650
5827
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5651
|
-
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5652
5828
|
|
5653
5829
|
// expert id
|
5654
5830
|
const int32_t id = tgpig.z/(ne12*ne13);
|
5831
|
+
device const uchar * src0 = src0s + id*nb02;
|
5655
5832
|
|
5656
5833
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5657
5834
|
|
@@ -5666,7 +5843,7 @@ kernel void kernel_mul_mm_id(
|
|
5666
5843
|
}
|
5667
5844
|
|
5668
5845
|
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
5669
|
-
|
5846
|
+
src0,
|
5670
5847
|
src1,
|
5671
5848
|
src1ids,
|
5672
5849
|
dst,
|
@@ -5730,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
|
|
5730
5907
|
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5731
5908
|
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5732
5909
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5910
|
+
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5733
5911
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5734
5912
|
#if QK_K == 64
|
5735
5913
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
@@ -5778,6 +5956,7 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
5778
5956
|
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5779
5957
|
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5780
5958
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5959
|
+
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5781
5960
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5782
5961
|
#if QK_K == 64
|
5783
5962
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
@@ -5790,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
5790
5969
|
//
|
5791
5970
|
|
5792
5971
|
typedef void (mat_mm_id_t)(
|
5793
|
-
device const uchar *
|
5972
|
+
device const uchar * src0s,
|
5794
5973
|
device const uchar * src1,
|
5795
5974
|
device float * dst,
|
5975
|
+
device const uchar * ids,
|
5796
5976
|
constant uint64_t & nbi1,
|
5797
5977
|
constant int64_t & ne00,
|
5798
5978
|
constant int64_t & ne02,
|
@@ -5809,14 +5989,6 @@ typedef void (mat_mm_id_t)(
|
|
5809
5989
|
constant uint & r2,
|
5810
5990
|
constant uint & r3,
|
5811
5991
|
constant int & idx,
|
5812
|
-
device const uchar * src00,
|
5813
|
-
device const uchar * src01,
|
5814
|
-
device const uchar * src02,
|
5815
|
-
device const uchar * src03,
|
5816
|
-
device const uchar * src04,
|
5817
|
-
device const uchar * src05,
|
5818
|
-
device const uchar * src06,
|
5819
|
-
device const uchar * src07,
|
5820
5992
|
threadgroup uchar *,
|
5821
5993
|
uint3, uint, uint);
|
5822
5994
|
|
@@ -5838,6 +6010,7 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
|
|
5838
6010
|
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5839
6011
|
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5840
6012
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6013
|
+
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5841
6014
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5842
6015
|
#if QK_K == 64
|
5843
6016
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
@@ -5851,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
5851
6024
|
|
5852
6025
|
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
5853
6026
|
kernel void kernel_mul_mv_id_f32_f32(
|
5854
|
-
device const char *
|
6027
|
+
device const char * src0s,
|
5855
6028
|
device const char * src1,
|
5856
6029
|
device float * dst,
|
6030
|
+
device const char * ids,
|
5857
6031
|
constant uint64_t & nbi1,
|
5858
6032
|
constant int64_t & ne00,
|
5859
6033
|
constant int64_t & ne01,
|
@@ -5874,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
5874
6048
|
constant uint & r2,
|
5875
6049
|
constant uint & r3,
|
5876
6050
|
constant int & idx,
|
5877
|
-
device const char * src00,
|
5878
|
-
device const char * src01,
|
5879
|
-
device const char * src02,
|
5880
|
-
device const char * src03,
|
5881
|
-
device const char * src04,
|
5882
|
-
device const char * src05,
|
5883
|
-
device const char * src06,
|
5884
|
-
device const char * src07,
|
5885
6051
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5886
6052
|
uint tiitg[[thread_index_in_threadgroup]],
|
5887
6053
|
uint tiisg[[thread_index_in_simdgroup]],
|
5888
6054
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5889
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5890
|
-
|
5891
6055
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5892
6056
|
|
5893
6057
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5894
6058
|
|
5895
6059
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6060
|
+
device const char * src0 = src0s + id*nb02;
|
5896
6061
|
|
5897
6062
|
kernel_mul_mv_f32_f32_impl(
|
5898
|
-
src0
|
6063
|
+
src0,
|
5899
6064
|
src1 + bid*nb11,
|
5900
6065
|
dst + bid*ne0,
|
5901
6066
|
ne00,
|
@@ -5920,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
5920
6085
|
|
5921
6086
|
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
5922
6087
|
kernel void kernel_mul_mv_id_f16_f32(
|
5923
|
-
device const char *
|
6088
|
+
device const char * src0s,
|
5924
6089
|
device const char * src1,
|
5925
6090
|
device float * dst,
|
6091
|
+
device const char * ids,
|
5926
6092
|
constant uint64_t & nbi1,
|
5927
6093
|
constant int64_t & ne00,
|
5928
6094
|
constant int64_t & ne01,
|
@@ -5943,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
5943
6109
|
constant uint & r2,
|
5944
6110
|
constant uint & r3,
|
5945
6111
|
constant int & idx,
|
5946
|
-
device const char * src00,
|
5947
|
-
device const char * src01,
|
5948
|
-
device const char * src02,
|
5949
|
-
device const char * src03,
|
5950
|
-
device const char * src04,
|
5951
|
-
device const char * src05,
|
5952
|
-
device const char * src06,
|
5953
|
-
device const char * src07,
|
5954
6112
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5955
6113
|
uint tiitg[[thread_index_in_threadgroup]],
|
5956
6114
|
uint tiisg[[thread_index_in_simdgroup]],
|
5957
6115
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5958
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5959
|
-
|
5960
6116
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5961
6117
|
|
5962
6118
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5963
6119
|
|
5964
6120
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6121
|
+
device const char * src0 = src0s + id*nb02;
|
5965
6122
|
|
5966
6123
|
kernel_mul_mv_f16_f32_impl(
|
5967
|
-
src0
|
6124
|
+
src0,
|
5968
6125
|
src1 + bid*nb11,
|
5969
6126
|
dst + bid*ne0,
|
5970
6127
|
ne00,
|
@@ -5989,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
5989
6146
|
|
5990
6147
|
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
5991
6148
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
5992
|
-
device const char *
|
6149
|
+
device const char * src0s,
|
5993
6150
|
device const char * src1,
|
5994
6151
|
device float * dst,
|
6152
|
+
device const char * ids,
|
5995
6153
|
constant uint64_t & nbi1,
|
5996
6154
|
constant int64_t & ne00,
|
5997
6155
|
constant int64_t & ne01,
|
@@ -6012,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
6012
6170
|
constant uint & r2,
|
6013
6171
|
constant uint & r3,
|
6014
6172
|
constant int & idx,
|
6015
|
-
device const char * src00,
|
6016
|
-
device const char * src01,
|
6017
|
-
device const char * src02,
|
6018
|
-
device const char * src03,
|
6019
|
-
device const char * src04,
|
6020
|
-
device const char * src05,
|
6021
|
-
device const char * src06,
|
6022
|
-
device const char * src07,
|
6023
6173
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6024
6174
|
uint tiitg[[thread_index_in_threadgroup]],
|
6025
6175
|
uint tiisg[[thread_index_in_simdgroup]],
|
6026
6176
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6027
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6028
|
-
|
6029
6177
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6030
6178
|
|
6031
6179
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6032
6180
|
|
6033
6181
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6182
|
+
device const char * src0 = src0s + id*nb02;
|
6034
6183
|
|
6035
6184
|
kernel_mul_mv_q8_0_f32_impl(
|
6036
|
-
src0
|
6185
|
+
src0,
|
6037
6186
|
(device const float *) (src1 + bid*nb11),
|
6038
6187
|
dst + bid*ne0,
|
6039
6188
|
ne00,
|
@@ -6052,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
6052
6201
|
|
6053
6202
|
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
6054
6203
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
6055
|
-
device const char *
|
6204
|
+
device const char * src0s,
|
6056
6205
|
device const char * src1,
|
6057
6206
|
device float * dst,
|
6207
|
+
device const char * ids,
|
6058
6208
|
constant uint64_t & nbi1,
|
6059
6209
|
constant int64_t & ne00,
|
6060
6210
|
constant int64_t & ne01,
|
@@ -6075,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
6075
6225
|
constant uint & r2,
|
6076
6226
|
constant uint & r3,
|
6077
6227
|
constant int & idx,
|
6078
|
-
device const char * src00,
|
6079
|
-
device const char * src01,
|
6080
|
-
device const char * src02,
|
6081
|
-
device const char * src03,
|
6082
|
-
device const char * src04,
|
6083
|
-
device const char * src05,
|
6084
|
-
device const char * src06,
|
6085
|
-
device const char * src07,
|
6086
6228
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6087
6229
|
uint tiitg[[thread_index_in_threadgroup]],
|
6088
6230
|
uint tiisg[[thread_index_in_simdgroup]],
|
6089
6231
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6090
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6091
|
-
|
6092
6232
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6093
6233
|
|
6094
6234
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6095
6235
|
|
6096
6236
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6237
|
+
device const char * src0 = src0s + id*nb02;
|
6097
6238
|
|
6098
6239
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6099
|
-
src0
|
6240
|
+
src0,
|
6100
6241
|
(device const float *) (src1 + bid*nb11),
|
6101
6242
|
dst + bid*ne0,
|
6102
6243
|
ne00,
|
@@ -6115,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
6115
6256
|
|
6116
6257
|
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
6117
6258
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
6118
|
-
device const char *
|
6259
|
+
device const char * src0s,
|
6119
6260
|
device const char * src1,
|
6120
6261
|
device float * dst,
|
6262
|
+
device const char * ids,
|
6121
6263
|
constant uint64_t & nbi1,
|
6122
6264
|
constant int64_t & ne00,
|
6123
6265
|
constant int64_t & ne01,
|
@@ -6138,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
6138
6280
|
constant uint & r2,
|
6139
6281
|
constant uint & r3,
|
6140
6282
|
constant int & idx,
|
6141
|
-
device const char * src00,
|
6142
|
-
device const char * src01,
|
6143
|
-
device const char * src02,
|
6144
|
-
device const char * src03,
|
6145
|
-
device const char * src04,
|
6146
|
-
device const char * src05,
|
6147
|
-
device const char * src06,
|
6148
|
-
device const char * src07,
|
6149
6283
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6150
6284
|
uint tiitg[[thread_index_in_threadgroup]],
|
6151
6285
|
uint tiisg[[thread_index_in_simdgroup]],
|
6152
6286
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6153
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6154
|
-
|
6155
6287
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6156
6288
|
|
6157
6289
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6158
6290
|
|
6159
6291
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6292
|
+
device const char * src0 = src0s + id*nb02;
|
6160
6293
|
|
6161
6294
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6162
|
-
src0
|
6295
|
+
src0,
|
6163
6296
|
(device const float *) (src1 + bid*nb11),
|
6164
6297
|
dst + bid*ne0,
|
6165
6298
|
ne00,
|
@@ -6178,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
6178
6311
|
|
6179
6312
|
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
6180
6313
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
6181
|
-
device const char *
|
6314
|
+
device const char * src0s,
|
6182
6315
|
device const char * src1,
|
6183
6316
|
device float * dst,
|
6317
|
+
device const char * ids,
|
6184
6318
|
constant uint64_t & nbi1,
|
6185
6319
|
constant int64_t & ne00,
|
6186
6320
|
constant int64_t & ne01,
|
@@ -6201,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
6201
6335
|
constant uint & r2,
|
6202
6336
|
constant uint & r3,
|
6203
6337
|
constant int & idx,
|
6204
|
-
device const char * src00,
|
6205
|
-
device const char * src01,
|
6206
|
-
device const char * src02,
|
6207
|
-
device const char * src03,
|
6208
|
-
device const char * src04,
|
6209
|
-
device const char * src05,
|
6210
|
-
device const char * src06,
|
6211
|
-
device const char * src07,
|
6212
6338
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6213
6339
|
uint tiitg[[thread_index_in_threadgroup]],
|
6214
6340
|
uint tiisg[[thread_index_in_simdgroup]],
|
6215
6341
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6216
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6217
|
-
|
6218
6342
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6219
6343
|
|
6220
6344
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6221
6345
|
|
6222
6346
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6347
|
+
device const char * src0 = src0s + id*nb02;
|
6223
6348
|
|
6224
6349
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6225
|
-
src0
|
6350
|
+
src0,
|
6226
6351
|
(device const float *) (src1 + bid*nb11),
|
6227
6352
|
dst + bid*ne0,
|
6228
6353
|
ne00,
|
@@ -6241,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
6241
6366
|
|
6242
6367
|
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
6243
6368
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
6244
|
-
device const char *
|
6369
|
+
device const char * src0s,
|
6245
6370
|
device const char * src1,
|
6246
6371
|
device float * dst,
|
6372
|
+
device const char * ids,
|
6247
6373
|
constant uint64_t & nbi1,
|
6248
6374
|
constant int64_t & ne00,
|
6249
6375
|
constant int64_t & ne01,
|
@@ -6264,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
6264
6390
|
constant uint & r2,
|
6265
6391
|
constant uint & r3,
|
6266
6392
|
constant int & idx,
|
6267
|
-
device const char * src00,
|
6268
|
-
device const char * src01,
|
6269
|
-
device const char * src02,
|
6270
|
-
device const char * src03,
|
6271
|
-
device const char * src04,
|
6272
|
-
device const char * src05,
|
6273
|
-
device const char * src06,
|
6274
|
-
device const char * src07,
|
6275
6393
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6276
6394
|
uint tiitg[[thread_index_in_threadgroup]],
|
6277
6395
|
uint tiisg[[thread_index_in_simdgroup]],
|
6278
6396
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6279
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6280
|
-
|
6281
6397
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6282
6398
|
|
6283
6399
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6284
6400
|
|
6285
6401
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6402
|
+
device const char * src0 = src0s + id*nb02;
|
6286
6403
|
|
6287
6404
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6288
|
-
src0
|
6405
|
+
src0,
|
6289
6406
|
(device const float *) (src1 + bid*nb11),
|
6290
6407
|
dst + bid*ne0,
|
6291
6408
|
ne00,
|
@@ -6304,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
6304
6421
|
|
6305
6422
|
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
6306
6423
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
6307
|
-
device const char *
|
6424
|
+
device const char * src0s,
|
6308
6425
|
device const char * src1,
|
6309
6426
|
device float * dst,
|
6427
|
+
device const char * ids,
|
6310
6428
|
constant uint64_t & nbi1,
|
6311
6429
|
constant int64_t & ne00,
|
6312
6430
|
constant int64_t & ne01,
|
@@ -6327,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
6327
6445
|
constant uint & r2,
|
6328
6446
|
constant uint & r3,
|
6329
6447
|
constant int & idx,
|
6330
|
-
device const char * src00,
|
6331
|
-
device const char * src01,
|
6332
|
-
device const char * src02,
|
6333
|
-
device const char * src03,
|
6334
|
-
device const char * src04,
|
6335
|
-
device const char * src05,
|
6336
|
-
device const char * src06,
|
6337
|
-
device const char * src07,
|
6338
6448
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6339
6449
|
uint tiitg[[thread_index_in_threadgroup]],
|
6340
6450
|
uint tiisg[[thread_index_in_simdgroup]],
|
6341
6451
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6342
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6343
|
-
|
6344
6452
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6345
6453
|
|
6346
6454
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6347
6455
|
|
6348
6456
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6457
|
+
device const char * src0 = src0s + id*nb02;
|
6349
6458
|
|
6350
6459
|
kernel_mul_mv_q2_K_f32_impl(
|
6351
|
-
src0
|
6460
|
+
src0,
|
6352
6461
|
(device const float *) (src1 + bid*nb11),
|
6353
6462
|
dst + bid*ne0,
|
6354
6463
|
ne00,
|
@@ -6367,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
6367
6476
|
|
6368
6477
|
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
6369
6478
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
6370
|
-
device const char *
|
6479
|
+
device const char * src0s,
|
6371
6480
|
device const char * src1,
|
6372
6481
|
device float * dst,
|
6482
|
+
device const char * ids,
|
6373
6483
|
constant uint64_t & nbi1,
|
6374
6484
|
constant int64_t & ne00,
|
6375
6485
|
constant int64_t & ne01,
|
@@ -6390,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
6390
6500
|
constant uint & r2,
|
6391
6501
|
constant uint & r3,
|
6392
6502
|
constant int & idx,
|
6393
|
-
device const char * src00,
|
6394
|
-
device const char * src01,
|
6395
|
-
device const char * src02,
|
6396
|
-
device const char * src03,
|
6397
|
-
device const char * src04,
|
6398
|
-
device const char * src05,
|
6399
|
-
device const char * src06,
|
6400
|
-
device const char * src07,
|
6401
6503
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6402
6504
|
uint tiitg[[thread_index_in_threadgroup]],
|
6403
6505
|
uint tiisg[[thread_index_in_simdgroup]],
|
6404
6506
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6405
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6406
|
-
|
6407
6507
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6408
6508
|
|
6409
6509
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6410
6510
|
|
6411
6511
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6512
|
+
device const char * src0 = src0s + id*nb02;
|
6412
6513
|
|
6413
6514
|
kernel_mul_mv_q3_K_f32_impl(
|
6414
|
-
src0
|
6515
|
+
src0,
|
6415
6516
|
(device const float *) (src1 + bid*nb11),
|
6416
6517
|
dst + bid*ne0,
|
6417
6518
|
ne00,
|
@@ -6430,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
6430
6531
|
|
6431
6532
|
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
6432
6533
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
6433
|
-
device const char *
|
6534
|
+
device const char * src0s,
|
6434
6535
|
device const char * src1,
|
6435
6536
|
device float * dst,
|
6537
|
+
device const char * ids,
|
6436
6538
|
constant uint64_t & nbi1,
|
6437
6539
|
constant int64_t & ne00,
|
6438
6540
|
constant int64_t & ne01,
|
@@ -6453,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
6453
6555
|
constant uint & r2,
|
6454
6556
|
constant uint & r3,
|
6455
6557
|
constant int & idx,
|
6456
|
-
device const char * src00,
|
6457
|
-
device const char * src01,
|
6458
|
-
device const char * src02,
|
6459
|
-
device const char * src03,
|
6460
|
-
device const char * src04,
|
6461
|
-
device const char * src05,
|
6462
|
-
device const char * src06,
|
6463
|
-
device const char * src07,
|
6464
6558
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6465
6559
|
uint tiitg[[thread_index_in_threadgroup]],
|
6466
6560
|
uint tiisg[[thread_index_in_simdgroup]],
|
6467
6561
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6468
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6469
|
-
|
6470
6562
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6471
6563
|
|
6472
6564
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6473
6565
|
|
6474
6566
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6567
|
+
device const char * src0 = src0s + id*nb02;
|
6475
6568
|
|
6476
6569
|
kernel_mul_mv_q4_K_f32_impl(
|
6477
|
-
src0
|
6570
|
+
src0,
|
6478
6571
|
(device const float *) (src1 + bid*nb11),
|
6479
6572
|
dst + bid*ne0,
|
6480
6573
|
ne00,
|
@@ -6493,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
6493
6586
|
|
6494
6587
|
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
6495
6588
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
6496
|
-
device const char *
|
6589
|
+
device const char * src0s,
|
6497
6590
|
device const char * src1,
|
6498
6591
|
device float * dst,
|
6592
|
+
device const char * ids,
|
6499
6593
|
constant uint64_t & nbi1,
|
6500
6594
|
constant int64_t & ne00,
|
6501
6595
|
constant int64_t & ne01,
|
@@ -6516,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
6516
6610
|
constant uint & r2,
|
6517
6611
|
constant uint & r3,
|
6518
6612
|
constant int & idx,
|
6519
|
-
device const char * src00,
|
6520
|
-
device const char * src01,
|
6521
|
-
device const char * src02,
|
6522
|
-
device const char * src03,
|
6523
|
-
device const char * src04,
|
6524
|
-
device const char * src05,
|
6525
|
-
device const char * src06,
|
6526
|
-
device const char * src07,
|
6527
6613
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6528
6614
|
uint tiitg[[thread_index_in_threadgroup]],
|
6529
6615
|
uint tiisg[[thread_index_in_simdgroup]],
|
6530
6616
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6531
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6532
|
-
|
6533
6617
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6534
6618
|
|
6535
6619
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6536
6620
|
|
6537
6621
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6622
|
+
device const char * src0 = src0s + id*nb02;
|
6538
6623
|
|
6539
6624
|
kernel_mul_mv_q5_K_f32_impl(
|
6540
|
-
src0
|
6625
|
+
src0,
|
6541
6626
|
(device const float *) (src1 + bid*nb11),
|
6542
6627
|
dst + bid*ne0,
|
6543
6628
|
ne00,
|
@@ -6556,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
6556
6641
|
|
6557
6642
|
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
6558
6643
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
6559
|
-
device const char *
|
6644
|
+
device const char * src0s,
|
6560
6645
|
device const char * src1,
|
6561
6646
|
device float * dst,
|
6647
|
+
device const char * ids,
|
6562
6648
|
constant uint64_t & nbi1,
|
6563
6649
|
constant int64_t & ne00,
|
6564
6650
|
constant int64_t & ne01,
|
@@ -6579,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
6579
6665
|
constant uint & r2,
|
6580
6666
|
constant uint & r3,
|
6581
6667
|
constant int & idx,
|
6582
|
-
device const char * src00,
|
6583
|
-
device const char * src01,
|
6584
|
-
device const char * src02,
|
6585
|
-
device const char * src03,
|
6586
|
-
device const char * src04,
|
6587
|
-
device const char * src05,
|
6588
|
-
device const char * src06,
|
6589
|
-
device const char * src07,
|
6590
6668
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6591
6669
|
uint tiitg[[thread_index_in_threadgroup]],
|
6592
6670
|
uint tiisg[[thread_index_in_simdgroup]],
|
6593
6671
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6594
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6595
|
-
|
6596
6672
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6597
6673
|
|
6598
6674
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6599
6675
|
|
6600
6676
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6677
|
+
device const char * src0 = src0s + id*nb02;
|
6601
6678
|
|
6602
6679
|
kernel_mul_mv_q6_K_f32_impl(
|
6603
|
-
src0
|
6680
|
+
src0,
|
6604
6681
|
(device const float *) (src1 + bid*nb11),
|
6605
6682
|
dst + bid*ne0,
|
6606
6683
|
ne00,
|
@@ -6619,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
6619
6696
|
|
6620
6697
|
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
6621
6698
|
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
6622
|
-
device const char *
|
6699
|
+
device const char * src0s,
|
6623
6700
|
device const char * src1,
|
6624
6701
|
device float * dst,
|
6702
|
+
device const char * ids,
|
6625
6703
|
constant uint64_t & nbi1,
|
6626
6704
|
constant int64_t & ne00,
|
6627
6705
|
constant int64_t & ne01,
|
@@ -6642,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
6642
6720
|
constant uint & r2,
|
6643
6721
|
constant uint & r3,
|
6644
6722
|
constant int & idx,
|
6645
|
-
device const char * src00,
|
6646
|
-
device const char * src01,
|
6647
|
-
device const char * src02,
|
6648
|
-
device const char * src03,
|
6649
|
-
device const char * src04,
|
6650
|
-
device const char * src05,
|
6651
|
-
device const char * src06,
|
6652
|
-
device const char * src07,
|
6653
6723
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6654
6724
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6655
6725
|
uint tiitg[[thread_index_in_threadgroup]],
|
6656
6726
|
uint tiisg[[thread_index_in_simdgroup]],
|
6657
6727
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6658
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6659
|
-
|
6660
6728
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6661
6729
|
|
6662
6730
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6663
6731
|
|
6664
6732
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6733
|
+
device const char * src0 = src0s + id*nb02;
|
6665
6734
|
|
6666
6735
|
kernel_mul_mv_iq2_xxs_f32_impl(
|
6667
|
-
src0
|
6736
|
+
src0,
|
6668
6737
|
(device const float *) (src1 + bid*nb11),
|
6669
6738
|
dst + bid*ne0,
|
6670
6739
|
ne00,
|
@@ -6684,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
6684
6753
|
|
6685
6754
|
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
6686
6755
|
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
6687
|
-
device const char *
|
6756
|
+
device const char * src0s,
|
6688
6757
|
device const char * src1,
|
6689
6758
|
device float * dst,
|
6759
|
+
device const char * ids,
|
6690
6760
|
constant uint64_t & nbi1,
|
6691
6761
|
constant int64_t & ne00,
|
6692
6762
|
constant int64_t & ne01,
|
@@ -6707,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
6707
6777
|
constant uint & r2,
|
6708
6778
|
constant uint & r3,
|
6709
6779
|
constant int & idx,
|
6710
|
-
device const char * src00,
|
6711
|
-
device const char * src01,
|
6712
|
-
device const char * src02,
|
6713
|
-
device const char * src03,
|
6714
|
-
device const char * src04,
|
6715
|
-
device const char * src05,
|
6716
|
-
device const char * src06,
|
6717
|
-
device const char * src07,
|
6718
6780
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6719
6781
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6720
6782
|
uint tiitg[[thread_index_in_threadgroup]],
|
6721
6783
|
uint tiisg[[thread_index_in_simdgroup]],
|
6722
6784
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6723
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6724
|
-
|
6725
6785
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6726
6786
|
|
6727
6787
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6728
6788
|
|
6729
6789
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6790
|
+
device const char * src0 = src0s + id*nb02;
|
6730
6791
|
|
6731
6792
|
kernel_mul_mv_iq2_xs_f32_impl(
|
6732
|
-
src0
|
6793
|
+
src0,
|
6733
6794
|
(device const float *) (src1 + bid*nb11),
|
6734
6795
|
dst + bid*ne0,
|
6735
6796
|
ne00,
|
@@ -6749,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
6749
6810
|
|
6750
6811
|
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
6751
6812
|
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
6752
|
-
device const char *
|
6813
|
+
device const char * src0s,
|
6753
6814
|
device const char * src1,
|
6754
6815
|
device float * dst,
|
6816
|
+
device const char * ids,
|
6755
6817
|
constant uint64_t & nbi1,
|
6756
6818
|
constant int64_t & ne00,
|
6757
6819
|
constant int64_t & ne01,
|
@@ -6772,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6772
6834
|
constant uint & r2,
|
6773
6835
|
constant uint & r3,
|
6774
6836
|
constant int & idx,
|
6775
|
-
device const char * src00,
|
6776
|
-
device const char * src01,
|
6777
|
-
device const char * src02,
|
6778
|
-
device const char * src03,
|
6779
|
-
device const char * src04,
|
6780
|
-
device const char * src05,
|
6781
|
-
device const char * src06,
|
6782
|
-
device const char * src07,
|
6783
6837
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6784
6838
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6785
6839
|
uint tiitg[[thread_index_in_threadgroup]],
|
6786
6840
|
uint tiisg[[thread_index_in_simdgroup]],
|
6787
6841
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6788
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6789
|
-
|
6790
6842
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6791
6843
|
|
6792
6844
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6793
6845
|
|
6794
6846
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6847
|
+
device const char * src0 = src0s + id*nb02;
|
6795
6848
|
|
6796
6849
|
kernel_mul_mv_iq3_xxs_f32_impl(
|
6797
|
-
src0
|
6850
|
+
src0,
|
6798
6851
|
(device const float *) (src1 + bid*nb11),
|
6799
6852
|
dst + bid*ne0,
|
6800
6853
|
ne00,
|
@@ -6814,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6814
6867
|
|
6815
6868
|
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
6816
6869
|
kernel void kernel_mul_mv_id_iq3_s_f32(
|
6817
|
-
device const char *
|
6870
|
+
device const char * src0s,
|
6818
6871
|
device const char * src1,
|
6819
6872
|
device float * dst,
|
6873
|
+
device const char * ids,
|
6820
6874
|
constant uint64_t & nbi1,
|
6821
6875
|
constant int64_t & ne00,
|
6822
6876
|
constant int64_t & ne01,
|
@@ -6837,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6837
6891
|
constant uint & r2,
|
6838
6892
|
constant uint & r3,
|
6839
6893
|
constant int & idx,
|
6840
|
-
device const char * src00,
|
6841
|
-
device const char * src01,
|
6842
|
-
device const char * src02,
|
6843
|
-
device const char * src03,
|
6844
|
-
device const char * src04,
|
6845
|
-
device const char * src05,
|
6846
|
-
device const char * src06,
|
6847
|
-
device const char * src07,
|
6848
6894
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6849
6895
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6850
6896
|
uint tiitg[[thread_index_in_threadgroup]],
|
6851
6897
|
uint tiisg[[thread_index_in_simdgroup]],
|
6852
6898
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6853
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6854
|
-
|
6855
6899
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6856
6900
|
|
6857
6901
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6858
6902
|
|
6859
6903
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6904
|
+
device const char * src0 = src0s + id*nb02;
|
6860
6905
|
|
6861
6906
|
kernel_mul_mv_iq3_s_f32_impl(
|
6862
|
-
src0
|
6907
|
+
src0,
|
6863
6908
|
(device const float *) (src1 + bid*nb11),
|
6864
6909
|
dst + bid*ne0,
|
6865
6910
|
ne00,
|
@@ -6879,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6879
6924
|
|
6880
6925
|
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
6881
6926
|
kernel void kernel_mul_mv_id_iq2_s_f32(
|
6882
|
-
device const char *
|
6927
|
+
device const char * src0s,
|
6883
6928
|
device const char * src1,
|
6884
6929
|
device float * dst,
|
6930
|
+
device const char * ids,
|
6885
6931
|
constant uint64_t & nbi1,
|
6886
6932
|
constant int64_t & ne00,
|
6887
6933
|
constant int64_t & ne01,
|
@@ -6902,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
6902
6948
|
constant uint & r2,
|
6903
6949
|
constant uint & r3,
|
6904
6950
|
constant int & idx,
|
6905
|
-
device const char * src00,
|
6906
|
-
device const char * src01,
|
6907
|
-
device const char * src02,
|
6908
|
-
device const char * src03,
|
6909
|
-
device const char * src04,
|
6910
|
-
device const char * src05,
|
6911
|
-
device const char * src06,
|
6912
|
-
device const char * src07,
|
6913
6951
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6914
6952
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6915
6953
|
uint tiitg[[thread_index_in_threadgroup]],
|
6916
6954
|
uint tiisg[[thread_index_in_simdgroup]],
|
6917
6955
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6918
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6919
|
-
|
6920
6956
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6921
6957
|
|
6922
6958
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6923
6959
|
|
6924
6960
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6961
|
+
device const char * src0 = src0s + id*nb02;
|
6925
6962
|
|
6926
6963
|
kernel_mul_mv_iq2_s_f32_impl(
|
6927
|
-
src0
|
6964
|
+
src0,
|
6928
6965
|
(device const float *) (src1 + bid*nb11),
|
6929
6966
|
dst + bid*ne0,
|
6930
6967
|
ne00,
|
@@ -6944,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
6944
6981
|
|
6945
6982
|
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6946
6983
|
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6947
|
-
device const char *
|
6984
|
+
device const char * src0s,
|
6948
6985
|
device const char * src1,
|
6949
6986
|
device float * dst,
|
6987
|
+
device const char * ids,
|
6950
6988
|
constant uint64_t & nbi1,
|
6951
6989
|
constant int64_t & ne00,
|
6952
6990
|
constant int64_t & ne01,
|
@@ -6967,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6967
7005
|
constant uint & r2,
|
6968
7006
|
constant uint & r3,
|
6969
7007
|
constant int & idx,
|
6970
|
-
device const char * src00,
|
6971
|
-
device const char * src01,
|
6972
|
-
device const char * src02,
|
6973
|
-
device const char * src03,
|
6974
|
-
device const char * src04,
|
6975
|
-
device const char * src05,
|
6976
|
-
device const char * src06,
|
6977
|
-
device const char * src07,
|
6978
7008
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6979
7009
|
uint tiitg[[thread_index_in_threadgroup]],
|
6980
7010
|
uint tiisg[[thread_index_in_simdgroup]],
|
6981
7011
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6982
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6983
|
-
|
6984
7012
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6985
7013
|
|
6986
7014
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6987
7015
|
|
6988
7016
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7017
|
+
device const char * src0 = src0s + id*nb02;
|
6989
7018
|
|
6990
7019
|
kernel_mul_mv_iq1_s_f32_impl(
|
6991
|
-
src0
|
7020
|
+
src0,
|
7021
|
+
(device const float *) (src1 + bid*nb11),
|
7022
|
+
dst + bid*ne0,
|
7023
|
+
ne00,
|
7024
|
+
ne01,
|
7025
|
+
ne02,
|
7026
|
+
ne10,
|
7027
|
+
ne12,
|
7028
|
+
ne0,
|
7029
|
+
ne1,
|
7030
|
+
r2,
|
7031
|
+
r3,
|
7032
|
+
tgpig,
|
7033
|
+
tiisg,
|
7034
|
+
sgitg);
|
7035
|
+
}
|
7036
|
+
|
7037
|
+
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
7038
|
+
kernel void kernel_mul_mv_id_iq1_m_f32(
|
7039
|
+
device const char * src0s,
|
7040
|
+
device const char * src1,
|
7041
|
+
device float * dst,
|
7042
|
+
device const char * ids,
|
7043
|
+
constant uint64_t & nbi1,
|
7044
|
+
constant int64_t & ne00,
|
7045
|
+
constant int64_t & ne01,
|
7046
|
+
constant int64_t & ne02,
|
7047
|
+
constant uint64_t & nb00,
|
7048
|
+
constant uint64_t & nb01,
|
7049
|
+
constant uint64_t & nb02,
|
7050
|
+
constant int64_t & ne10,
|
7051
|
+
constant int64_t & ne11,
|
7052
|
+
constant int64_t & ne12,
|
7053
|
+
constant int64_t & ne13,
|
7054
|
+
constant uint64_t & nb10,
|
7055
|
+
constant uint64_t & nb11,
|
7056
|
+
constant uint64_t & nb12,
|
7057
|
+
constant int64_t & ne0,
|
7058
|
+
constant int64_t & ne1,
|
7059
|
+
constant uint64_t & nb1,
|
7060
|
+
constant uint & r2,
|
7061
|
+
constant uint & r3,
|
7062
|
+
constant int & idx,
|
7063
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7064
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7065
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7066
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7067
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7068
|
+
|
7069
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7070
|
+
|
7071
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7072
|
+
device const char * src0 = src0s + id*nb02;
|
7073
|
+
|
7074
|
+
kernel_mul_mv_iq1_m_f32_impl(
|
7075
|
+
src0,
|
6992
7076
|
(device const float *) (src1 + bid*nb11),
|
6993
7077
|
dst + bid*ne0,
|
6994
7078
|
ne00,
|
@@ -7007,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
7007
7091
|
|
7008
7092
|
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
7009
7093
|
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
7010
|
-
device const char *
|
7094
|
+
device const char * src0s,
|
7011
7095
|
device const char * src1,
|
7012
7096
|
device float * dst,
|
7097
|
+
device const char * ids,
|
7013
7098
|
constant uint64_t & nbi1,
|
7014
7099
|
constant int64_t & ne00,
|
7015
7100
|
constant int64_t & ne01,
|
@@ -7030,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
7030
7115
|
constant uint & r2,
|
7031
7116
|
constant uint & r3,
|
7032
7117
|
constant int & idx,
|
7033
|
-
device const char * src00,
|
7034
|
-
device const char * src01,
|
7035
|
-
device const char * src02,
|
7036
|
-
device const char * src03,
|
7037
|
-
device const char * src04,
|
7038
|
-
device const char * src05,
|
7039
|
-
device const char * src06,
|
7040
|
-
device const char * src07,
|
7041
7118
|
threadgroup float * shared_values [[threadgroup(0)]],
|
7042
7119
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
7043
7120
|
uint tiitg[[thread_index_in_threadgroup]],
|
7044
7121
|
uint tiisg[[thread_index_in_simdgroup]],
|
7045
7122
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7046
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7047
|
-
|
7048
7123
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
7049
7124
|
|
7050
7125
|
tgpig.z = tgpig.z%(ne12*ne13);
|
7051
7126
|
|
7052
7127
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7128
|
+
device const char * src0 = src0s + id*nb02;
|
7053
7129
|
|
7054
7130
|
kernel_mul_mv_iq4_nl_f32_impl(
|
7055
|
-
src0
|
7131
|
+
src0,
|
7056
7132
|
(device const float *) (src1 + bid*nb11),
|
7057
7133
|
dst + bid*ne0,
|
7058
7134
|
ne00,
|
@@ -7072,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
7072
7148
|
|
7073
7149
|
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
7074
7150
|
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
7075
|
-
device const char *
|
7151
|
+
device const char * src0s,
|
7076
7152
|
device const char * src1,
|
7077
7153
|
device float * dst,
|
7154
|
+
device const char * ids,
|
7078
7155
|
constant uint64_t & nbi1,
|
7079
7156
|
constant int64_t & ne00,
|
7080
7157
|
constant int64_t & ne01,
|
@@ -7095,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|
7095
7172
|
constant uint & r2,
|
7096
7173
|
constant uint & r3,
|
7097
7174
|
constant int & idx,
|
7098
|
-
device const char * src00,
|
7099
|
-
device const char * src01,
|
7100
|
-
device const char * src02,
|
7101
|
-
device const char * src03,
|
7102
|
-
device const char * src04,
|
7103
|
-
device const char * src05,
|
7104
|
-
device const char * src06,
|
7105
|
-
device const char * src07,
|
7106
7175
|
threadgroup float * shared_values [[threadgroup(0)]],
|
7107
7176
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
7108
7177
|
uint tiitg[[thread_index_in_threadgroup]],
|
7109
7178
|
uint tiisg[[thread_index_in_simdgroup]],
|
7110
7179
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7111
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7112
|
-
|
7113
7180
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
7114
7181
|
|
7115
7182
|
tgpig.z = tgpig.z%(ne12*ne13);
|
7116
7183
|
|
7117
7184
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7185
|
+
device const char * src0 = src0s + id*nb02;
|
7118
7186
|
|
7119
7187
|
#if QK_K == 64
|
7120
7188
|
kernel_mul_mv_iq4_nl_f32_impl(
|
7121
7189
|
#else
|
7122
7190
|
kernel_mul_mv_iq4_xs_f32_impl(
|
7123
7191
|
#endif
|
7124
|
-
src0
|
7192
|
+
src0,
|
7125
7193
|
(device const float *) (src1 + bid*nb11),
|
7126
7194
|
dst + bid*ne0,
|
7127
7195
|
ne00,
|