llama_cpp 0.14.3 → 0.14.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/examples/chat.rb +2 -4
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +27 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +14 -0
- data/vendor/tmp/llama.cpp/LICENSE +1 -1
- data/vendor/tmp/llama.cpp/Makefile +81 -20
- data/vendor/tmp/llama.cpp/ggml-alloc.c +7 -2
- data/vendor/tmp/llama.cpp/ggml-backend.c +1 -1
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-common.h +25 -2
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +295 -9324
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +133 -113
- data/vendor/tmp/llama.cpp/ggml-metal.metal +344 -276
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +785 -190
- data/vendor/tmp/llama.cpp/ggml-quants.h +83 -80
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +963 -588
- data/vendor/tmp/llama.cpp/ggml-sycl.h +13 -3
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +37199 -14939
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +329 -308
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -11
- data/vendor/tmp/llama.cpp/ggml.c +141 -101
- data/vendor/tmp/llama.cpp/ggml.h +18 -12
- data/vendor/tmp/llama.cpp/llama.cpp +2519 -625
- data/vendor/tmp/llama.cpp/llama.h +145 -29
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1651 -0
- data/vendor/tmp/llama.cpp/unicode-data.h +16 -0
- data/vendor/tmp/llama.cpp/unicode.cpp +8 -1403
- data/vendor/tmp/llama.cpp/unicode.h +2 -0
- metadata +5 -3
@@ -13,8 +13,8 @@ using namespace metal;
|
|
13
13
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
14
14
|
|
15
15
|
enum ggml_sort_order {
|
16
|
-
|
17
|
-
|
16
|
+
GGML_SORT_ORDER_ASC,
|
17
|
+
GGML_SORT_ORDER_DESC,
|
18
18
|
};
|
19
19
|
|
20
20
|
// general-purpose kernel for addition, multiplication and division of two tensors
|
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
|
|
1973
1973
|
|
1974
1974
|
// bitonic sort implementation following the CUDA kernels as reference
|
1975
1975
|
typedef void (argsort_t)(
|
1976
|
-
device const float
|
1977
|
-
device int32_t
|
1978
|
-
constant int64_t
|
1976
|
+
device const float * x,
|
1977
|
+
device int32_t * dst,
|
1978
|
+
constant int64_t & ncols,
|
1979
|
+
constant int64_t & ncols_pad,
|
1980
|
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
1979
1981
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1980
1982
|
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1981
1983
|
|
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
|
|
1984
1986
|
device const float * x,
|
1985
1987
|
device int32_t * dst,
|
1986
1988
|
constant int64_t & ncols,
|
1989
|
+
constant int64_t & ncols_pad,
|
1990
|
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
1987
1991
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1988
1992
|
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1989
1993
|
// bitonic sort
|
1990
1994
|
int col = tpitg[0];
|
1991
1995
|
int row = tgpig[1];
|
1992
1996
|
|
1993
|
-
if (col >=
|
1997
|
+
if (col >= ncols_pad) return;
|
1994
1998
|
|
1995
|
-
device const float * x_row = x
|
1996
|
-
|
1999
|
+
device const float * x_row = x + row * ncols;
|
2000
|
+
threadgroup int32_t * dst_row = shared_values;
|
1997
2001
|
|
1998
2002
|
// initialize indices
|
1999
|
-
|
2000
|
-
|
2001
|
-
}
|
2003
|
+
dst_row[col] = col;
|
2004
|
+
|
2002
2005
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2003
2006
|
|
2004
|
-
for (int k = 2; k <=
|
2007
|
+
for (int k = 2; k <= ncols_pad; k *= 2) {
|
2005
2008
|
for (int j = k / 2; j > 0; j /= 2) {
|
2006
2009
|
int ixj = col ^ j;
|
2007
2010
|
if (ixj > col) {
|
2008
2011
|
if ((col & k) == 0) {
|
2009
|
-
if (
|
2012
|
+
if (dst_row[col] >= ncols ||
|
2013
|
+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
2014
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
2015
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
2016
|
+
) {
|
2010
2017
|
SWAP(dst_row[col], dst_row[ixj]);
|
2011
2018
|
}
|
2012
2019
|
} else {
|
2013
|
-
if (
|
2020
|
+
if (dst_row[ixj] >= ncols ||
|
2021
|
+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
2022
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
2023
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
2024
|
+
) {
|
2014
2025
|
SWAP(dst_row[col], dst_row[ixj]);
|
2015
2026
|
}
|
2016
2027
|
}
|
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
|
|
2018
2029
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2019
2030
|
}
|
2020
2031
|
}
|
2032
|
+
|
2033
|
+
// copy the result to dst without the padding
|
2034
|
+
if (col < ncols) {
|
2035
|
+
dst[row * ncols + col] = dst_row[col];
|
2036
|
+
}
|
2021
2037
|
}
|
2022
2038
|
|
2023
|
-
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<
|
2024
|
-
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<
|
2039
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
2040
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
2025
2041
|
|
2026
2042
|
kernel void kernel_leaky_relu_f32(
|
2027
2043
|
device const float * src0,
|
@@ -4456,6 +4472,114 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4456
4472
|
}
|
4457
4473
|
}
|
4458
4474
|
|
4475
|
+
void kernel_mul_mv_iq1_m_f32_impl(
|
4476
|
+
device const void * src0,
|
4477
|
+
device const float * src1,
|
4478
|
+
device float * dst,
|
4479
|
+
constant int64_t & ne00,
|
4480
|
+
constant int64_t & ne01,
|
4481
|
+
constant int64_t & ne02,
|
4482
|
+
constant int64_t & ne10,
|
4483
|
+
constant int64_t & ne12,
|
4484
|
+
constant int64_t & ne0,
|
4485
|
+
constant int64_t & ne1,
|
4486
|
+
constant uint & r2,
|
4487
|
+
constant uint & r3,
|
4488
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4489
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4490
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4491
|
+
|
4492
|
+
const int nb = ne00/QK_K;
|
4493
|
+
const int r0 = tgpig.x;
|
4494
|
+
const int r1 = tgpig.y;
|
4495
|
+
const int im = tgpig.z;
|
4496
|
+
|
4497
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
4498
|
+
const int ib_row = first_row * nb;
|
4499
|
+
|
4500
|
+
const uint i12 = im%ne12;
|
4501
|
+
const uint i13 = im/ne12;
|
4502
|
+
|
4503
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4504
|
+
device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
|
4505
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4506
|
+
|
4507
|
+
float yl[32];
|
4508
|
+
float sumf[N_DST]={0.f}, all_sum;
|
4509
|
+
|
4510
|
+
const int nb32 = nb * (QK_K / 32);
|
4511
|
+
|
4512
|
+
const int ix = tiisg;
|
4513
|
+
|
4514
|
+
device const float * y4 = y + 32 * ix;
|
4515
|
+
|
4516
|
+
#if QK_K != 64
|
4517
|
+
iq1m_scale_t scale;
|
4518
|
+
#endif
|
4519
|
+
|
4520
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
4521
|
+
|
4522
|
+
float4 sumy = {0.f};
|
4523
|
+
for (int i = 0; i < 8; ++i) {
|
4524
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
4525
|
+
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
4526
|
+
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
4527
|
+
yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
|
4528
|
+
}
|
4529
|
+
|
4530
|
+
const int ibl = ib32 / (QK_K / 32);
|
4531
|
+
const int ib = ib32 % (QK_K / 32);
|
4532
|
+
|
4533
|
+
device const block_iq1_m * xr = x + ibl;
|
4534
|
+
device const uint8_t * qs = xr->qs + 4 * ib;
|
4535
|
+
device const uint8_t * qh = xr->qh + 2 * ib;
|
4536
|
+
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
4537
|
+
|
4538
|
+
for (int row = 0; row < N_DST; row++) {
|
4539
|
+
|
4540
|
+
#if QK_K != 64
|
4541
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
4542
|
+
#endif
|
4543
|
+
|
4544
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
4545
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
4546
|
+
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
|
4547
|
+
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
4548
|
+
|
4549
|
+
float2 sum = {0.f};
|
4550
|
+
for (int j = 0; j < 4; ++j) {
|
4551
|
+
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
4552
|
+
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
4553
|
+
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
4554
|
+
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
4555
|
+
}
|
4556
|
+
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
4557
|
+
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
4558
|
+
#if QK_K == 64
|
4559
|
+
const float d = (float) *((device const half *)(sc - 1));
|
4560
|
+
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
4561
|
+
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
4562
|
+
#else
|
4563
|
+
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
4564
|
+
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
4565
|
+
#endif
|
4566
|
+
|
4567
|
+
sc += nb*sizeof(block_iq1_m)/2;
|
4568
|
+
qs += nb*sizeof(block_iq1_m);
|
4569
|
+
qh += nb*sizeof(block_iq1_m);
|
4570
|
+
}
|
4571
|
+
|
4572
|
+
y4 += 32 * 32;
|
4573
|
+
}
|
4574
|
+
|
4575
|
+
for (int row = 0; row < N_DST; ++row) {
|
4576
|
+
all_sum = simd_sum(sumf[row]);
|
4577
|
+
if (tiisg == 0) {
|
4578
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4579
|
+
}
|
4580
|
+
}
|
4581
|
+
}
|
4582
|
+
|
4459
4583
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
4460
4584
|
device const void * src0,
|
4461
4585
|
device const float * src1,
|
@@ -4673,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
4673
4797
|
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4674
4798
|
}
|
4675
4799
|
|
4800
|
+
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
4801
|
+
kernel void kernel_mul_mv_iq1_m_f32(
|
4802
|
+
device const void * src0,
|
4803
|
+
device const float * src1,
|
4804
|
+
device float * dst,
|
4805
|
+
constant int64_t & ne00,
|
4806
|
+
constant int64_t & ne01,
|
4807
|
+
constant int64_t & ne02,
|
4808
|
+
constant uint64_t & nb00,
|
4809
|
+
constant uint64_t & nb01,
|
4810
|
+
constant uint64_t & nb02,
|
4811
|
+
constant int64_t & ne10,
|
4812
|
+
constant int64_t & ne11,
|
4813
|
+
constant int64_t & ne12,
|
4814
|
+
constant uint64_t & nb10,
|
4815
|
+
constant uint64_t & nb11,
|
4816
|
+
constant uint64_t & nb12,
|
4817
|
+
constant int64_t & ne0,
|
4818
|
+
constant int64_t & ne1,
|
4819
|
+
constant uint & r2,
|
4820
|
+
constant uint & r3,
|
4821
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4822
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4823
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4824
|
+
|
4825
|
+
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4826
|
+
}
|
4827
|
+
|
4676
4828
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
4677
4829
|
kernel void kernel_mul_mv_iq4_nl_f32(
|
4678
4830
|
device const void * src0,
|
@@ -5146,6 +5298,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|
5146
5298
|
}
|
5147
5299
|
}
|
5148
5300
|
|
5301
|
+
template <typename type4x4>
|
5302
|
+
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
|
5303
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5304
|
+
const int ib32 = il/2;
|
5305
|
+
il = il%2;
|
5306
|
+
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5307
|
+
#if QK_K == 64
|
5308
|
+
const float d = xb->d;
|
5309
|
+
#else
|
5310
|
+
iq1m_scale_t scale;
|
5311
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5312
|
+
const float d = scale.f16;
|
5313
|
+
#endif
|
5314
|
+
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5315
|
+
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
5316
|
+
#if QK_K == 64
|
5317
|
+
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
5318
|
+
#else
|
5319
|
+
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
5320
|
+
#endif
|
5321
|
+
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5322
|
+
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5323
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5324
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
5325
|
+
for (int i = 0; i < 4; ++i) {
|
5326
|
+
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
|
5327
|
+
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
|
5328
|
+
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
|
5329
|
+
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
|
5330
|
+
}
|
5331
|
+
}
|
5332
|
+
|
5149
5333
|
template <typename type4x4>
|
5150
5334
|
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
5151
5335
|
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
@@ -5617,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
5617
5801
|
|
5618
5802
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
5619
5803
|
kernel void kernel_mul_mm_id(
|
5620
|
-
device const uchar *
|
5804
|
+
device const uchar * src0s,
|
5621
5805
|
device const uchar * src1,
|
5622
5806
|
device float * dst,
|
5807
|
+
device const uchar * ids,
|
5623
5808
|
constant uint64_t & nbi1,
|
5624
5809
|
constant int64_t & ne00,
|
5625
5810
|
constant int64_t & ne02,
|
@@ -5636,22 +5821,14 @@ kernel void kernel_mul_mm_id(
|
|
5636
5821
|
constant uint & r2,
|
5637
5822
|
constant uint & r3,
|
5638
5823
|
constant int & idx,
|
5639
|
-
device const uchar * src00,
|
5640
|
-
device const uchar * src01,
|
5641
|
-
device const uchar * src02,
|
5642
|
-
device const uchar * src03,
|
5643
|
-
device const uchar * src04,
|
5644
|
-
device const uchar * src05,
|
5645
|
-
device const uchar * src06,
|
5646
|
-
device const uchar * src07,
|
5647
5824
|
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
5648
5825
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5649
5826
|
uint tiitg[[thread_index_in_threadgroup]],
|
5650
5827
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5651
|
-
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5652
5828
|
|
5653
5829
|
// expert id
|
5654
5830
|
const int32_t id = tgpig.z/(ne12*ne13);
|
5831
|
+
device const uchar * src0 = src0s + id*nb02;
|
5655
5832
|
|
5656
5833
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5657
5834
|
|
@@ -5666,7 +5843,7 @@ kernel void kernel_mul_mm_id(
|
|
5666
5843
|
}
|
5667
5844
|
|
5668
5845
|
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
5669
|
-
|
5846
|
+
src0,
|
5670
5847
|
src1,
|
5671
5848
|
src1ids,
|
5672
5849
|
dst,
|
@@ -5730,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
|
|
5730
5907
|
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5731
5908
|
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5732
5909
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5910
|
+
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5733
5911
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5734
5912
|
#if QK_K == 64
|
5735
5913
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
@@ -5778,6 +5956,7 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
5778
5956
|
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5779
5957
|
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5780
5958
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5959
|
+
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5781
5960
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5782
5961
|
#if QK_K == 64
|
5783
5962
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
@@ -5790,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
5790
5969
|
//
|
5791
5970
|
|
5792
5971
|
typedef void (mat_mm_id_t)(
|
5793
|
-
device const uchar *
|
5972
|
+
device const uchar * src0s,
|
5794
5973
|
device const uchar * src1,
|
5795
5974
|
device float * dst,
|
5975
|
+
device const uchar * ids,
|
5796
5976
|
constant uint64_t & nbi1,
|
5797
5977
|
constant int64_t & ne00,
|
5798
5978
|
constant int64_t & ne02,
|
@@ -5809,14 +5989,6 @@ typedef void (mat_mm_id_t)(
|
|
5809
5989
|
constant uint & r2,
|
5810
5990
|
constant uint & r3,
|
5811
5991
|
constant int & idx,
|
5812
|
-
device const uchar * src00,
|
5813
|
-
device const uchar * src01,
|
5814
|
-
device const uchar * src02,
|
5815
|
-
device const uchar * src03,
|
5816
|
-
device const uchar * src04,
|
5817
|
-
device const uchar * src05,
|
5818
|
-
device const uchar * src06,
|
5819
|
-
device const uchar * src07,
|
5820
5992
|
threadgroup uchar *,
|
5821
5993
|
uint3, uint, uint);
|
5822
5994
|
|
@@ -5838,6 +6010,7 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
|
|
5838
6010
|
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5839
6011
|
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5840
6012
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6013
|
+
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5841
6014
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5842
6015
|
#if QK_K == 64
|
5843
6016
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
@@ -5851,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
5851
6024
|
|
5852
6025
|
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
5853
6026
|
kernel void kernel_mul_mv_id_f32_f32(
|
5854
|
-
device const char *
|
6027
|
+
device const char * src0s,
|
5855
6028
|
device const char * src1,
|
5856
6029
|
device float * dst,
|
6030
|
+
device const char * ids,
|
5857
6031
|
constant uint64_t & nbi1,
|
5858
6032
|
constant int64_t & ne00,
|
5859
6033
|
constant int64_t & ne01,
|
@@ -5874,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
5874
6048
|
constant uint & r2,
|
5875
6049
|
constant uint & r3,
|
5876
6050
|
constant int & idx,
|
5877
|
-
device const char * src00,
|
5878
|
-
device const char * src01,
|
5879
|
-
device const char * src02,
|
5880
|
-
device const char * src03,
|
5881
|
-
device const char * src04,
|
5882
|
-
device const char * src05,
|
5883
|
-
device const char * src06,
|
5884
|
-
device const char * src07,
|
5885
6051
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5886
6052
|
uint tiitg[[thread_index_in_threadgroup]],
|
5887
6053
|
uint tiisg[[thread_index_in_simdgroup]],
|
5888
6054
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5889
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5890
|
-
|
5891
6055
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5892
6056
|
|
5893
6057
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5894
6058
|
|
5895
6059
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6060
|
+
device const char * src0 = src0s + id*nb02;
|
5896
6061
|
|
5897
6062
|
kernel_mul_mv_f32_f32_impl(
|
5898
|
-
src0
|
6063
|
+
src0,
|
5899
6064
|
src1 + bid*nb11,
|
5900
6065
|
dst + bid*ne0,
|
5901
6066
|
ne00,
|
@@ -5920,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
5920
6085
|
|
5921
6086
|
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
5922
6087
|
kernel void kernel_mul_mv_id_f16_f32(
|
5923
|
-
device const char *
|
6088
|
+
device const char * src0s,
|
5924
6089
|
device const char * src1,
|
5925
6090
|
device float * dst,
|
6091
|
+
device const char * ids,
|
5926
6092
|
constant uint64_t & nbi1,
|
5927
6093
|
constant int64_t & ne00,
|
5928
6094
|
constant int64_t & ne01,
|
@@ -5943,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
5943
6109
|
constant uint & r2,
|
5944
6110
|
constant uint & r3,
|
5945
6111
|
constant int & idx,
|
5946
|
-
device const char * src00,
|
5947
|
-
device const char * src01,
|
5948
|
-
device const char * src02,
|
5949
|
-
device const char * src03,
|
5950
|
-
device const char * src04,
|
5951
|
-
device const char * src05,
|
5952
|
-
device const char * src06,
|
5953
|
-
device const char * src07,
|
5954
6112
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5955
6113
|
uint tiitg[[thread_index_in_threadgroup]],
|
5956
6114
|
uint tiisg[[thread_index_in_simdgroup]],
|
5957
6115
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5958
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5959
|
-
|
5960
6116
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5961
6117
|
|
5962
6118
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5963
6119
|
|
5964
6120
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6121
|
+
device const char * src0 = src0s + id*nb02;
|
5965
6122
|
|
5966
6123
|
kernel_mul_mv_f16_f32_impl(
|
5967
|
-
src0
|
6124
|
+
src0,
|
5968
6125
|
src1 + bid*nb11,
|
5969
6126
|
dst + bid*ne0,
|
5970
6127
|
ne00,
|
@@ -5989,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
5989
6146
|
|
5990
6147
|
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
5991
6148
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
5992
|
-
device const char *
|
6149
|
+
device const char * src0s,
|
5993
6150
|
device const char * src1,
|
5994
6151
|
device float * dst,
|
6152
|
+
device const char * ids,
|
5995
6153
|
constant uint64_t & nbi1,
|
5996
6154
|
constant int64_t & ne00,
|
5997
6155
|
constant int64_t & ne01,
|
@@ -6012,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
6012
6170
|
constant uint & r2,
|
6013
6171
|
constant uint & r3,
|
6014
6172
|
constant int & idx,
|
6015
|
-
device const char * src00,
|
6016
|
-
device const char * src01,
|
6017
|
-
device const char * src02,
|
6018
|
-
device const char * src03,
|
6019
|
-
device const char * src04,
|
6020
|
-
device const char * src05,
|
6021
|
-
device const char * src06,
|
6022
|
-
device const char * src07,
|
6023
6173
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6024
6174
|
uint tiitg[[thread_index_in_threadgroup]],
|
6025
6175
|
uint tiisg[[thread_index_in_simdgroup]],
|
6026
6176
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6027
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6028
|
-
|
6029
6177
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6030
6178
|
|
6031
6179
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6032
6180
|
|
6033
6181
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6182
|
+
device const char * src0 = src0s + id*nb02;
|
6034
6183
|
|
6035
6184
|
kernel_mul_mv_q8_0_f32_impl(
|
6036
|
-
src0
|
6185
|
+
src0,
|
6037
6186
|
(device const float *) (src1 + bid*nb11),
|
6038
6187
|
dst + bid*ne0,
|
6039
6188
|
ne00,
|
@@ -6052,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
6052
6201
|
|
6053
6202
|
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
6054
6203
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
6055
|
-
device const char *
|
6204
|
+
device const char * src0s,
|
6056
6205
|
device const char * src1,
|
6057
6206
|
device float * dst,
|
6207
|
+
device const char * ids,
|
6058
6208
|
constant uint64_t & nbi1,
|
6059
6209
|
constant int64_t & ne00,
|
6060
6210
|
constant int64_t & ne01,
|
@@ -6075,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
6075
6225
|
constant uint & r2,
|
6076
6226
|
constant uint & r3,
|
6077
6227
|
constant int & idx,
|
6078
|
-
device const char * src00,
|
6079
|
-
device const char * src01,
|
6080
|
-
device const char * src02,
|
6081
|
-
device const char * src03,
|
6082
|
-
device const char * src04,
|
6083
|
-
device const char * src05,
|
6084
|
-
device const char * src06,
|
6085
|
-
device const char * src07,
|
6086
6228
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6087
6229
|
uint tiitg[[thread_index_in_threadgroup]],
|
6088
6230
|
uint tiisg[[thread_index_in_simdgroup]],
|
6089
6231
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6090
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6091
|
-
|
6092
6232
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6093
6233
|
|
6094
6234
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6095
6235
|
|
6096
6236
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6237
|
+
device const char * src0 = src0s + id*nb02;
|
6097
6238
|
|
6098
6239
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6099
|
-
src0
|
6240
|
+
src0,
|
6100
6241
|
(device const float *) (src1 + bid*nb11),
|
6101
6242
|
dst + bid*ne0,
|
6102
6243
|
ne00,
|
@@ -6115,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
6115
6256
|
|
6116
6257
|
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
6117
6258
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
6118
|
-
device const char *
|
6259
|
+
device const char * src0s,
|
6119
6260
|
device const char * src1,
|
6120
6261
|
device float * dst,
|
6262
|
+
device const char * ids,
|
6121
6263
|
constant uint64_t & nbi1,
|
6122
6264
|
constant int64_t & ne00,
|
6123
6265
|
constant int64_t & ne01,
|
@@ -6138,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
6138
6280
|
constant uint & r2,
|
6139
6281
|
constant uint & r3,
|
6140
6282
|
constant int & idx,
|
6141
|
-
device const char * src00,
|
6142
|
-
device const char * src01,
|
6143
|
-
device const char * src02,
|
6144
|
-
device const char * src03,
|
6145
|
-
device const char * src04,
|
6146
|
-
device const char * src05,
|
6147
|
-
device const char * src06,
|
6148
|
-
device const char * src07,
|
6149
6283
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6150
6284
|
uint tiitg[[thread_index_in_threadgroup]],
|
6151
6285
|
uint tiisg[[thread_index_in_simdgroup]],
|
6152
6286
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6153
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6154
|
-
|
6155
6287
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6156
6288
|
|
6157
6289
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6158
6290
|
|
6159
6291
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6292
|
+
device const char * src0 = src0s + id*nb02;
|
6160
6293
|
|
6161
6294
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6162
|
-
src0
|
6295
|
+
src0,
|
6163
6296
|
(device const float *) (src1 + bid*nb11),
|
6164
6297
|
dst + bid*ne0,
|
6165
6298
|
ne00,
|
@@ -6178,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
6178
6311
|
|
6179
6312
|
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
6180
6313
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
6181
|
-
device const char *
|
6314
|
+
device const char * src0s,
|
6182
6315
|
device const char * src1,
|
6183
6316
|
device float * dst,
|
6317
|
+
device const char * ids,
|
6184
6318
|
constant uint64_t & nbi1,
|
6185
6319
|
constant int64_t & ne00,
|
6186
6320
|
constant int64_t & ne01,
|
@@ -6201,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
6201
6335
|
constant uint & r2,
|
6202
6336
|
constant uint & r3,
|
6203
6337
|
constant int & idx,
|
6204
|
-
device const char * src00,
|
6205
|
-
device const char * src01,
|
6206
|
-
device const char * src02,
|
6207
|
-
device const char * src03,
|
6208
|
-
device const char * src04,
|
6209
|
-
device const char * src05,
|
6210
|
-
device const char * src06,
|
6211
|
-
device const char * src07,
|
6212
6338
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6213
6339
|
uint tiitg[[thread_index_in_threadgroup]],
|
6214
6340
|
uint tiisg[[thread_index_in_simdgroup]],
|
6215
6341
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6216
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6217
|
-
|
6218
6342
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6219
6343
|
|
6220
6344
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6221
6345
|
|
6222
6346
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6347
|
+
device const char * src0 = src0s + id*nb02;
|
6223
6348
|
|
6224
6349
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6225
|
-
src0
|
6350
|
+
src0,
|
6226
6351
|
(device const float *) (src1 + bid*nb11),
|
6227
6352
|
dst + bid*ne0,
|
6228
6353
|
ne00,
|
@@ -6241,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
6241
6366
|
|
6242
6367
|
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
6243
6368
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
6244
|
-
device const char *
|
6369
|
+
device const char * src0s,
|
6245
6370
|
device const char * src1,
|
6246
6371
|
device float * dst,
|
6372
|
+
device const char * ids,
|
6247
6373
|
constant uint64_t & nbi1,
|
6248
6374
|
constant int64_t & ne00,
|
6249
6375
|
constant int64_t & ne01,
|
@@ -6264,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
6264
6390
|
constant uint & r2,
|
6265
6391
|
constant uint & r3,
|
6266
6392
|
constant int & idx,
|
6267
|
-
device const char * src00,
|
6268
|
-
device const char * src01,
|
6269
|
-
device const char * src02,
|
6270
|
-
device const char * src03,
|
6271
|
-
device const char * src04,
|
6272
|
-
device const char * src05,
|
6273
|
-
device const char * src06,
|
6274
|
-
device const char * src07,
|
6275
6393
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6276
6394
|
uint tiitg[[thread_index_in_threadgroup]],
|
6277
6395
|
uint tiisg[[thread_index_in_simdgroup]],
|
6278
6396
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6279
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6280
|
-
|
6281
6397
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6282
6398
|
|
6283
6399
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6284
6400
|
|
6285
6401
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6402
|
+
device const char * src0 = src0s + id*nb02;
|
6286
6403
|
|
6287
6404
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6288
|
-
src0
|
6405
|
+
src0,
|
6289
6406
|
(device const float *) (src1 + bid*nb11),
|
6290
6407
|
dst + bid*ne0,
|
6291
6408
|
ne00,
|
@@ -6304,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
6304
6421
|
|
6305
6422
|
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
6306
6423
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
6307
|
-
device const char *
|
6424
|
+
device const char * src0s,
|
6308
6425
|
device const char * src1,
|
6309
6426
|
device float * dst,
|
6427
|
+
device const char * ids,
|
6310
6428
|
constant uint64_t & nbi1,
|
6311
6429
|
constant int64_t & ne00,
|
6312
6430
|
constant int64_t & ne01,
|
@@ -6327,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
6327
6445
|
constant uint & r2,
|
6328
6446
|
constant uint & r3,
|
6329
6447
|
constant int & idx,
|
6330
|
-
device const char * src00,
|
6331
|
-
device const char * src01,
|
6332
|
-
device const char * src02,
|
6333
|
-
device const char * src03,
|
6334
|
-
device const char * src04,
|
6335
|
-
device const char * src05,
|
6336
|
-
device const char * src06,
|
6337
|
-
device const char * src07,
|
6338
6448
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6339
6449
|
uint tiitg[[thread_index_in_threadgroup]],
|
6340
6450
|
uint tiisg[[thread_index_in_simdgroup]],
|
6341
6451
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6342
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6343
|
-
|
6344
6452
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6345
6453
|
|
6346
6454
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6347
6455
|
|
6348
6456
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6457
|
+
device const char * src0 = src0s + id*nb02;
|
6349
6458
|
|
6350
6459
|
kernel_mul_mv_q2_K_f32_impl(
|
6351
|
-
src0
|
6460
|
+
src0,
|
6352
6461
|
(device const float *) (src1 + bid*nb11),
|
6353
6462
|
dst + bid*ne0,
|
6354
6463
|
ne00,
|
@@ -6367,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
6367
6476
|
|
6368
6477
|
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
6369
6478
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
6370
|
-
device const char *
|
6479
|
+
device const char * src0s,
|
6371
6480
|
device const char * src1,
|
6372
6481
|
device float * dst,
|
6482
|
+
device const char * ids,
|
6373
6483
|
constant uint64_t & nbi1,
|
6374
6484
|
constant int64_t & ne00,
|
6375
6485
|
constant int64_t & ne01,
|
@@ -6390,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
6390
6500
|
constant uint & r2,
|
6391
6501
|
constant uint & r3,
|
6392
6502
|
constant int & idx,
|
6393
|
-
device const char * src00,
|
6394
|
-
device const char * src01,
|
6395
|
-
device const char * src02,
|
6396
|
-
device const char * src03,
|
6397
|
-
device const char * src04,
|
6398
|
-
device const char * src05,
|
6399
|
-
device const char * src06,
|
6400
|
-
device const char * src07,
|
6401
6503
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6402
6504
|
uint tiitg[[thread_index_in_threadgroup]],
|
6403
6505
|
uint tiisg[[thread_index_in_simdgroup]],
|
6404
6506
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6405
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6406
|
-
|
6407
6507
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6408
6508
|
|
6409
6509
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6410
6510
|
|
6411
6511
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6512
|
+
device const char * src0 = src0s + id*nb02;
|
6412
6513
|
|
6413
6514
|
kernel_mul_mv_q3_K_f32_impl(
|
6414
|
-
src0
|
6515
|
+
src0,
|
6415
6516
|
(device const float *) (src1 + bid*nb11),
|
6416
6517
|
dst + bid*ne0,
|
6417
6518
|
ne00,
|
@@ -6430,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
6430
6531
|
|
6431
6532
|
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
6432
6533
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
6433
|
-
device const char *
|
6534
|
+
device const char * src0s,
|
6434
6535
|
device const char * src1,
|
6435
6536
|
device float * dst,
|
6537
|
+
device const char * ids,
|
6436
6538
|
constant uint64_t & nbi1,
|
6437
6539
|
constant int64_t & ne00,
|
6438
6540
|
constant int64_t & ne01,
|
@@ -6453,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
6453
6555
|
constant uint & r2,
|
6454
6556
|
constant uint & r3,
|
6455
6557
|
constant int & idx,
|
6456
|
-
device const char * src00,
|
6457
|
-
device const char * src01,
|
6458
|
-
device const char * src02,
|
6459
|
-
device const char * src03,
|
6460
|
-
device const char * src04,
|
6461
|
-
device const char * src05,
|
6462
|
-
device const char * src06,
|
6463
|
-
device const char * src07,
|
6464
6558
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6465
6559
|
uint tiitg[[thread_index_in_threadgroup]],
|
6466
6560
|
uint tiisg[[thread_index_in_simdgroup]],
|
6467
6561
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6468
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6469
|
-
|
6470
6562
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6471
6563
|
|
6472
6564
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6473
6565
|
|
6474
6566
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6567
|
+
device const char * src0 = src0s + id*nb02;
|
6475
6568
|
|
6476
6569
|
kernel_mul_mv_q4_K_f32_impl(
|
6477
|
-
src0
|
6570
|
+
src0,
|
6478
6571
|
(device const float *) (src1 + bid*nb11),
|
6479
6572
|
dst + bid*ne0,
|
6480
6573
|
ne00,
|
@@ -6493,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
6493
6586
|
|
6494
6587
|
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
6495
6588
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
6496
|
-
device const char *
|
6589
|
+
device const char * src0s,
|
6497
6590
|
device const char * src1,
|
6498
6591
|
device float * dst,
|
6592
|
+
device const char * ids,
|
6499
6593
|
constant uint64_t & nbi1,
|
6500
6594
|
constant int64_t & ne00,
|
6501
6595
|
constant int64_t & ne01,
|
@@ -6516,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
6516
6610
|
constant uint & r2,
|
6517
6611
|
constant uint & r3,
|
6518
6612
|
constant int & idx,
|
6519
|
-
device const char * src00,
|
6520
|
-
device const char * src01,
|
6521
|
-
device const char * src02,
|
6522
|
-
device const char * src03,
|
6523
|
-
device const char * src04,
|
6524
|
-
device const char * src05,
|
6525
|
-
device const char * src06,
|
6526
|
-
device const char * src07,
|
6527
6613
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6528
6614
|
uint tiitg[[thread_index_in_threadgroup]],
|
6529
6615
|
uint tiisg[[thread_index_in_simdgroup]],
|
6530
6616
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6531
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6532
|
-
|
6533
6617
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6534
6618
|
|
6535
6619
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6536
6620
|
|
6537
6621
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6622
|
+
device const char * src0 = src0s + id*nb02;
|
6538
6623
|
|
6539
6624
|
kernel_mul_mv_q5_K_f32_impl(
|
6540
|
-
src0
|
6625
|
+
src0,
|
6541
6626
|
(device const float *) (src1 + bid*nb11),
|
6542
6627
|
dst + bid*ne0,
|
6543
6628
|
ne00,
|
@@ -6556,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
6556
6641
|
|
6557
6642
|
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
6558
6643
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
6559
|
-
device const char *
|
6644
|
+
device const char * src0s,
|
6560
6645
|
device const char * src1,
|
6561
6646
|
device float * dst,
|
6647
|
+
device const char * ids,
|
6562
6648
|
constant uint64_t & nbi1,
|
6563
6649
|
constant int64_t & ne00,
|
6564
6650
|
constant int64_t & ne01,
|
@@ -6579,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
6579
6665
|
constant uint & r2,
|
6580
6666
|
constant uint & r3,
|
6581
6667
|
constant int & idx,
|
6582
|
-
device const char * src00,
|
6583
|
-
device const char * src01,
|
6584
|
-
device const char * src02,
|
6585
|
-
device const char * src03,
|
6586
|
-
device const char * src04,
|
6587
|
-
device const char * src05,
|
6588
|
-
device const char * src06,
|
6589
|
-
device const char * src07,
|
6590
6668
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6591
6669
|
uint tiitg[[thread_index_in_threadgroup]],
|
6592
6670
|
uint tiisg[[thread_index_in_simdgroup]],
|
6593
6671
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6594
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6595
|
-
|
6596
6672
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6597
6673
|
|
6598
6674
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6599
6675
|
|
6600
6676
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6677
|
+
device const char * src0 = src0s + id*nb02;
|
6601
6678
|
|
6602
6679
|
kernel_mul_mv_q6_K_f32_impl(
|
6603
|
-
src0
|
6680
|
+
src0,
|
6604
6681
|
(device const float *) (src1 + bid*nb11),
|
6605
6682
|
dst + bid*ne0,
|
6606
6683
|
ne00,
|
@@ -6619,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
6619
6696
|
|
6620
6697
|
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
6621
6698
|
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
6622
|
-
device const char *
|
6699
|
+
device const char * src0s,
|
6623
6700
|
device const char * src1,
|
6624
6701
|
device float * dst,
|
6702
|
+
device const char * ids,
|
6625
6703
|
constant uint64_t & nbi1,
|
6626
6704
|
constant int64_t & ne00,
|
6627
6705
|
constant int64_t & ne01,
|
@@ -6642,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
6642
6720
|
constant uint & r2,
|
6643
6721
|
constant uint & r3,
|
6644
6722
|
constant int & idx,
|
6645
|
-
device const char * src00,
|
6646
|
-
device const char * src01,
|
6647
|
-
device const char * src02,
|
6648
|
-
device const char * src03,
|
6649
|
-
device const char * src04,
|
6650
|
-
device const char * src05,
|
6651
|
-
device const char * src06,
|
6652
|
-
device const char * src07,
|
6653
6723
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6654
6724
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6655
6725
|
uint tiitg[[thread_index_in_threadgroup]],
|
6656
6726
|
uint tiisg[[thread_index_in_simdgroup]],
|
6657
6727
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6658
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6659
|
-
|
6660
6728
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6661
6729
|
|
6662
6730
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6663
6731
|
|
6664
6732
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6733
|
+
device const char * src0 = src0s + id*nb02;
|
6665
6734
|
|
6666
6735
|
kernel_mul_mv_iq2_xxs_f32_impl(
|
6667
|
-
src0
|
6736
|
+
src0,
|
6668
6737
|
(device const float *) (src1 + bid*nb11),
|
6669
6738
|
dst + bid*ne0,
|
6670
6739
|
ne00,
|
@@ -6684,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
6684
6753
|
|
6685
6754
|
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
6686
6755
|
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
6687
|
-
device const char *
|
6756
|
+
device const char * src0s,
|
6688
6757
|
device const char * src1,
|
6689
6758
|
device float * dst,
|
6759
|
+
device const char * ids,
|
6690
6760
|
constant uint64_t & nbi1,
|
6691
6761
|
constant int64_t & ne00,
|
6692
6762
|
constant int64_t & ne01,
|
@@ -6707,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
6707
6777
|
constant uint & r2,
|
6708
6778
|
constant uint & r3,
|
6709
6779
|
constant int & idx,
|
6710
|
-
device const char * src00,
|
6711
|
-
device const char * src01,
|
6712
|
-
device const char * src02,
|
6713
|
-
device const char * src03,
|
6714
|
-
device const char * src04,
|
6715
|
-
device const char * src05,
|
6716
|
-
device const char * src06,
|
6717
|
-
device const char * src07,
|
6718
6780
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6719
6781
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6720
6782
|
uint tiitg[[thread_index_in_threadgroup]],
|
6721
6783
|
uint tiisg[[thread_index_in_simdgroup]],
|
6722
6784
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6723
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6724
|
-
|
6725
6785
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6726
6786
|
|
6727
6787
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6728
6788
|
|
6729
6789
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6790
|
+
device const char * src0 = src0s + id*nb02;
|
6730
6791
|
|
6731
6792
|
kernel_mul_mv_iq2_xs_f32_impl(
|
6732
|
-
src0
|
6793
|
+
src0,
|
6733
6794
|
(device const float *) (src1 + bid*nb11),
|
6734
6795
|
dst + bid*ne0,
|
6735
6796
|
ne00,
|
@@ -6749,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
6749
6810
|
|
6750
6811
|
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
6751
6812
|
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
6752
|
-
device const char *
|
6813
|
+
device const char * src0s,
|
6753
6814
|
device const char * src1,
|
6754
6815
|
device float * dst,
|
6816
|
+
device const char * ids,
|
6755
6817
|
constant uint64_t & nbi1,
|
6756
6818
|
constant int64_t & ne00,
|
6757
6819
|
constant int64_t & ne01,
|
@@ -6772,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6772
6834
|
constant uint & r2,
|
6773
6835
|
constant uint & r3,
|
6774
6836
|
constant int & idx,
|
6775
|
-
device const char * src00,
|
6776
|
-
device const char * src01,
|
6777
|
-
device const char * src02,
|
6778
|
-
device const char * src03,
|
6779
|
-
device const char * src04,
|
6780
|
-
device const char * src05,
|
6781
|
-
device const char * src06,
|
6782
|
-
device const char * src07,
|
6783
6837
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6784
6838
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6785
6839
|
uint tiitg[[thread_index_in_threadgroup]],
|
6786
6840
|
uint tiisg[[thread_index_in_simdgroup]],
|
6787
6841
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6788
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6789
|
-
|
6790
6842
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6791
6843
|
|
6792
6844
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6793
6845
|
|
6794
6846
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6847
|
+
device const char * src0 = src0s + id*nb02;
|
6795
6848
|
|
6796
6849
|
kernel_mul_mv_iq3_xxs_f32_impl(
|
6797
|
-
src0
|
6850
|
+
src0,
|
6798
6851
|
(device const float *) (src1 + bid*nb11),
|
6799
6852
|
dst + bid*ne0,
|
6800
6853
|
ne00,
|
@@ -6814,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6814
6867
|
|
6815
6868
|
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
6816
6869
|
kernel void kernel_mul_mv_id_iq3_s_f32(
|
6817
|
-
device const char *
|
6870
|
+
device const char * src0s,
|
6818
6871
|
device const char * src1,
|
6819
6872
|
device float * dst,
|
6873
|
+
device const char * ids,
|
6820
6874
|
constant uint64_t & nbi1,
|
6821
6875
|
constant int64_t & ne00,
|
6822
6876
|
constant int64_t & ne01,
|
@@ -6837,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6837
6891
|
constant uint & r2,
|
6838
6892
|
constant uint & r3,
|
6839
6893
|
constant int & idx,
|
6840
|
-
device const char * src00,
|
6841
|
-
device const char * src01,
|
6842
|
-
device const char * src02,
|
6843
|
-
device const char * src03,
|
6844
|
-
device const char * src04,
|
6845
|
-
device const char * src05,
|
6846
|
-
device const char * src06,
|
6847
|
-
device const char * src07,
|
6848
6894
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6849
6895
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6850
6896
|
uint tiitg[[thread_index_in_threadgroup]],
|
6851
6897
|
uint tiisg[[thread_index_in_simdgroup]],
|
6852
6898
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6853
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6854
|
-
|
6855
6899
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6856
6900
|
|
6857
6901
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6858
6902
|
|
6859
6903
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6904
|
+
device const char * src0 = src0s + id*nb02;
|
6860
6905
|
|
6861
6906
|
kernel_mul_mv_iq3_s_f32_impl(
|
6862
|
-
src0
|
6907
|
+
src0,
|
6863
6908
|
(device const float *) (src1 + bid*nb11),
|
6864
6909
|
dst + bid*ne0,
|
6865
6910
|
ne00,
|
@@ -6879,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6879
6924
|
|
6880
6925
|
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
6881
6926
|
kernel void kernel_mul_mv_id_iq2_s_f32(
|
6882
|
-
device const char *
|
6927
|
+
device const char * src0s,
|
6883
6928
|
device const char * src1,
|
6884
6929
|
device float * dst,
|
6930
|
+
device const char * ids,
|
6885
6931
|
constant uint64_t & nbi1,
|
6886
6932
|
constant int64_t & ne00,
|
6887
6933
|
constant int64_t & ne01,
|
@@ -6902,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
6902
6948
|
constant uint & r2,
|
6903
6949
|
constant uint & r3,
|
6904
6950
|
constant int & idx,
|
6905
|
-
device const char * src00,
|
6906
|
-
device const char * src01,
|
6907
|
-
device const char * src02,
|
6908
|
-
device const char * src03,
|
6909
|
-
device const char * src04,
|
6910
|
-
device const char * src05,
|
6911
|
-
device const char * src06,
|
6912
|
-
device const char * src07,
|
6913
6951
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6914
6952
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6915
6953
|
uint tiitg[[thread_index_in_threadgroup]],
|
6916
6954
|
uint tiisg[[thread_index_in_simdgroup]],
|
6917
6955
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6918
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6919
|
-
|
6920
6956
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6921
6957
|
|
6922
6958
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6923
6959
|
|
6924
6960
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6961
|
+
device const char * src0 = src0s + id*nb02;
|
6925
6962
|
|
6926
6963
|
kernel_mul_mv_iq2_s_f32_impl(
|
6927
|
-
src0
|
6964
|
+
src0,
|
6928
6965
|
(device const float *) (src1 + bid*nb11),
|
6929
6966
|
dst + bid*ne0,
|
6930
6967
|
ne00,
|
@@ -6944,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
6944
6981
|
|
6945
6982
|
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6946
6983
|
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6947
|
-
device const char *
|
6984
|
+
device const char * src0s,
|
6948
6985
|
device const char * src1,
|
6949
6986
|
device float * dst,
|
6987
|
+
device const char * ids,
|
6950
6988
|
constant uint64_t & nbi1,
|
6951
6989
|
constant int64_t & ne00,
|
6952
6990
|
constant int64_t & ne01,
|
@@ -6967,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6967
7005
|
constant uint & r2,
|
6968
7006
|
constant uint & r3,
|
6969
7007
|
constant int & idx,
|
6970
|
-
device const char * src00,
|
6971
|
-
device const char * src01,
|
6972
|
-
device const char * src02,
|
6973
|
-
device const char * src03,
|
6974
|
-
device const char * src04,
|
6975
|
-
device const char * src05,
|
6976
|
-
device const char * src06,
|
6977
|
-
device const char * src07,
|
6978
7008
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6979
7009
|
uint tiitg[[thread_index_in_threadgroup]],
|
6980
7010
|
uint tiisg[[thread_index_in_simdgroup]],
|
6981
7011
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6982
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6983
|
-
|
6984
7012
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6985
7013
|
|
6986
7014
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6987
7015
|
|
6988
7016
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7017
|
+
device const char * src0 = src0s + id*nb02;
|
6989
7018
|
|
6990
7019
|
kernel_mul_mv_iq1_s_f32_impl(
|
6991
|
-
src0
|
7020
|
+
src0,
|
7021
|
+
(device const float *) (src1 + bid*nb11),
|
7022
|
+
dst + bid*ne0,
|
7023
|
+
ne00,
|
7024
|
+
ne01,
|
7025
|
+
ne02,
|
7026
|
+
ne10,
|
7027
|
+
ne12,
|
7028
|
+
ne0,
|
7029
|
+
ne1,
|
7030
|
+
r2,
|
7031
|
+
r3,
|
7032
|
+
tgpig,
|
7033
|
+
tiisg,
|
7034
|
+
sgitg);
|
7035
|
+
}
|
7036
|
+
|
7037
|
+
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
7038
|
+
kernel void kernel_mul_mv_id_iq1_m_f32(
|
7039
|
+
device const char * src0s,
|
7040
|
+
device const char * src1,
|
7041
|
+
device float * dst,
|
7042
|
+
device const char * ids,
|
7043
|
+
constant uint64_t & nbi1,
|
7044
|
+
constant int64_t & ne00,
|
7045
|
+
constant int64_t & ne01,
|
7046
|
+
constant int64_t & ne02,
|
7047
|
+
constant uint64_t & nb00,
|
7048
|
+
constant uint64_t & nb01,
|
7049
|
+
constant uint64_t & nb02,
|
7050
|
+
constant int64_t & ne10,
|
7051
|
+
constant int64_t & ne11,
|
7052
|
+
constant int64_t & ne12,
|
7053
|
+
constant int64_t & ne13,
|
7054
|
+
constant uint64_t & nb10,
|
7055
|
+
constant uint64_t & nb11,
|
7056
|
+
constant uint64_t & nb12,
|
7057
|
+
constant int64_t & ne0,
|
7058
|
+
constant int64_t & ne1,
|
7059
|
+
constant uint64_t & nb1,
|
7060
|
+
constant uint & r2,
|
7061
|
+
constant uint & r3,
|
7062
|
+
constant int & idx,
|
7063
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7064
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7065
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7066
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7067
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7068
|
+
|
7069
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7070
|
+
|
7071
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7072
|
+
device const char * src0 = src0s + id*nb02;
|
7073
|
+
|
7074
|
+
kernel_mul_mv_iq1_m_f32_impl(
|
7075
|
+
src0,
|
6992
7076
|
(device const float *) (src1 + bid*nb11),
|
6993
7077
|
dst + bid*ne0,
|
6994
7078
|
ne00,
|
@@ -7007,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
7007
7091
|
|
7008
7092
|
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
7009
7093
|
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
7010
|
-
device const char *
|
7094
|
+
device const char * src0s,
|
7011
7095
|
device const char * src1,
|
7012
7096
|
device float * dst,
|
7097
|
+
device const char * ids,
|
7013
7098
|
constant uint64_t & nbi1,
|
7014
7099
|
constant int64_t & ne00,
|
7015
7100
|
constant int64_t & ne01,
|
@@ -7030,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
7030
7115
|
constant uint & r2,
|
7031
7116
|
constant uint & r3,
|
7032
7117
|
constant int & idx,
|
7033
|
-
device const char * src00,
|
7034
|
-
device const char * src01,
|
7035
|
-
device const char * src02,
|
7036
|
-
device const char * src03,
|
7037
|
-
device const char * src04,
|
7038
|
-
device const char * src05,
|
7039
|
-
device const char * src06,
|
7040
|
-
device const char * src07,
|
7041
7118
|
threadgroup float * shared_values [[threadgroup(0)]],
|
7042
7119
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
7043
7120
|
uint tiitg[[thread_index_in_threadgroup]],
|
7044
7121
|
uint tiisg[[thread_index_in_simdgroup]],
|
7045
7122
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7046
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7047
|
-
|
7048
7123
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
7049
7124
|
|
7050
7125
|
tgpig.z = tgpig.z%(ne12*ne13);
|
7051
7126
|
|
7052
7127
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7128
|
+
device const char * src0 = src0s + id*nb02;
|
7053
7129
|
|
7054
7130
|
kernel_mul_mv_iq4_nl_f32_impl(
|
7055
|
-
src0
|
7131
|
+
src0,
|
7056
7132
|
(device const float *) (src1 + bid*nb11),
|
7057
7133
|
dst + bid*ne0,
|
7058
7134
|
ne00,
|
@@ -7072,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
7072
7148
|
|
7073
7149
|
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
7074
7150
|
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
7075
|
-
device const char *
|
7151
|
+
device const char * src0s,
|
7076
7152
|
device const char * src1,
|
7077
7153
|
device float * dst,
|
7154
|
+
device const char * ids,
|
7078
7155
|
constant uint64_t & nbi1,
|
7079
7156
|
constant int64_t & ne00,
|
7080
7157
|
constant int64_t & ne01,
|
@@ -7095,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|
7095
7172
|
constant uint & r2,
|
7096
7173
|
constant uint & r3,
|
7097
7174
|
constant int & idx,
|
7098
|
-
device const char * src00,
|
7099
|
-
device const char * src01,
|
7100
|
-
device const char * src02,
|
7101
|
-
device const char * src03,
|
7102
|
-
device const char * src04,
|
7103
|
-
device const char * src05,
|
7104
|
-
device const char * src06,
|
7105
|
-
device const char * src07,
|
7106
7175
|
threadgroup float * shared_values [[threadgroup(0)]],
|
7107
7176
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
7108
7177
|
uint tiitg[[thread_index_in_threadgroup]],
|
7109
7178
|
uint tiisg[[thread_index_in_simdgroup]],
|
7110
7179
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7111
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7112
|
-
|
7113
7180
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
7114
7181
|
|
7115
7182
|
tgpig.z = tgpig.z%(ne12*ne13);
|
7116
7183
|
|
7117
7184
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7185
|
+
device const char * src0 = src0s + id*nb02;
|
7118
7186
|
|
7119
7187
|
#if QK_K == 64
|
7120
7188
|
kernel_mul_mv_iq4_nl_f32_impl(
|
7121
7189
|
#else
|
7122
7190
|
kernel_mul_mv_iq4_xs_f32_impl(
|
7123
7191
|
#endif
|
7124
|
-
src0
|
7192
|
+
src0,
|
7125
7193
|
(device const float *) (src1 + bid*nb11),
|
7126
7194
|
dst + bid*ne0,
|
7127
7195
|
ne00,
|