llama_cpp 0.15.2 → 0.15.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +6 -17
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +72 -30
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +40 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +68 -70
- data/vendor/tmp/llama.cpp/ggml-metal.metal +24 -409
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1879 -2450
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +176 -53
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +40 -500
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +202 -225
- data/vendor/tmp/llama.cpp/ggml.c +376 -758
- data/vendor/tmp/llama.cpp/ggml.h +39 -27
- data/vendor/tmp/llama.cpp/llama.cpp +823 -593
- data/vendor/tmp/llama.cpp/llama.h +10 -3
- metadata +3 -3
@@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims(
|
|
1640
1640
|
typedef void (rope_t)(
|
1641
1641
|
device const void * src0,
|
1642
1642
|
device const int32_t * src1,
|
1643
|
+
device const float * src2,
|
1643
1644
|
device float * dst,
|
1644
1645
|
constant int64_t & ne00,
|
1645
1646
|
constant int64_t & ne01,
|
@@ -1675,6 +1676,7 @@ template<typename T>
|
|
1675
1676
|
kernel void kernel_rope(
|
1676
1677
|
device const void * src0,
|
1677
1678
|
device const int32_t * src1,
|
1679
|
+
device const float * src2,
|
1678
1680
|
device float * dst,
|
1679
1681
|
constant int64_t & ne00,
|
1680
1682
|
constant int64_t & ne01,
|
@@ -1744,8 +1746,10 @@ kernel void kernel_rope(
|
|
1744
1746
|
|
1745
1747
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1746
1748
|
const float cur_rot = inv_ndims*ic - ib;
|
1749
|
+
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
1750
|
+
|
1751
|
+
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
1747
1752
|
|
1748
|
-
const float theta = theta_0 * pow(freq_base, cur_rot);
|
1749
1753
|
float cos_theta, sin_theta;
|
1750
1754
|
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1751
1755
|
|
@@ -2204,11 +2208,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2204
2208
|
// pointer to the mask
|
2205
2209
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
2206
2210
|
|
2207
|
-
|
2208
|
-
simdgroup_float8x8 mscale(scale);
|
2209
|
-
|
2210
|
-
// prepare diagonal slope matrix
|
2211
|
-
simdgroup_float8x8 mslope(1.0f);
|
2211
|
+
float slope = 1.0f;
|
2212
2212
|
|
2213
2213
|
// ALiBi
|
2214
2214
|
if (max_bias > 0.0f) {
|
@@ -2217,7 +2217,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2217
2217
|
const float base = h < n_head_log2 ? m0 : m1;
|
2218
2218
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2219
2219
|
|
2220
|
-
|
2220
|
+
slope = pow(base, exph);
|
2221
2221
|
}
|
2222
2222
|
|
2223
2223
|
// loop over the KV cache
|
@@ -2242,18 +2242,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2242
2242
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2243
2243
|
}
|
2244
2244
|
|
2245
|
+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2246
|
+
|
2247
|
+
const short tx = tiisg%4;
|
2248
|
+
const short ty = tiisg/4;
|
2249
|
+
|
2245
2250
|
if (mask != q) {
|
2246
2251
|
// mqk = mqk*scale + mask*slope
|
2247
|
-
|
2248
|
-
|
2249
|
-
simdgroup_multiply(mm, mslope, mm);
|
2250
|
-
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2252
|
+
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
2253
|
+
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
2251
2254
|
} else {
|
2252
2255
|
// mqk = mqk*scale
|
2253
|
-
|
2256
|
+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
2257
|
+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
2254
2258
|
}
|
2255
|
-
|
2256
|
-
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2257
2259
|
}
|
2258
2260
|
}
|
2259
2261
|
|
@@ -2816,8 +2818,7 @@ kernel void kernel_cpy_f32_f16(
|
|
2816
2818
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
2817
2819
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2818
2820
|
|
2819
|
-
|
2820
|
-
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
2821
|
+
dst_data[i00] = src[0];
|
2821
2822
|
}
|
2822
2823
|
}
|
2823
2824
|
|
@@ -3385,7 +3386,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3385
3386
|
|
3386
3387
|
const int step = sizeof(block_q2_K) * nb;
|
3387
3388
|
|
3388
|
-
#if QK_K == 256
|
3389
3389
|
const int ix = tiisg/8; // 0...3
|
3390
3390
|
const int it = tiisg%8; // 0...7
|
3391
3391
|
const int iq = it/4; // 0 or 1
|
@@ -3437,57 +3437,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3437
3437
|
|
3438
3438
|
y4 += 4 * QK_K;
|
3439
3439
|
}
|
3440
|
-
#else
|
3441
|
-
const int ix = tiisg/2; // 0...15
|
3442
|
-
const int it = tiisg%2; // 0...1
|
3443
|
-
|
3444
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3445
|
-
|
3446
|
-
for (int ib = ix; ib < nb; ib += 16) {
|
3447
|
-
|
3448
|
-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
3449
|
-
for (int i = 0; i < 8; ++i) {
|
3450
|
-
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
3451
|
-
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
3452
|
-
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
3453
|
-
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
3454
|
-
}
|
3455
|
-
|
3456
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
3457
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3458
|
-
device const half * dh = &x[ib].d;
|
3459
|
-
|
3460
|
-
for (int row = 0; row < N_DST; row++) {
|
3461
|
-
|
3462
|
-
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
3463
|
-
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
3464
|
-
for (int i = 0; i < 8; i += 2) {
|
3465
|
-
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
3466
|
-
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
3467
|
-
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
3468
|
-
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
3469
|
-
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
3470
|
-
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
3471
|
-
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
3472
|
-
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
3473
|
-
}
|
3474
|
-
|
3475
|
-
float dall = dh[0];
|
3476
|
-
float dmin = dh[1];
|
3477
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
3478
|
-
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
3479
|
-
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
3480
|
-
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
3481
|
-
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
3482
|
-
|
3483
|
-
qs += step/2;
|
3484
|
-
sc += step;
|
3485
|
-
dh += step/2;
|
3486
|
-
}
|
3487
|
-
|
3488
|
-
y4 += 16 * QK_K;
|
3489
|
-
}
|
3490
|
-
#endif
|
3491
3440
|
|
3492
3441
|
for (int row = 0; row < N_DST; ++row) {
|
3493
3442
|
all_sum = simd_sum(sumf[row]);
|
@@ -3525,7 +3474,6 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
3525
3474
|
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3526
3475
|
}
|
3527
3476
|
|
3528
|
-
#if QK_K == 256
|
3529
3477
|
void kernel_mul_mv_q3_K_f32_impl(
|
3530
3478
|
device const void * src0,
|
3531
3479
|
device const float * src1,
|
@@ -3684,84 +3632,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
3684
3632
|
}
|
3685
3633
|
}
|
3686
3634
|
}
|
3687
|
-
#else
|
3688
|
-
void kernel_mul_mv_q3_K_f32_impl(
|
3689
|
-
device const void * src0,
|
3690
|
-
device const float * src1,
|
3691
|
-
device float * dst,
|
3692
|
-
constant int64_t & ne00,
|
3693
|
-
constant int64_t & ne01,
|
3694
|
-
constant int64_t & ne02,
|
3695
|
-
constant int64_t & ne10,
|
3696
|
-
constant int64_t & ne12,
|
3697
|
-
constant int64_t & ne0,
|
3698
|
-
constant int64_t & ne1,
|
3699
|
-
constant uint & r2,
|
3700
|
-
constant uint & r3,
|
3701
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3702
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3703
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3704
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3705
|
-
|
3706
|
-
const int nb = ne00/QK_K;
|
3707
|
-
|
3708
|
-
const int64_t r0 = tgpig.x;
|
3709
|
-
const int64_t r1 = tgpig.y;
|
3710
|
-
const int64_t im = tgpig.z;
|
3711
|
-
|
3712
|
-
const int row = 2 * r0 + sgitg;
|
3713
|
-
|
3714
|
-
const uint i12 = im%ne12;
|
3715
|
-
const uint i13 = im/ne12;
|
3716
|
-
|
3717
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3718
|
-
|
3719
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
3720
|
-
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3721
|
-
|
3722
|
-
const int ix = tiisg/4;
|
3723
|
-
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
3724
|
-
const int iq = il/8; // 0, 0, 1, 1
|
3725
|
-
const int in = il%8; // 0, 4, 0, 4
|
3726
|
-
|
3727
|
-
float2 sum = {0.f, 0.f};
|
3728
|
-
|
3729
|
-
for (int i = ix; i < nb; i += 8) {
|
3730
|
-
|
3731
|
-
const float d_all = (float)(x[i].d);
|
3732
|
-
|
3733
|
-
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
3734
|
-
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
3735
|
-
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
3736
|
-
device const float * y = yy + i * QK_K + il;
|
3737
|
-
|
3738
|
-
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
3739
|
-
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
3740
|
-
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
3741
|
-
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
3742
|
-
|
3743
|
-
for (int l = 0; l < 4; l += 2) {
|
3744
|
-
const uint16_t hm = h[l/2] >> iq;
|
3745
|
-
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
3746
|
-
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
3747
|
-
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
3748
|
-
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
3749
|
-
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
3750
|
-
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
3751
|
-
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
3752
|
-
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
3753
|
-
}
|
3754
|
-
|
3755
|
-
}
|
3756
|
-
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
3757
|
-
|
3758
|
-
const float tot = simd_sum(sumf);
|
3759
|
-
if (tiisg == 0) {
|
3760
|
-
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
3761
|
-
}
|
3762
|
-
|
3763
|
-
}
|
3764
|
-
#endif
|
3765
3635
|
|
3766
3636
|
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
3767
3637
|
kernel void kernel_mul_mv_q3_K_f32(
|
@@ -3791,7 +3661,6 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
3791
3661
|
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3792
3662
|
}
|
3793
3663
|
|
3794
|
-
#if QK_K == 256
|
3795
3664
|
void kernel_mul_mv_q4_K_f32_impl(
|
3796
3665
|
device const void * src0,
|
3797
3666
|
device const float * src1,
|
@@ -3905,103 +3774,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3905
3774
|
}
|
3906
3775
|
}
|
3907
3776
|
}
|
3908
|
-
#else
|
3909
|
-
void kernel_mul_mv_q4_K_f32_impl(
|
3910
|
-
device const void * src0,
|
3911
|
-
device const float * src1,
|
3912
|
-
device float * dst,
|
3913
|
-
constant int64_t & ne00,
|
3914
|
-
constant int64_t & ne01,
|
3915
|
-
constant int64_t & ne02,
|
3916
|
-
constant int64_t & ne10,
|
3917
|
-
constant int64_t & ne12,
|
3918
|
-
constant int64_t & ne0,
|
3919
|
-
constant int64_t & ne1,
|
3920
|
-
constant uint & r2,
|
3921
|
-
constant uint & r3,
|
3922
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3923
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3924
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3925
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3926
|
-
|
3927
|
-
const int ix = tiisg/4; // 0...7
|
3928
|
-
const int it = tiisg%4; // 0...3
|
3929
|
-
|
3930
|
-
const int nb = ne00/QK_K;
|
3931
|
-
const int r0 = tgpig.x;
|
3932
|
-
const int r1 = tgpig.y;
|
3933
|
-
const int im = tgpig.z;
|
3934
|
-
const int first_row = r0 * N_DST;
|
3935
|
-
const int ib_row = first_row * nb;
|
3936
|
-
|
3937
|
-
const uint i12 = im%ne12;
|
3938
|
-
const uint i13 = im/ne12;
|
3939
|
-
|
3940
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3941
|
-
|
3942
|
-
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
3943
|
-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3944
|
-
|
3945
|
-
float yl[8];
|
3946
|
-
float yh[8];
|
3947
|
-
float sumf[N_DST]={0.f}, all_sum;
|
3948
|
-
|
3949
|
-
const int step = sizeof(block_q4_K) * nb / 2;
|
3950
|
-
|
3951
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3952
|
-
|
3953
|
-
uint16_t sc16[4];
|
3954
|
-
|
3955
|
-
for (int ib = ix; ib < nb; ib += 8) {
|
3956
|
-
|
3957
|
-
float2 sumy = {0.f, 0.f};
|
3958
|
-
for (int i = 0; i < 8; ++i) {
|
3959
|
-
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
3960
|
-
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
3961
|
-
}
|
3962
|
-
|
3963
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
3964
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3965
|
-
device const half * dh = x[ib].d;
|
3966
|
-
|
3967
|
-
for (int row = 0; row < N_DST; row++) {
|
3968
|
-
|
3969
|
-
sc16[0] = sc[0] & 0x000f;
|
3970
|
-
sc16[1] = sc[0] & 0x0f00;
|
3971
|
-
sc16[2] = sc[0] & 0x00f0;
|
3972
|
-
sc16[3] = sc[0] & 0xf000;
|
3973
|
-
|
3974
|
-
float2 acc1 = {0.f, 0.f};
|
3975
|
-
float2 acc2 = {0.f, 0.f};
|
3976
|
-
for (int i = 0; i < 8; i += 2) {
|
3977
|
-
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
3978
|
-
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
3979
|
-
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
3980
|
-
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
3981
|
-
}
|
3982
|
-
|
3983
|
-
float dall = dh[0];
|
3984
|
-
float dmin = dh[1];
|
3985
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
3986
|
-
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
3987
|
-
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
3988
|
-
|
3989
|
-
qs += step;
|
3990
|
-
sc += step;
|
3991
|
-
dh += step;
|
3992
|
-
}
|
3993
|
-
|
3994
|
-
y4 += 8 * QK_K;
|
3995
|
-
}
|
3996
|
-
|
3997
|
-
for (int row = 0; row < N_DST; ++row) {
|
3998
|
-
all_sum = simd_sum(sumf[row]);
|
3999
|
-
if (tiisg == 0) {
|
4000
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4001
|
-
}
|
4002
|
-
}
|
4003
|
-
}
|
4004
|
-
#endif
|
4005
3777
|
|
4006
3778
|
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
4007
3779
|
kernel void kernel_mul_mv_q4_K_f32(
|
@@ -4069,8 +3841,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4069
3841
|
|
4070
3842
|
const int step = sizeof(block_q5_K) * nb;
|
4071
3843
|
|
4072
|
-
#if QK_K == 256
|
4073
|
-
#
|
4074
3844
|
float yl[16], yh[16];
|
4075
3845
|
|
4076
3846
|
const uint16_t kmask1 = 0x3f3f;
|
@@ -4153,54 +3923,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4153
3923
|
y1 += 4 * QK_K;
|
4154
3924
|
|
4155
3925
|
}
|
4156
|
-
#else
|
4157
|
-
float yl[8], yh[8];
|
4158
|
-
|
4159
|
-
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
4160
|
-
const int ix = tiisg%8;
|
4161
|
-
const int iq = il/8; // 0, 0, 1, 1
|
4162
|
-
const int in = il%8; // 0, 4, 0, 4
|
4163
|
-
|
4164
|
-
device const float * y = yy + ix*QK_K + il;
|
4165
|
-
|
4166
|
-
for (int i = ix; i < nb; i += 8) {
|
4167
|
-
|
4168
|
-
for (int l = 0; l < 4; ++l) {
|
4169
|
-
yl[l+0] = y[l+ 0];
|
4170
|
-
yl[l+4] = y[l+16];
|
4171
|
-
yh[l+0] = y[l+32];
|
4172
|
-
yh[l+4] = y[l+48];
|
4173
|
-
}
|
4174
|
-
|
4175
|
-
device const half * dh = &x[i].d;
|
4176
|
-
device const uint8_t * q = x[i].qs + il;
|
4177
|
-
device const uint8_t * h = x[i].qh + in;
|
4178
|
-
device const int8_t * s = x[i].scales;
|
4179
|
-
|
4180
|
-
for (int row = 0; row < 2; ++row) {
|
4181
|
-
|
4182
|
-
const float d = dh[0];
|
4183
|
-
|
4184
|
-
float2 acc = {0.f, 0.f};
|
4185
|
-
for (int l = 0; l < 4; ++l) {
|
4186
|
-
const uint8_t hl = h[l] >> iq;
|
4187
|
-
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
4188
|
-
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
4189
|
-
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
4190
|
-
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
4191
|
-
}
|
4192
|
-
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
4193
|
-
|
4194
|
-
q += step;
|
4195
|
-
h += step;
|
4196
|
-
s += step;
|
4197
|
-
dh += step/2;
|
4198
|
-
|
4199
|
-
}
|
4200
|
-
|
4201
|
-
y += 8 * QK_K;
|
4202
|
-
}
|
4203
|
-
#endif
|
4204
3926
|
|
4205
3927
|
for (int row = 0; row < 2; ++row) {
|
4206
3928
|
const float tot = simd_sum(sumf[row]);
|
@@ -4279,7 +4001,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4279
4001
|
|
4280
4002
|
float sumf = 0;
|
4281
4003
|
|
4282
|
-
#if QK_K == 256
|
4283
4004
|
const int tid = tiisg/2;
|
4284
4005
|
const int ix = tiisg%2;
|
4285
4006
|
const int ip = tid/8; // 0 or 1
|
@@ -4315,30 +4036,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4315
4036
|
|
4316
4037
|
}
|
4317
4038
|
|
4318
|
-
#else
|
4319
|
-
const int ix = tiisg/4;
|
4320
|
-
const int il = 4*(tiisg%4);
|
4321
|
-
|
4322
|
-
for (int i = ix; i < nb; i += 8) {
|
4323
|
-
device const float * y = yy + i * QK_K + il;
|
4324
|
-
device const uint8_t * ql = x[i].ql + il;
|
4325
|
-
device const uint8_t * qh = x[i].qh + il;
|
4326
|
-
device const int8_t * s = x[i].scales;
|
4327
|
-
|
4328
|
-
const float d = x[i].d;
|
4329
|
-
|
4330
|
-
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
4331
|
-
for (int l = 0; l < 4; ++l) {
|
4332
|
-
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
4333
|
-
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
4334
|
-
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
4335
|
-
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
4336
|
-
}
|
4337
|
-
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
4338
|
-
}
|
4339
|
-
|
4340
|
-
#endif
|
4341
|
-
|
4342
4039
|
const float tot = simd_sum(sumf);
|
4343
4040
|
if (tiisg == 0) {
|
4344
4041
|
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
@@ -5172,9 +4869,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5172
4869
|
|
5173
4870
|
device const float * y4 = y + 32 * ix;
|
5174
4871
|
|
5175
|
-
#if QK_K != 64
|
5176
4872
|
iq1m_scale_t scale;
|
5177
|
-
#endif
|
5178
4873
|
|
5179
4874
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5180
4875
|
|
@@ -5195,10 +4890,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5195
4890
|
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
5196
4891
|
|
5197
4892
|
for (int row = 0; row < N_DST; row++) {
|
5198
|
-
|
5199
|
-
#if QK_K != 64
|
5200
4893
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5201
|
-
#endif
|
5202
4894
|
|
5203
4895
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5204
4896
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
@@ -5214,14 +4906,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5214
4906
|
}
|
5215
4907
|
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5216
4908
|
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5217
|
-
|
5218
|
-
const float d = (float) *((device const half *)(sc - 1));
|
5219
|
-
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
5220
|
-
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
5221
|
-
#else
|
4909
|
+
|
5222
4910
|
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
5223
4911
|
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
5224
|
-
#endif
|
5225
4912
|
|
5226
4913
|
sc += nb*sizeof(block_iq1_m)/2;
|
5227
4914
|
qs += nb*sizeof(block_iq1_m);
|
@@ -5333,7 +5020,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5333
5020
|
}
|
5334
5021
|
}
|
5335
5022
|
|
5336
|
-
#if QK_K != 64
|
5337
5023
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
5338
5024
|
device const void * src0,
|
5339
5025
|
device const float * src1,
|
@@ -5428,7 +5114,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5428
5114
|
}
|
5429
5115
|
}
|
5430
5116
|
}
|
5431
|
-
#endif
|
5432
5117
|
|
5433
5118
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5434
5119
|
kernel void kernel_mul_mv_iq1_s_f32(
|
@@ -5541,11 +5226,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
5541
5226
|
uint tiisg[[thread_index_in_simdgroup]],
|
5542
5227
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5543
5228
|
|
5544
|
-
#if QK_K == 64
|
5545
|
-
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5546
|
-
#else
|
5547
5229
|
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5548
|
-
#endif
|
5549
5230
|
}
|
5550
5231
|
|
5551
5232
|
//============================= templates and their specializations =============================
|
@@ -5671,10 +5352,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
5671
5352
|
float dl, ml;
|
5672
5353
|
uint8_t sc = xb->scales[il];
|
5673
5354
|
|
5674
|
-
#if QK_K == 256
|
5675
5355
|
q = q + 32*(il/8) + 16*(il&1);
|
5676
5356
|
il = (il/2)%4;
|
5677
|
-
|
5357
|
+
|
5678
5358
|
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5679
5359
|
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5680
5360
|
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
@@ -5690,7 +5370,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5690
5370
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
5691
5371
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5692
5372
|
|
5693
|
-
#if QK_K == 256
|
5694
5373
|
q = q + 32 * (il/8) + 16 * (il&1);
|
5695
5374
|
h = h + 16 * (il&1);
|
5696
5375
|
uint8_t m = 1 << (il/2);
|
@@ -5711,17 +5390,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5711
5390
|
for (int i = 0; i < 16; ++i) {
|
5712
5391
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
5713
5392
|
}
|
5714
|
-
#else
|
5715
|
-
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
5716
|
-
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
5717
|
-
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
5718
|
-
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5719
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5720
|
-
uint8_t m = 1<<(il*2);
|
5721
|
-
for (int i = 0; i < 16; ++i) {
|
5722
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
5723
|
-
}
|
5724
|
-
#endif
|
5725
5393
|
}
|
5726
5394
|
|
5727
5395
|
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
@@ -5733,7 +5401,6 @@ template <typename type4x4>
|
|
5733
5401
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
5734
5402
|
device const uchar * q = xb->qs;
|
5735
5403
|
|
5736
|
-
#if QK_K == 256
|
5737
5404
|
short is = (il/4) * 2;
|
5738
5405
|
q = q + (il/4) * 32 + 16 * (il&1);
|
5739
5406
|
il = il & 3;
|
@@ -5742,16 +5409,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
5742
5409
|
const float min = xb->dmin;
|
5743
5410
|
const float dl = d * sc[0];
|
5744
5411
|
const float ml = min * sc[1];
|
5745
|
-
|
5746
|
-
(void) get_scale_min_k4_just2;
|
5747
|
-
|
5748
|
-
q = q + 16 * (il&1);
|
5749
|
-
device const uint8_t * s = xb->scales;
|
5750
|
-
device const half2 * dh = (device const half2 *)xb->d;
|
5751
|
-
const float2 d = (float2)dh[0];
|
5752
|
-
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
5753
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
5754
|
-
#endif
|
5412
|
+
|
5755
5413
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5756
5414
|
for (int i = 0; i < 16; ++i) {
|
5757
5415
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
@@ -5763,7 +5421,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5763
5421
|
device const uint8_t * q = xb->qs;
|
5764
5422
|
device const uint8_t * qh = xb->qh;
|
5765
5423
|
|
5766
|
-
#if QK_K == 256
|
5767
5424
|
short is = (il/4) * 2;
|
5768
5425
|
q = q + 32 * (il/4) + 16 * (il&1);
|
5769
5426
|
qh = qh + 16 * (il&1);
|
@@ -5780,17 +5437,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5780
5437
|
for (int i = 0; i < 16; ++i) {
|
5781
5438
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
5782
5439
|
}
|
5783
|
-
#else
|
5784
|
-
q = q + 16 * (il&1);
|
5785
|
-
device const int8_t * s = xb->scales;
|
5786
|
-
const float dl = xb->d * s[il];
|
5787
|
-
uint8_t m = 1<<(il*2);
|
5788
|
-
const float coef = il<2 ? 1.f : 1.f/16.f;
|
5789
|
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5790
|
-
for (int i = 0; i < 16; ++i) {
|
5791
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
5792
|
-
}
|
5793
|
-
#endif
|
5794
5440
|
}
|
5795
5441
|
|
5796
5442
|
template <typename type4x4>
|
@@ -5800,15 +5446,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
5800
5446
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
5801
5447
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5802
5448
|
|
5803
|
-
#if QK_K == 256
|
5804
5449
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
5805
5450
|
qh = qh + 32*(il/8) + 16*(il&1);
|
5806
5451
|
float sc = scales[(il%2) + 2 * ((il/2))];
|
5807
5452
|
il = (il/2) & 3;
|
5808
|
-
|
5809
|
-
ql = ql + 16 * (il&1);
|
5810
|
-
float sc = scales[il];
|
5811
|
-
#endif
|
5453
|
+
|
5812
5454
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5813
5455
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
5814
5456
|
const float coef = il>1 ? 1.f/16.f : 1.f;
|
@@ -5965,20 +5607,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
|
|
5965
5607
|
const int ib32 = il/2;
|
5966
5608
|
il = il%2;
|
5967
5609
|
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5968
|
-
|
5969
|
-
const float d = xb->d;
|
5970
|
-
#else
|
5610
|
+
|
5971
5611
|
iq1m_scale_t scale;
|
5972
5612
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5973
5613
|
const float d = scale.f16;
|
5974
|
-
|
5614
|
+
|
5975
5615
|
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5976
5616
|
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
5977
|
-
|
5978
|
-
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
5979
|
-
#else
|
5617
|
+
|
5980
5618
|
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
5981
|
-
#endif
|
5982
5619
|
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5983
5620
|
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5984
5621
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
@@ -6008,9 +5645,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
6008
5645
|
|
6009
5646
|
template <typename type4x4>
|
6010
5647
|
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
6011
|
-
#if QK_K == 64
|
6012
|
-
dequantize_iq4_nl(xb, il, reg);
|
6013
|
-
#else
|
6014
5648
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
6015
5649
|
const int ib32 = il/2;
|
6016
5650
|
il = il%2;
|
@@ -6027,7 +5661,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
6027
5661
|
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
6028
5662
|
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
6029
5663
|
}
|
6030
|
-
#endif
|
6031
5664
|
}
|
6032
5665
|
|
6033
5666
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
@@ -6532,11 +6165,7 @@ kernel void kernel_mul_mm_id(
|
|
6532
6165
|
sgitg);
|
6533
6166
|
}
|
6534
6167
|
|
6535
|
-
#if QK_K == 256
|
6536
6168
|
#define QK_NL 16
|
6537
|
-
#else
|
6538
|
-
#define QK_NL 4
|
6539
|
-
#endif
|
6540
6169
|
|
6541
6170
|
//
|
6542
6171
|
// get rows
|
@@ -6576,11 +6205,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
|
|
6576
6205
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6577
6206
|
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6578
6207
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6579
|
-
#if QK_K == 64
|
6580
|
-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6581
|
-
#else
|
6582
6208
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6583
|
-
#endif
|
6584
6209
|
|
6585
6210
|
//
|
6586
6211
|
// matrix-matrix multiplication
|
@@ -6608,11 +6233,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
|
|
6608
6233
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6609
6234
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6610
6235
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6611
|
-
#if QK_K == 64
|
6612
|
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
6613
|
-
#else
|
6614
6236
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6615
|
-
#endif
|
6616
6237
|
|
6617
6238
|
//
|
6618
6239
|
// indirect matrix-matrix multiplication
|
@@ -6640,11 +6261,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
|
|
6640
6261
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6641
6262
|
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6642
6263
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6643
|
-
#if QK_K == 64
|
6644
|
-
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6645
|
-
#else
|
6646
6264
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6647
|
-
#endif
|
6648
6265
|
|
6649
6266
|
//
|
6650
6267
|
// matrix-vector multiplication
|
@@ -6853,7 +6470,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
|
|
6853
6470
|
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
6854
6471
|
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
6855
6472
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
6856
|
-
#if QK_K != 64
|
6857
6473
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
6858
|
-
#endif
|
6859
6474
|
|