llama_cpp 0.15.2 → 0.15.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +6 -17
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +72 -30
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +40 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +68 -70
- data/vendor/tmp/llama.cpp/ggml-metal.metal +24 -409
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1879 -2450
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +176 -53
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +40 -500
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +202 -225
- data/vendor/tmp/llama.cpp/ggml.c +376 -758
- data/vendor/tmp/llama.cpp/ggml.h +39 -27
- data/vendor/tmp/llama.cpp/llama.cpp +823 -593
- data/vendor/tmp/llama.cpp/llama.h +10 -3
- metadata +3 -3
@@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims(
|
|
1640
1640
|
typedef void (rope_t)(
|
1641
1641
|
device const void * src0,
|
1642
1642
|
device const int32_t * src1,
|
1643
|
+
device const float * src2,
|
1643
1644
|
device float * dst,
|
1644
1645
|
constant int64_t & ne00,
|
1645
1646
|
constant int64_t & ne01,
|
@@ -1675,6 +1676,7 @@ template<typename T>
|
|
1675
1676
|
kernel void kernel_rope(
|
1676
1677
|
device const void * src0,
|
1677
1678
|
device const int32_t * src1,
|
1679
|
+
device const float * src2,
|
1678
1680
|
device float * dst,
|
1679
1681
|
constant int64_t & ne00,
|
1680
1682
|
constant int64_t & ne01,
|
@@ -1744,8 +1746,10 @@ kernel void kernel_rope(
|
|
1744
1746
|
|
1745
1747
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1746
1748
|
const float cur_rot = inv_ndims*ic - ib;
|
1749
|
+
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
1750
|
+
|
1751
|
+
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
1747
1752
|
|
1748
|
-
const float theta = theta_0 * pow(freq_base, cur_rot);
|
1749
1753
|
float cos_theta, sin_theta;
|
1750
1754
|
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1751
1755
|
|
@@ -2204,11 +2208,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2204
2208
|
// pointer to the mask
|
2205
2209
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
2206
2210
|
|
2207
|
-
|
2208
|
-
simdgroup_float8x8 mscale(scale);
|
2209
|
-
|
2210
|
-
// prepare diagonal slope matrix
|
2211
|
-
simdgroup_float8x8 mslope(1.0f);
|
2211
|
+
float slope = 1.0f;
|
2212
2212
|
|
2213
2213
|
// ALiBi
|
2214
2214
|
if (max_bias > 0.0f) {
|
@@ -2217,7 +2217,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2217
2217
|
const float base = h < n_head_log2 ? m0 : m1;
|
2218
2218
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2219
2219
|
|
2220
|
-
|
2220
|
+
slope = pow(base, exph);
|
2221
2221
|
}
|
2222
2222
|
|
2223
2223
|
// loop over the KV cache
|
@@ -2242,18 +2242,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2242
2242
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2243
2243
|
}
|
2244
2244
|
|
2245
|
+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2246
|
+
|
2247
|
+
const short tx = tiisg%4;
|
2248
|
+
const short ty = tiisg/4;
|
2249
|
+
|
2245
2250
|
if (mask != q) {
|
2246
2251
|
// mqk = mqk*scale + mask*slope
|
2247
|
-
|
2248
|
-
|
2249
|
-
simdgroup_multiply(mm, mslope, mm);
|
2250
|
-
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2252
|
+
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
2253
|
+
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
2251
2254
|
} else {
|
2252
2255
|
// mqk = mqk*scale
|
2253
|
-
|
2256
|
+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
2257
|
+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
2254
2258
|
}
|
2255
|
-
|
2256
|
-
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2257
2259
|
}
|
2258
2260
|
}
|
2259
2261
|
|
@@ -2816,8 +2818,7 @@ kernel void kernel_cpy_f32_f16(
|
|
2816
2818
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
2817
2819
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2818
2820
|
|
2819
|
-
|
2820
|
-
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
2821
|
+
dst_data[i00] = src[0];
|
2821
2822
|
}
|
2822
2823
|
}
|
2823
2824
|
|
@@ -3385,7 +3386,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3385
3386
|
|
3386
3387
|
const int step = sizeof(block_q2_K) * nb;
|
3387
3388
|
|
3388
|
-
#if QK_K == 256
|
3389
3389
|
const int ix = tiisg/8; // 0...3
|
3390
3390
|
const int it = tiisg%8; // 0...7
|
3391
3391
|
const int iq = it/4; // 0 or 1
|
@@ -3437,57 +3437,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3437
3437
|
|
3438
3438
|
y4 += 4 * QK_K;
|
3439
3439
|
}
|
3440
|
-
#else
|
3441
|
-
const int ix = tiisg/2; // 0...15
|
3442
|
-
const int it = tiisg%2; // 0...1
|
3443
|
-
|
3444
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3445
|
-
|
3446
|
-
for (int ib = ix; ib < nb; ib += 16) {
|
3447
|
-
|
3448
|
-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
3449
|
-
for (int i = 0; i < 8; ++i) {
|
3450
|
-
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
3451
|
-
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
3452
|
-
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
3453
|
-
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
3454
|
-
}
|
3455
|
-
|
3456
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
3457
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3458
|
-
device const half * dh = &x[ib].d;
|
3459
|
-
|
3460
|
-
for (int row = 0; row < N_DST; row++) {
|
3461
|
-
|
3462
|
-
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
3463
|
-
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
3464
|
-
for (int i = 0; i < 8; i += 2) {
|
3465
|
-
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
3466
|
-
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
3467
|
-
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
3468
|
-
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
3469
|
-
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
3470
|
-
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
3471
|
-
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
3472
|
-
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
3473
|
-
}
|
3474
|
-
|
3475
|
-
float dall = dh[0];
|
3476
|
-
float dmin = dh[1];
|
3477
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
3478
|
-
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
3479
|
-
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
3480
|
-
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
3481
|
-
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
3482
|
-
|
3483
|
-
qs += step/2;
|
3484
|
-
sc += step;
|
3485
|
-
dh += step/2;
|
3486
|
-
}
|
3487
|
-
|
3488
|
-
y4 += 16 * QK_K;
|
3489
|
-
}
|
3490
|
-
#endif
|
3491
3440
|
|
3492
3441
|
for (int row = 0; row < N_DST; ++row) {
|
3493
3442
|
all_sum = simd_sum(sumf[row]);
|
@@ -3525,7 +3474,6 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
3525
3474
|
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3526
3475
|
}
|
3527
3476
|
|
3528
|
-
#if QK_K == 256
|
3529
3477
|
void kernel_mul_mv_q3_K_f32_impl(
|
3530
3478
|
device const void * src0,
|
3531
3479
|
device const float * src1,
|
@@ -3684,84 +3632,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
3684
3632
|
}
|
3685
3633
|
}
|
3686
3634
|
}
|
3687
|
-
#else
|
3688
|
-
void kernel_mul_mv_q3_K_f32_impl(
|
3689
|
-
device const void * src0,
|
3690
|
-
device const float * src1,
|
3691
|
-
device float * dst,
|
3692
|
-
constant int64_t & ne00,
|
3693
|
-
constant int64_t & ne01,
|
3694
|
-
constant int64_t & ne02,
|
3695
|
-
constant int64_t & ne10,
|
3696
|
-
constant int64_t & ne12,
|
3697
|
-
constant int64_t & ne0,
|
3698
|
-
constant int64_t & ne1,
|
3699
|
-
constant uint & r2,
|
3700
|
-
constant uint & r3,
|
3701
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3702
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3703
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3704
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3705
|
-
|
3706
|
-
const int nb = ne00/QK_K;
|
3707
|
-
|
3708
|
-
const int64_t r0 = tgpig.x;
|
3709
|
-
const int64_t r1 = tgpig.y;
|
3710
|
-
const int64_t im = tgpig.z;
|
3711
|
-
|
3712
|
-
const int row = 2 * r0 + sgitg;
|
3713
|
-
|
3714
|
-
const uint i12 = im%ne12;
|
3715
|
-
const uint i13 = im/ne12;
|
3716
|
-
|
3717
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3718
|
-
|
3719
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
3720
|
-
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3721
|
-
|
3722
|
-
const int ix = tiisg/4;
|
3723
|
-
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
3724
|
-
const int iq = il/8; // 0, 0, 1, 1
|
3725
|
-
const int in = il%8; // 0, 4, 0, 4
|
3726
|
-
|
3727
|
-
float2 sum = {0.f, 0.f};
|
3728
|
-
|
3729
|
-
for (int i = ix; i < nb; i += 8) {
|
3730
|
-
|
3731
|
-
const float d_all = (float)(x[i].d);
|
3732
|
-
|
3733
|
-
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
3734
|
-
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
3735
|
-
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
3736
|
-
device const float * y = yy + i * QK_K + il;
|
3737
|
-
|
3738
|
-
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
3739
|
-
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
3740
|
-
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
3741
|
-
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
3742
|
-
|
3743
|
-
for (int l = 0; l < 4; l += 2) {
|
3744
|
-
const uint16_t hm = h[l/2] >> iq;
|
3745
|
-
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
3746
|
-
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
3747
|
-
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
3748
|
-
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
3749
|
-
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
3750
|
-
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
3751
|
-
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
3752
|
-
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
3753
|
-
}
|
3754
|
-
|
3755
|
-
}
|
3756
|
-
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
3757
|
-
|
3758
|
-
const float tot = simd_sum(sumf);
|
3759
|
-
if (tiisg == 0) {
|
3760
|
-
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
3761
|
-
}
|
3762
|
-
|
3763
|
-
}
|
3764
|
-
#endif
|
3765
3635
|
|
3766
3636
|
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
3767
3637
|
kernel void kernel_mul_mv_q3_K_f32(
|
@@ -3791,7 +3661,6 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
3791
3661
|
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3792
3662
|
}
|
3793
3663
|
|
3794
|
-
#if QK_K == 256
|
3795
3664
|
void kernel_mul_mv_q4_K_f32_impl(
|
3796
3665
|
device const void * src0,
|
3797
3666
|
device const float * src1,
|
@@ -3905,103 +3774,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3905
3774
|
}
|
3906
3775
|
}
|
3907
3776
|
}
|
3908
|
-
#else
|
3909
|
-
void kernel_mul_mv_q4_K_f32_impl(
|
3910
|
-
device const void * src0,
|
3911
|
-
device const float * src1,
|
3912
|
-
device float * dst,
|
3913
|
-
constant int64_t & ne00,
|
3914
|
-
constant int64_t & ne01,
|
3915
|
-
constant int64_t & ne02,
|
3916
|
-
constant int64_t & ne10,
|
3917
|
-
constant int64_t & ne12,
|
3918
|
-
constant int64_t & ne0,
|
3919
|
-
constant int64_t & ne1,
|
3920
|
-
constant uint & r2,
|
3921
|
-
constant uint & r3,
|
3922
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3923
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3924
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3925
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3926
|
-
|
3927
|
-
const int ix = tiisg/4; // 0...7
|
3928
|
-
const int it = tiisg%4; // 0...3
|
3929
|
-
|
3930
|
-
const int nb = ne00/QK_K;
|
3931
|
-
const int r0 = tgpig.x;
|
3932
|
-
const int r1 = tgpig.y;
|
3933
|
-
const int im = tgpig.z;
|
3934
|
-
const int first_row = r0 * N_DST;
|
3935
|
-
const int ib_row = first_row * nb;
|
3936
|
-
|
3937
|
-
const uint i12 = im%ne12;
|
3938
|
-
const uint i13 = im/ne12;
|
3939
|
-
|
3940
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3941
|
-
|
3942
|
-
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
3943
|
-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3944
|
-
|
3945
|
-
float yl[8];
|
3946
|
-
float yh[8];
|
3947
|
-
float sumf[N_DST]={0.f}, all_sum;
|
3948
|
-
|
3949
|
-
const int step = sizeof(block_q4_K) * nb / 2;
|
3950
|
-
|
3951
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3952
|
-
|
3953
|
-
uint16_t sc16[4];
|
3954
|
-
|
3955
|
-
for (int ib = ix; ib < nb; ib += 8) {
|
3956
|
-
|
3957
|
-
float2 sumy = {0.f, 0.f};
|
3958
|
-
for (int i = 0; i < 8; ++i) {
|
3959
|
-
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
3960
|
-
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
3961
|
-
}
|
3962
|
-
|
3963
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
3964
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3965
|
-
device const half * dh = x[ib].d;
|
3966
|
-
|
3967
|
-
for (int row = 0; row < N_DST; row++) {
|
3968
|
-
|
3969
|
-
sc16[0] = sc[0] & 0x000f;
|
3970
|
-
sc16[1] = sc[0] & 0x0f00;
|
3971
|
-
sc16[2] = sc[0] & 0x00f0;
|
3972
|
-
sc16[3] = sc[0] & 0xf000;
|
3973
|
-
|
3974
|
-
float2 acc1 = {0.f, 0.f};
|
3975
|
-
float2 acc2 = {0.f, 0.f};
|
3976
|
-
for (int i = 0; i < 8; i += 2) {
|
3977
|
-
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
3978
|
-
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
3979
|
-
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
3980
|
-
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
3981
|
-
}
|
3982
|
-
|
3983
|
-
float dall = dh[0];
|
3984
|
-
float dmin = dh[1];
|
3985
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
3986
|
-
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
3987
|
-
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
3988
|
-
|
3989
|
-
qs += step;
|
3990
|
-
sc += step;
|
3991
|
-
dh += step;
|
3992
|
-
}
|
3993
|
-
|
3994
|
-
y4 += 8 * QK_K;
|
3995
|
-
}
|
3996
|
-
|
3997
|
-
for (int row = 0; row < N_DST; ++row) {
|
3998
|
-
all_sum = simd_sum(sumf[row]);
|
3999
|
-
if (tiisg == 0) {
|
4000
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4001
|
-
}
|
4002
|
-
}
|
4003
|
-
}
|
4004
|
-
#endif
|
4005
3777
|
|
4006
3778
|
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
4007
3779
|
kernel void kernel_mul_mv_q4_K_f32(
|
@@ -4069,8 +3841,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4069
3841
|
|
4070
3842
|
const int step = sizeof(block_q5_K) * nb;
|
4071
3843
|
|
4072
|
-
#if QK_K == 256
|
4073
|
-
#
|
4074
3844
|
float yl[16], yh[16];
|
4075
3845
|
|
4076
3846
|
const uint16_t kmask1 = 0x3f3f;
|
@@ -4153,54 +3923,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4153
3923
|
y1 += 4 * QK_K;
|
4154
3924
|
|
4155
3925
|
}
|
4156
|
-
#else
|
4157
|
-
float yl[8], yh[8];
|
4158
|
-
|
4159
|
-
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
4160
|
-
const int ix = tiisg%8;
|
4161
|
-
const int iq = il/8; // 0, 0, 1, 1
|
4162
|
-
const int in = il%8; // 0, 4, 0, 4
|
4163
|
-
|
4164
|
-
device const float * y = yy + ix*QK_K + il;
|
4165
|
-
|
4166
|
-
for (int i = ix; i < nb; i += 8) {
|
4167
|
-
|
4168
|
-
for (int l = 0; l < 4; ++l) {
|
4169
|
-
yl[l+0] = y[l+ 0];
|
4170
|
-
yl[l+4] = y[l+16];
|
4171
|
-
yh[l+0] = y[l+32];
|
4172
|
-
yh[l+4] = y[l+48];
|
4173
|
-
}
|
4174
|
-
|
4175
|
-
device const half * dh = &x[i].d;
|
4176
|
-
device const uint8_t * q = x[i].qs + il;
|
4177
|
-
device const uint8_t * h = x[i].qh + in;
|
4178
|
-
device const int8_t * s = x[i].scales;
|
4179
|
-
|
4180
|
-
for (int row = 0; row < 2; ++row) {
|
4181
|
-
|
4182
|
-
const float d = dh[0];
|
4183
|
-
|
4184
|
-
float2 acc = {0.f, 0.f};
|
4185
|
-
for (int l = 0; l < 4; ++l) {
|
4186
|
-
const uint8_t hl = h[l] >> iq;
|
4187
|
-
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
4188
|
-
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
4189
|
-
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
4190
|
-
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
4191
|
-
}
|
4192
|
-
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
4193
|
-
|
4194
|
-
q += step;
|
4195
|
-
h += step;
|
4196
|
-
s += step;
|
4197
|
-
dh += step/2;
|
4198
|
-
|
4199
|
-
}
|
4200
|
-
|
4201
|
-
y += 8 * QK_K;
|
4202
|
-
}
|
4203
|
-
#endif
|
4204
3926
|
|
4205
3927
|
for (int row = 0; row < 2; ++row) {
|
4206
3928
|
const float tot = simd_sum(sumf[row]);
|
@@ -4279,7 +4001,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4279
4001
|
|
4280
4002
|
float sumf = 0;
|
4281
4003
|
|
4282
|
-
#if QK_K == 256
|
4283
4004
|
const int tid = tiisg/2;
|
4284
4005
|
const int ix = tiisg%2;
|
4285
4006
|
const int ip = tid/8; // 0 or 1
|
@@ -4315,30 +4036,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4315
4036
|
|
4316
4037
|
}
|
4317
4038
|
|
4318
|
-
#else
|
4319
|
-
const int ix = tiisg/4;
|
4320
|
-
const int il = 4*(tiisg%4);
|
4321
|
-
|
4322
|
-
for (int i = ix; i < nb; i += 8) {
|
4323
|
-
device const float * y = yy + i * QK_K + il;
|
4324
|
-
device const uint8_t * ql = x[i].ql + il;
|
4325
|
-
device const uint8_t * qh = x[i].qh + il;
|
4326
|
-
device const int8_t * s = x[i].scales;
|
4327
|
-
|
4328
|
-
const float d = x[i].d;
|
4329
|
-
|
4330
|
-
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
4331
|
-
for (int l = 0; l < 4; ++l) {
|
4332
|
-
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
4333
|
-
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
4334
|
-
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
4335
|
-
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
4336
|
-
}
|
4337
|
-
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
4338
|
-
}
|
4339
|
-
|
4340
|
-
#endif
|
4341
|
-
|
4342
4039
|
const float tot = simd_sum(sumf);
|
4343
4040
|
if (tiisg == 0) {
|
4344
4041
|
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
@@ -5172,9 +4869,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5172
4869
|
|
5173
4870
|
device const float * y4 = y + 32 * ix;
|
5174
4871
|
|
5175
|
-
#if QK_K != 64
|
5176
4872
|
iq1m_scale_t scale;
|
5177
|
-
#endif
|
5178
4873
|
|
5179
4874
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5180
4875
|
|
@@ -5195,10 +4890,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5195
4890
|
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
5196
4891
|
|
5197
4892
|
for (int row = 0; row < N_DST; row++) {
|
5198
|
-
|
5199
|
-
#if QK_K != 64
|
5200
4893
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5201
|
-
#endif
|
5202
4894
|
|
5203
4895
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5204
4896
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
@@ -5214,14 +4906,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5214
4906
|
}
|
5215
4907
|
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5216
4908
|
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5217
|
-
|
5218
|
-
const float d = (float) *((device const half *)(sc - 1));
|
5219
|
-
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
5220
|
-
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
5221
|
-
#else
|
4909
|
+
|
5222
4910
|
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
5223
4911
|
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
5224
|
-
#endif
|
5225
4912
|
|
5226
4913
|
sc += nb*sizeof(block_iq1_m)/2;
|
5227
4914
|
qs += nb*sizeof(block_iq1_m);
|
@@ -5333,7 +5020,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5333
5020
|
}
|
5334
5021
|
}
|
5335
5022
|
|
5336
|
-
#if QK_K != 64
|
5337
5023
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
5338
5024
|
device const void * src0,
|
5339
5025
|
device const float * src1,
|
@@ -5428,7 +5114,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5428
5114
|
}
|
5429
5115
|
}
|
5430
5116
|
}
|
5431
|
-
#endif
|
5432
5117
|
|
5433
5118
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5434
5119
|
kernel void kernel_mul_mv_iq1_s_f32(
|
@@ -5541,11 +5226,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
5541
5226
|
uint tiisg[[thread_index_in_simdgroup]],
|
5542
5227
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5543
5228
|
|
5544
|
-
#if QK_K == 64
|
5545
|
-
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5546
|
-
#else
|
5547
5229
|
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5548
|
-
#endif
|
5549
5230
|
}
|
5550
5231
|
|
5551
5232
|
//============================= templates and their specializations =============================
|
@@ -5671,10 +5352,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
5671
5352
|
float dl, ml;
|
5672
5353
|
uint8_t sc = xb->scales[il];
|
5673
5354
|
|
5674
|
-
#if QK_K == 256
|
5675
5355
|
q = q + 32*(il/8) + 16*(il&1);
|
5676
5356
|
il = (il/2)%4;
|
5677
|
-
|
5357
|
+
|
5678
5358
|
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5679
5359
|
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5680
5360
|
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
@@ -5690,7 +5370,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5690
5370
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
5691
5371
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5692
5372
|
|
5693
|
-
#if QK_K == 256
|
5694
5373
|
q = q + 32 * (il/8) + 16 * (il&1);
|
5695
5374
|
h = h + 16 * (il&1);
|
5696
5375
|
uint8_t m = 1 << (il/2);
|
@@ -5711,17 +5390,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5711
5390
|
for (int i = 0; i < 16; ++i) {
|
5712
5391
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
5713
5392
|
}
|
5714
|
-
#else
|
5715
|
-
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
5716
|
-
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
5717
|
-
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
5718
|
-
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5719
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5720
|
-
uint8_t m = 1<<(il*2);
|
5721
|
-
for (int i = 0; i < 16; ++i) {
|
5722
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
5723
|
-
}
|
5724
|
-
#endif
|
5725
5393
|
}
|
5726
5394
|
|
5727
5395
|
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
@@ -5733,7 +5401,6 @@ template <typename type4x4>
|
|
5733
5401
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
5734
5402
|
device const uchar * q = xb->qs;
|
5735
5403
|
|
5736
|
-
#if QK_K == 256
|
5737
5404
|
short is = (il/4) * 2;
|
5738
5405
|
q = q + (il/4) * 32 + 16 * (il&1);
|
5739
5406
|
il = il & 3;
|
@@ -5742,16 +5409,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
5742
5409
|
const float min = xb->dmin;
|
5743
5410
|
const float dl = d * sc[0];
|
5744
5411
|
const float ml = min * sc[1];
|
5745
|
-
|
5746
|
-
(void) get_scale_min_k4_just2;
|
5747
|
-
|
5748
|
-
q = q + 16 * (il&1);
|
5749
|
-
device const uint8_t * s = xb->scales;
|
5750
|
-
device const half2 * dh = (device const half2 *)xb->d;
|
5751
|
-
const float2 d = (float2)dh[0];
|
5752
|
-
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
5753
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
5754
|
-
#endif
|
5412
|
+
|
5755
5413
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5756
5414
|
for (int i = 0; i < 16; ++i) {
|
5757
5415
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
@@ -5763,7 +5421,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5763
5421
|
device const uint8_t * q = xb->qs;
|
5764
5422
|
device const uint8_t * qh = xb->qh;
|
5765
5423
|
|
5766
|
-
#if QK_K == 256
|
5767
5424
|
short is = (il/4) * 2;
|
5768
5425
|
q = q + 32 * (il/4) + 16 * (il&1);
|
5769
5426
|
qh = qh + 16 * (il&1);
|
@@ -5780,17 +5437,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5780
5437
|
for (int i = 0; i < 16; ++i) {
|
5781
5438
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
5782
5439
|
}
|
5783
|
-
#else
|
5784
|
-
q = q + 16 * (il&1);
|
5785
|
-
device const int8_t * s = xb->scales;
|
5786
|
-
const float dl = xb->d * s[il];
|
5787
|
-
uint8_t m = 1<<(il*2);
|
5788
|
-
const float coef = il<2 ? 1.f : 1.f/16.f;
|
5789
|
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5790
|
-
for (int i = 0; i < 16; ++i) {
|
5791
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
5792
|
-
}
|
5793
|
-
#endif
|
5794
5440
|
}
|
5795
5441
|
|
5796
5442
|
template <typename type4x4>
|
@@ -5800,15 +5446,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
5800
5446
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
5801
5447
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5802
5448
|
|
5803
|
-
#if QK_K == 256
|
5804
5449
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
5805
5450
|
qh = qh + 32*(il/8) + 16*(il&1);
|
5806
5451
|
float sc = scales[(il%2) + 2 * ((il/2))];
|
5807
5452
|
il = (il/2) & 3;
|
5808
|
-
|
5809
|
-
ql = ql + 16 * (il&1);
|
5810
|
-
float sc = scales[il];
|
5811
|
-
#endif
|
5453
|
+
|
5812
5454
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5813
5455
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
5814
5456
|
const float coef = il>1 ? 1.f/16.f : 1.f;
|
@@ -5965,20 +5607,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
|
|
5965
5607
|
const int ib32 = il/2;
|
5966
5608
|
il = il%2;
|
5967
5609
|
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5968
|
-
|
5969
|
-
const float d = xb->d;
|
5970
|
-
#else
|
5610
|
+
|
5971
5611
|
iq1m_scale_t scale;
|
5972
5612
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5973
5613
|
const float d = scale.f16;
|
5974
|
-
|
5614
|
+
|
5975
5615
|
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5976
5616
|
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
5977
|
-
|
5978
|
-
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
5979
|
-
#else
|
5617
|
+
|
5980
5618
|
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
5981
|
-
#endif
|
5982
5619
|
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5983
5620
|
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5984
5621
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
@@ -6008,9 +5645,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
6008
5645
|
|
6009
5646
|
template <typename type4x4>
|
6010
5647
|
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
6011
|
-
#if QK_K == 64
|
6012
|
-
dequantize_iq4_nl(xb, il, reg);
|
6013
|
-
#else
|
6014
5648
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
6015
5649
|
const int ib32 = il/2;
|
6016
5650
|
il = il%2;
|
@@ -6027,7 +5661,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
6027
5661
|
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
6028
5662
|
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
6029
5663
|
}
|
6030
|
-
#endif
|
6031
5664
|
}
|
6032
5665
|
|
6033
5666
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
@@ -6532,11 +6165,7 @@ kernel void kernel_mul_mm_id(
|
|
6532
6165
|
sgitg);
|
6533
6166
|
}
|
6534
6167
|
|
6535
|
-
#if QK_K == 256
|
6536
6168
|
#define QK_NL 16
|
6537
|
-
#else
|
6538
|
-
#define QK_NL 4
|
6539
|
-
#endif
|
6540
6169
|
|
6541
6170
|
//
|
6542
6171
|
// get rows
|
@@ -6576,11 +6205,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
|
|
6576
6205
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6577
6206
|
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6578
6207
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6579
|
-
#if QK_K == 64
|
6580
|
-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6581
|
-
#else
|
6582
6208
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6583
|
-
#endif
|
6584
6209
|
|
6585
6210
|
//
|
6586
6211
|
// matrix-matrix multiplication
|
@@ -6608,11 +6233,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
|
|
6608
6233
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6609
6234
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6610
6235
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6611
|
-
#if QK_K == 64
|
6612
|
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
6613
|
-
#else
|
6614
6236
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6615
|
-
#endif
|
6616
6237
|
|
6617
6238
|
//
|
6618
6239
|
// indirect matrix-matrix multiplication
|
@@ -6640,11 +6261,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
|
|
6640
6261
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6641
6262
|
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6642
6263
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6643
|
-
#if QK_K == 64
|
6644
|
-
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6645
|
-
#else
|
6646
6264
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6647
|
-
#endif
|
6648
6265
|
|
6649
6266
|
//
|
6650
6267
|
// matrix-vector multiplication
|
@@ -6853,7 +6470,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
|
|
6853
6470
|
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
6854
6471
|
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
6855
6472
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
6856
|
-
#if QK_K != 64
|
6857
6473
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
6858
|
-
#endif
|
6859
6474
|
|