llama_cpp 0.15.2 → 0.15.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims(
1640
1640
  typedef void (rope_t)(
1641
1641
  device const void * src0,
1642
1642
  device const int32_t * src1,
1643
+ device const float * src2,
1643
1644
  device float * dst,
1644
1645
  constant int64_t & ne00,
1645
1646
  constant int64_t & ne01,
@@ -1675,6 +1676,7 @@ template<typename T>
1675
1676
  kernel void kernel_rope(
1676
1677
  device const void * src0,
1677
1678
  device const int32_t * src1,
1679
+ device const float * src2,
1678
1680
  device float * dst,
1679
1681
  constant int64_t & ne00,
1680
1682
  constant int64_t & ne01,
@@ -1744,8 +1746,10 @@ kernel void kernel_rope(
1744
1746
 
1745
1747
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1746
1748
  const float cur_rot = inv_ndims*ic - ib;
1749
+ const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1750
+
1751
+ const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1747
1752
 
1748
- const float theta = theta_0 * pow(freq_base, cur_rot);
1749
1753
  float cos_theta, sin_theta;
1750
1754
  rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1751
1755
 
@@ -2204,11 +2208,7 @@ kernel void kernel_flash_attn_ext_f16(
2204
2208
  // pointer to the mask
2205
2209
  device const half * mp = (device const half *) (mask + iq1*nb31);
2206
2210
 
2207
- // prepare diagonal scale matrix
2208
- simdgroup_float8x8 mscale(scale);
2209
-
2210
- // prepare diagonal slope matrix
2211
- simdgroup_float8x8 mslope(1.0f);
2211
+ float slope = 1.0f;
2212
2212
 
2213
2213
  // ALiBi
2214
2214
  if (max_bias > 0.0f) {
@@ -2217,7 +2217,7 @@ kernel void kernel_flash_attn_ext_f16(
2217
2217
  const float base = h < n_head_log2 ? m0 : m1;
2218
2218
  const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
2219
 
2220
- mslope = simdgroup_float8x8(pow(base, exph));
2220
+ slope = pow(base, exph);
2221
2221
  }
2222
2222
 
2223
2223
  // loop over the KV cache
@@ -2242,18 +2242,20 @@ kernel void kernel_flash_attn_ext_f16(
2242
2242
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2243
2243
  }
2244
2244
 
2245
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2246
+
2247
+ const short tx = tiisg%4;
2248
+ const short ty = tiisg/4;
2249
+
2245
2250
  if (mask != q) {
2246
2251
  // mqk = mqk*scale + mask*slope
2247
- simdgroup_half8x8 mm;
2248
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2249
- simdgroup_multiply(mm, mslope, mm);
2250
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2252
+ ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2253
+ ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2251
2254
  } else {
2252
2255
  // mqk = mqk*scale
2253
- simdgroup_multiply(mqk, mscale, mqk);
2256
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2257
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2254
2258
  }
2255
-
2256
- simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2257
2259
  }
2258
2260
  }
2259
2261
 
@@ -2816,8 +2818,7 @@ kernel void kernel_cpy_f32_f16(
2816
2818
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2817
2819
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2818
2820
 
2819
- // TODO: is there a better way to handle -INFINITY?
2820
- dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2821
+ dst_data[i00] = src[0];
2821
2822
  }
2822
2823
  }
2823
2824
 
@@ -3385,7 +3386,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3385
3386
 
3386
3387
  const int step = sizeof(block_q2_K) * nb;
3387
3388
 
3388
- #if QK_K == 256
3389
3389
  const int ix = tiisg/8; // 0...3
3390
3390
  const int it = tiisg%8; // 0...7
3391
3391
  const int iq = it/4; // 0 or 1
@@ -3437,57 +3437,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3437
3437
 
3438
3438
  y4 += 4 * QK_K;
3439
3439
  }
3440
- #else
3441
- const int ix = tiisg/2; // 0...15
3442
- const int it = tiisg%2; // 0...1
3443
-
3444
- device const float * y4 = y + ix * QK_K + 8 * it;
3445
-
3446
- for (int ib = ix; ib < nb; ib += 16) {
3447
-
3448
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
3449
- for (int i = 0; i < 8; ++i) {
3450
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
3451
- yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
3452
- yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
3453
- yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
3454
- }
3455
-
3456
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
3457
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3458
- device const half * dh = &x[ib].d;
3459
-
3460
- for (int row = 0; row < N_DST; row++) {
3461
-
3462
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
3463
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
3464
- for (int i = 0; i < 8; i += 2) {
3465
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
3466
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
3467
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
3468
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
3469
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
3470
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
3471
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
3472
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
3473
- }
3474
-
3475
- float dall = dh[0];
3476
- float dmin = dh[1];
3477
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
3478
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
3479
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
3480
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
3481
- dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
3482
-
3483
- qs += step/2;
3484
- sc += step;
3485
- dh += step/2;
3486
- }
3487
-
3488
- y4 += 16 * QK_K;
3489
- }
3490
- #endif
3491
3440
 
3492
3441
  for (int row = 0; row < N_DST; ++row) {
3493
3442
  all_sum = simd_sum(sumf[row]);
@@ -3525,7 +3474,6 @@ kernel void kernel_mul_mv_q2_K_f32(
3525
3474
  kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3526
3475
  }
3527
3476
 
3528
- #if QK_K == 256
3529
3477
  void kernel_mul_mv_q3_K_f32_impl(
3530
3478
  device const void * src0,
3531
3479
  device const float * src1,
@@ -3684,84 +3632,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3684
3632
  }
3685
3633
  }
3686
3634
  }
3687
- #else
3688
- void kernel_mul_mv_q3_K_f32_impl(
3689
- device const void * src0,
3690
- device const float * src1,
3691
- device float * dst,
3692
- constant int64_t & ne00,
3693
- constant int64_t & ne01,
3694
- constant int64_t & ne02,
3695
- constant int64_t & ne10,
3696
- constant int64_t & ne12,
3697
- constant int64_t & ne0,
3698
- constant int64_t & ne1,
3699
- constant uint & r2,
3700
- constant uint & r3,
3701
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3702
- uint3 tgpig[[threadgroup_position_in_grid]],
3703
- uint tiisg[[thread_index_in_simdgroup]],
3704
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3705
-
3706
- const int nb = ne00/QK_K;
3707
-
3708
- const int64_t r0 = tgpig.x;
3709
- const int64_t r1 = tgpig.y;
3710
- const int64_t im = tgpig.z;
3711
-
3712
- const int row = 2 * r0 + sgitg;
3713
-
3714
- const uint i12 = im%ne12;
3715
- const uint i13 = im/ne12;
3716
-
3717
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3718
-
3719
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
3720
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3721
-
3722
- const int ix = tiisg/4;
3723
- const int il = 4 * (tiisg%4);// 0, 4, 8, 12
3724
- const int iq = il/8; // 0, 0, 1, 1
3725
- const int in = il%8; // 0, 4, 0, 4
3726
-
3727
- float2 sum = {0.f, 0.f};
3728
-
3729
- for (int i = ix; i < nb; i += 8) {
3730
-
3731
- const float d_all = (float)(x[i].d);
3732
-
3733
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
3734
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
3735
- device const uint16_t * s = (device const uint16_t *)(x[i].scales);
3736
- device const float * y = yy + i * QK_K + il;
3737
-
3738
- const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
3739
- const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
3740
- const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
3741
- const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
3742
-
3743
- for (int l = 0; l < 4; l += 2) {
3744
- const uint16_t hm = h[l/2] >> iq;
3745
- sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
3746
- + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
3747
- + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
3748
- + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
3749
- sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
3750
- + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
3751
- + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
3752
- + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
3753
- }
3754
-
3755
- }
3756
- const float sumf = sum[0] + sum[1] * 1.f/256.f;
3757
-
3758
- const float tot = simd_sum(sumf);
3759
- if (tiisg == 0) {
3760
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
3761
- }
3762
-
3763
- }
3764
- #endif
3765
3635
 
3766
3636
  [[host_name("kernel_mul_mv_q3_K_f32")]]
3767
3637
  kernel void kernel_mul_mv_q3_K_f32(
@@ -3791,7 +3661,6 @@ kernel void kernel_mul_mv_q3_K_f32(
3791
3661
  kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3792
3662
  }
3793
3663
 
3794
- #if QK_K == 256
3795
3664
  void kernel_mul_mv_q4_K_f32_impl(
3796
3665
  device const void * src0,
3797
3666
  device const float * src1,
@@ -3905,103 +3774,6 @@ void kernel_mul_mv_q4_K_f32_impl(
3905
3774
  }
3906
3775
  }
3907
3776
  }
3908
- #else
3909
- void kernel_mul_mv_q4_K_f32_impl(
3910
- device const void * src0,
3911
- device const float * src1,
3912
- device float * dst,
3913
- constant int64_t & ne00,
3914
- constant int64_t & ne01,
3915
- constant int64_t & ne02,
3916
- constant int64_t & ne10,
3917
- constant int64_t & ne12,
3918
- constant int64_t & ne0,
3919
- constant int64_t & ne1,
3920
- constant uint & r2,
3921
- constant uint & r3,
3922
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3923
- uint3 tgpig[[threadgroup_position_in_grid]],
3924
- uint tiisg[[thread_index_in_simdgroup]],
3925
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3926
-
3927
- const int ix = tiisg/4; // 0...7
3928
- const int it = tiisg%4; // 0...3
3929
-
3930
- const int nb = ne00/QK_K;
3931
- const int r0 = tgpig.x;
3932
- const int r1 = tgpig.y;
3933
- const int im = tgpig.z;
3934
- const int first_row = r0 * N_DST;
3935
- const int ib_row = first_row * nb;
3936
-
3937
- const uint i12 = im%ne12;
3938
- const uint i13 = im/ne12;
3939
-
3940
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3941
-
3942
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
3943
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3944
-
3945
- float yl[8];
3946
- float yh[8];
3947
- float sumf[N_DST]={0.f}, all_sum;
3948
-
3949
- const int step = sizeof(block_q4_K) * nb / 2;
3950
-
3951
- device const float * y4 = y + ix * QK_K + 8 * it;
3952
-
3953
- uint16_t sc16[4];
3954
-
3955
- for (int ib = ix; ib < nb; ib += 8) {
3956
-
3957
- float2 sumy = {0.f, 0.f};
3958
- for (int i = 0; i < 8; ++i) {
3959
- yl[i] = y4[i+ 0]; sumy[0] += yl[i];
3960
- yh[i] = y4[i+32]; sumy[1] += yh[i];
3961
- }
3962
-
3963
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
3964
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3965
- device const half * dh = x[ib].d;
3966
-
3967
- for (int row = 0; row < N_DST; row++) {
3968
-
3969
- sc16[0] = sc[0] & 0x000f;
3970
- sc16[1] = sc[0] & 0x0f00;
3971
- sc16[2] = sc[0] & 0x00f0;
3972
- sc16[3] = sc[0] & 0xf000;
3973
-
3974
- float2 acc1 = {0.f, 0.f};
3975
- float2 acc2 = {0.f, 0.f};
3976
- for (int i = 0; i < 8; i += 2) {
3977
- acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
3978
- acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
3979
- acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
3980
- acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
3981
- }
3982
-
3983
- float dall = dh[0];
3984
- float dmin = dh[1];
3985
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
3986
- (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
3987
- dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
3988
-
3989
- qs += step;
3990
- sc += step;
3991
- dh += step;
3992
- }
3993
-
3994
- y4 += 8 * QK_K;
3995
- }
3996
-
3997
- for (int row = 0; row < N_DST; ++row) {
3998
- all_sum = simd_sum(sumf[row]);
3999
- if (tiisg == 0) {
4000
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4001
- }
4002
- }
4003
- }
4004
- #endif
4005
3777
 
4006
3778
  [[host_name("kernel_mul_mv_q4_K_f32")]]
4007
3779
  kernel void kernel_mul_mv_q4_K_f32(
@@ -4069,8 +3841,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4069
3841
 
4070
3842
  const int step = sizeof(block_q5_K) * nb;
4071
3843
 
4072
- #if QK_K == 256
4073
- #
4074
3844
  float yl[16], yh[16];
4075
3845
 
4076
3846
  const uint16_t kmask1 = 0x3f3f;
@@ -4153,54 +3923,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4153
3923
  y1 += 4 * QK_K;
4154
3924
 
4155
3925
  }
4156
- #else
4157
- float yl[8], yh[8];
4158
-
4159
- const int il = 4 * (tiisg/8); // 0, 4, 8, 12
4160
- const int ix = tiisg%8;
4161
- const int iq = il/8; // 0, 0, 1, 1
4162
- const int in = il%8; // 0, 4, 0, 4
4163
-
4164
- device const float * y = yy + ix*QK_K + il;
4165
-
4166
- for (int i = ix; i < nb; i += 8) {
4167
-
4168
- for (int l = 0; l < 4; ++l) {
4169
- yl[l+0] = y[l+ 0];
4170
- yl[l+4] = y[l+16];
4171
- yh[l+0] = y[l+32];
4172
- yh[l+4] = y[l+48];
4173
- }
4174
-
4175
- device const half * dh = &x[i].d;
4176
- device const uint8_t * q = x[i].qs + il;
4177
- device const uint8_t * h = x[i].qh + in;
4178
- device const int8_t * s = x[i].scales;
4179
-
4180
- for (int row = 0; row < 2; ++row) {
4181
-
4182
- const float d = dh[0];
4183
-
4184
- float2 acc = {0.f, 0.f};
4185
- for (int l = 0; l < 4; ++l) {
4186
- const uint8_t hl = h[l] >> iq;
4187
- acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
4188
- + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
4189
- acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
4190
- + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
4191
- }
4192
- sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
4193
-
4194
- q += step;
4195
- h += step;
4196
- s += step;
4197
- dh += step/2;
4198
-
4199
- }
4200
-
4201
- y += 8 * QK_K;
4202
- }
4203
- #endif
4204
3926
 
4205
3927
  for (int row = 0; row < 2; ++row) {
4206
3928
  const float tot = simd_sum(sumf[row]);
@@ -4279,7 +4001,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4279
4001
 
4280
4002
  float sumf = 0;
4281
4003
 
4282
- #if QK_K == 256
4283
4004
  const int tid = tiisg/2;
4284
4005
  const int ix = tiisg%2;
4285
4006
  const int ip = tid/8; // 0 or 1
@@ -4315,30 +4036,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4315
4036
 
4316
4037
  }
4317
4038
 
4318
- #else
4319
- const int ix = tiisg/4;
4320
- const int il = 4*(tiisg%4);
4321
-
4322
- for (int i = ix; i < nb; i += 8) {
4323
- device const float * y = yy + i * QK_K + il;
4324
- device const uint8_t * ql = x[i].ql + il;
4325
- device const uint8_t * qh = x[i].qh + il;
4326
- device const int8_t * s = x[i].scales;
4327
-
4328
- const float d = x[i].d;
4329
-
4330
- float4 sums = {0.f, 0.f, 0.f, 0.f};
4331
- for (int l = 0; l < 4; ++l) {
4332
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
4333
- sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
4334
- sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
4335
- sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
4336
- }
4337
- sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
4338
- }
4339
-
4340
- #endif
4341
-
4342
4039
  const float tot = simd_sum(sumf);
4343
4040
  if (tiisg == 0) {
4344
4041
  dst[r1*ne0 + im*ne0*ne1 + row] = tot;
@@ -5172,9 +4869,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5172
4869
 
5173
4870
  device const float * y4 = y + 32 * ix;
5174
4871
 
5175
- #if QK_K != 64
5176
4872
  iq1m_scale_t scale;
5177
- #endif
5178
4873
 
5179
4874
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5180
4875
 
@@ -5195,10 +4890,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5195
4890
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5196
4891
 
5197
4892
  for (int row = 0; row < N_DST; row++) {
5198
-
5199
- #if QK_K != 64
5200
4893
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5201
- #endif
5202
4894
 
5203
4895
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5204
4896
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -5214,14 +4906,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
5214
4906
  }
5215
4907
  const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5216
4908
  const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5217
- #if QK_K == 64
5218
- const float d = (float) *((device const half *)(sc - 1));
5219
- sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
5220
- (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
5221
- #else
4909
+
5222
4910
  sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
5223
4911
  (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
5224
- #endif
5225
4912
 
5226
4913
  sc += nb*sizeof(block_iq1_m)/2;
5227
4914
  qs += nb*sizeof(block_iq1_m);
@@ -5333,7 +5020,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5333
5020
  }
5334
5021
  }
5335
5022
 
5336
- #if QK_K != 64
5337
5023
  void kernel_mul_mv_iq4_xs_f32_impl(
5338
5024
  device const void * src0,
5339
5025
  device const float * src1,
@@ -5428,7 +5114,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5428
5114
  }
5429
5115
  }
5430
5116
  }
5431
- #endif
5432
5117
 
5433
5118
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5434
5119
  kernel void kernel_mul_mv_iq1_s_f32(
@@ -5541,11 +5226,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5541
5226
  uint tiisg[[thread_index_in_simdgroup]],
5542
5227
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5543
5228
 
5544
- #if QK_K == 64
5545
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5546
- #else
5547
5229
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5548
- #endif
5549
5230
  }
5550
5231
 
5551
5232
  //============================= templates and their specializations =============================
@@ -5671,10 +5352,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
5671
5352
  float dl, ml;
5672
5353
  uint8_t sc = xb->scales[il];
5673
5354
 
5674
- #if QK_K == 256
5675
5355
  q = q + 32*(il/8) + 16*(il&1);
5676
5356
  il = (il/2)%4;
5677
- #endif
5357
+
5678
5358
  half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5679
5359
  uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5680
5360
  dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
@@ -5690,7 +5370,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5690
5370
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
5691
5371
  device const int8_t * scales = (device const int8_t *)xb->scales;
5692
5372
 
5693
- #if QK_K == 256
5694
5373
  q = q + 32 * (il/8) + 16 * (il&1);
5695
5374
  h = h + 16 * (il&1);
5696
5375
  uint8_t m = 1 << (il/2);
@@ -5711,17 +5390,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5711
5390
  for (int i = 0; i < 16; ++i) {
5712
5391
  reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
5713
5392
  }
5714
- #else
5715
- float kcoef = il&1 ? 1.f/16.f : 1.f;
5716
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
5717
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
5718
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5719
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5720
- uint8_t m = 1<<(il*2);
5721
- for (int i = 0; i < 16; ++i) {
5722
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
5723
- }
5724
- #endif
5725
5393
  }
5726
5394
 
5727
5395
  static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
@@ -5733,7 +5401,6 @@ template <typename type4x4>
5733
5401
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
5734
5402
  device const uchar * q = xb->qs;
5735
5403
 
5736
- #if QK_K == 256
5737
5404
  short is = (il/4) * 2;
5738
5405
  q = q + (il/4) * 32 + 16 * (il&1);
5739
5406
  il = il & 3;
@@ -5742,16 +5409,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
5742
5409
  const float min = xb->dmin;
5743
5410
  const float dl = d * sc[0];
5744
5411
  const float ml = min * sc[1];
5745
- #else
5746
- (void) get_scale_min_k4_just2;
5747
-
5748
- q = q + 16 * (il&1);
5749
- device const uint8_t * s = xb->scales;
5750
- device const half2 * dh = (device const half2 *)xb->d;
5751
- const float2 d = (float2)dh[0];
5752
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
5753
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
5754
- #endif
5412
+
5755
5413
  const ushort mask = il<2 ? 0x0F : 0xF0;
5756
5414
  for (int i = 0; i < 16; ++i) {
5757
5415
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
@@ -5763,7 +5421,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5763
5421
  device const uint8_t * q = xb->qs;
5764
5422
  device const uint8_t * qh = xb->qh;
5765
5423
 
5766
- #if QK_K == 256
5767
5424
  short is = (il/4) * 2;
5768
5425
  q = q + 32 * (il/4) + 16 * (il&1);
5769
5426
  qh = qh + 16 * (il&1);
@@ -5780,17 +5437,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5780
5437
  for (int i = 0; i < 16; ++i) {
5781
5438
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
5782
5439
  }
5783
- #else
5784
- q = q + 16 * (il&1);
5785
- device const int8_t * s = xb->scales;
5786
- const float dl = xb->d * s[il];
5787
- uint8_t m = 1<<(il*2);
5788
- const float coef = il<2 ? 1.f : 1.f/16.f;
5789
- const ushort mask = il<2 ? 0x0F : 0xF0;
5790
- for (int i = 0; i < 16; ++i) {
5791
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
5792
- }
5793
- #endif
5794
5440
  }
5795
5441
 
5796
5442
  template <typename type4x4>
@@ -5800,15 +5446,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
5800
5446
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
5801
5447
  device const int8_t * scales = (device const int8_t *)xb->scales;
5802
5448
 
5803
- #if QK_K == 256
5804
5449
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
5805
5450
  qh = qh + 32*(il/8) + 16*(il&1);
5806
5451
  float sc = scales[(il%2) + 2 * ((il/2))];
5807
5452
  il = (il/2) & 3;
5808
- #else
5809
- ql = ql + 16 * (il&1);
5810
- float sc = scales[il];
5811
- #endif
5453
+
5812
5454
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5813
5455
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
5814
5456
  const float coef = il>1 ? 1.f/16.f : 1.f;
@@ -5965,20 +5607,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
5965
5607
  const int ib32 = il/2;
5966
5608
  il = il%2;
5967
5609
  device const uint16_t * sc = (device const uint16_t *)xb->scales;
5968
- #if QK_K == 64
5969
- const float d = xb->d;
5970
- #else
5610
+
5971
5611
  iq1m_scale_t scale;
5972
5612
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5973
5613
  const float d = scale.f16;
5974
- #endif
5614
+
5975
5615
  device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5976
5616
  device const uint8_t * qh = xb->qh + 2*ib32 + il;
5977
- #if QK_K == 64
5978
- const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5979
- #else
5617
+
5980
5618
  const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5981
- #endif
5982
5619
  const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5983
5620
  const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5984
5621
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -6008,9 +5645,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
6008
5645
 
6009
5646
  template <typename type4x4>
6010
5647
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
6011
- #if QK_K == 64
6012
- dequantize_iq4_nl(xb, il, reg);
6013
- #else
6014
5648
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
6015
5649
  const int ib32 = il/2;
6016
5650
  il = il%2;
@@ -6027,7 +5661,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
6027
5661
  reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
6028
5662
  reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
6029
5663
  }
6030
- #endif
6031
5664
  }
6032
5665
 
6033
5666
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6532,11 +6165,7 @@ kernel void kernel_mul_mm_id(
6532
6165
  sgitg);
6533
6166
  }
6534
6167
 
6535
- #if QK_K == 256
6536
6168
  #define QK_NL 16
6537
- #else
6538
- #define QK_NL 4
6539
- #endif
6540
6169
 
6541
6170
  //
6542
6171
  // get rows
@@ -6576,11 +6205,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
6576
6205
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6577
6206
  template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6578
6207
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6579
- #if QK_K == 64
6580
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6581
- #else
6582
6208
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6583
- #endif
6584
6209
 
6585
6210
  //
6586
6211
  // matrix-matrix multiplication
@@ -6608,11 +6233,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
6608
6233
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6609
6234
  template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
6610
6235
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6611
- #if QK_K == 64
6612
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6613
- #else
6614
6236
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6615
- #endif
6616
6237
 
6617
6238
  //
6618
6239
  // indirect matrix-matrix multiplication
@@ -6640,11 +6261,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
6640
6261
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6641
6262
  template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
6642
6263
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6643
- #if QK_K == 64
6644
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6645
- #else
6646
6264
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6647
- #endif
6648
6265
 
6649
6266
  //
6650
6267
  // matrix-vector multiplication
@@ -6853,7 +6470,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
6853
6470
  template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6854
6471
  template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6855
6472
  template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6856
- #if QK_K != 64
6857
6473
  template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6858
- #endif
6859
6474