llama_cpp 0.15.2 → 0.15.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims(
1640
1640
  typedef void (rope_t)(
1641
1641
  device const void * src0,
1642
1642
  device const int32_t * src1,
1643
+ device const float * src2,
1643
1644
  device float * dst,
1644
1645
  constant int64_t & ne00,
1645
1646
  constant int64_t & ne01,
@@ -1675,6 +1676,7 @@ template<typename T>
1675
1676
  kernel void kernel_rope(
1676
1677
  device const void * src0,
1677
1678
  device const int32_t * src1,
1679
+ device const float * src2,
1678
1680
  device float * dst,
1679
1681
  constant int64_t & ne00,
1680
1682
  constant int64_t & ne01,
@@ -1744,8 +1746,10 @@ kernel void kernel_rope(
1744
1746
 
1745
1747
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1746
1748
  const float cur_rot = inv_ndims*ic - ib;
1749
+ const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1750
+
1751
+ const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1747
1752
 
1748
- const float theta = theta_0 * pow(freq_base, cur_rot);
1749
1753
  float cos_theta, sin_theta;
1750
1754
  rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1751
1755
 
@@ -2204,11 +2208,7 @@ kernel void kernel_flash_attn_ext_f16(
2204
2208
  // pointer to the mask
2205
2209
  device const half * mp = (device const half *) (mask + iq1*nb31);
2206
2210
 
2207
- // prepare diagonal scale matrix
2208
- simdgroup_float8x8 mscale(scale);
2209
-
2210
- // prepare diagonal slope matrix
2211
- simdgroup_float8x8 mslope(1.0f);
2211
+ float slope = 1.0f;
2212
2212
 
2213
2213
  // ALiBi
2214
2214
  if (max_bias > 0.0f) {
@@ -2217,7 +2217,7 @@ kernel void kernel_flash_attn_ext_f16(
2217
2217
  const float base = h < n_head_log2 ? m0 : m1;
2218
2218
  const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
2219
 
2220
- mslope = simdgroup_float8x8(pow(base, exph));
2220
+ slope = pow(base, exph);
2221
2221
  }
2222
2222
 
2223
2223
  // loop over the KV cache
@@ -2242,18 +2242,20 @@ kernel void kernel_flash_attn_ext_f16(
2242
2242
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2243
2243
  }
2244
2244
 
2245
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2246
+
2247
+ const short tx = tiisg%4;
2248
+ const short ty = tiisg/4;
2249
+
2245
2250
  if (mask != q) {
2246
2251
  // mqk = mqk*scale + mask*slope
2247
- simdgroup_half8x8 mm;
2248
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2249
- simdgroup_multiply(mm, mslope, mm);
2250
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2252
+ ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2253
+ ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2251
2254
  } else {
2252
2255
  // mqk = mqk*scale
2253
- simdgroup_multiply(mqk, mscale, mqk);
2256
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2257
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2254
2258
  }
2255
-
2256
- simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2257
2259
  }
2258
2260
  }
2259
2261
 
@@ -2816,8 +2818,7 @@ kernel void kernel_cpy_f32_f16(
2816
2818
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2817
2819
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2818
2820
 
2819
- // TODO: is there a better way to handle -INFINITY?
2820
- dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2821
+ dst_data[i00] = src[0];
2821
2822
  }
2822
2823
  }
2823
2824
 
@@ -3385,7 +3386,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3385
3386
 
3386
3387
  const int step = sizeof(block_q2_K) * nb;
3387
3388
 
3388
- #if QK_K == 256
3389
3389
  const int ix = tiisg/8; // 0...3
3390
3390
  const int it = tiisg%8; // 0...7
3391
3391
  const int iq = it/4; // 0 or 1
@@ -3437,57 +3437,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3437
3437
 
3438
3438
  y4 += 4 * QK_K;
3439
3439
  }
3440
- #else
3441
- const int ix = tiisg/2; // 0...15
3442
- const int it = tiisg%2; // 0...1
3443
-
3444
- device const float * y4 = y + ix * QK_K + 8 * it;
3445
-
3446
- for (int ib = ix; ib < nb; ib += 16) {
3447
-
3448
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
3449
- for (int i = 0; i < 8; ++i) {
3450
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
3451
- yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
3452
- yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
3453
- yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
3454
- }
3455
-
3456
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
3457
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3458
- device const half * dh = &x[ib].d;
3459
-
3460
- for (int row = 0; row < N_DST; row++) {
3461
-
3462
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
3463
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
3464
- for (int i = 0; i < 8; i += 2) {
3465
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
3466
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
3467
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
3468
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
3469
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
3470
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
3471
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
3472
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
3473
- }
3474
-
3475
- float dall = dh[0];
3476
- float dmin = dh[1];
3477
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
3478
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
3479
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
3480
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
3481
- dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
3482
-
3483
- qs += step/2;
3484
- sc += step;
3485
- dh += step/2;
3486
- }
3487
-
3488
- y4 += 16 * QK_K;
3489
- }
3490
- #endif
3491
3440
 
3492
3441
  for (int row = 0; row < N_DST; ++row) {
3493
3442
  all_sum = simd_sum(sumf[row]);
@@ -3525,7 +3474,6 @@ kernel void kernel_mul_mv_q2_K_f32(
3525
3474
  kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3526
3475
  }
3527
3476
 
3528
- #if QK_K == 256
3529
3477
  void kernel_mul_mv_q3_K_f32_impl(
3530
3478
  device const void * src0,
3531
3479
  device const float * src1,
@@ -3684,84 +3632,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3684
3632
  }
3685
3633
  }
3686
3634
  }
3687
- #else
3688
- void kernel_mul_mv_q3_K_f32_impl(
3689
- device const void * src0,
3690
- device const float * src1,
3691
- device float * dst,
3692
- constant int64_t & ne00,
3693
- constant int64_t & ne01,
3694
- constant int64_t & ne02,
3695
- constant int64_t & ne10,
3696
- constant int64_t & ne12,
3697
- constant int64_t & ne0,
3698
- constant int64_t & ne1,
3699
- constant uint & r2,
3700
- constant uint & r3,
3701
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3702
- uint3 tgpig[[threadgroup_position_in_grid]],
3703
- uint tiisg[[thread_index_in_simdgroup]],
3704
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3705
-
3706
- const int nb = ne00/QK_K;
3707
-
3708
- const int64_t r0 = tgpig.x;
3709
- const int64_t r1 = tgpig.y;
3710
- const int64_t im = tgpig.z;
3711
-
3712
- const int row = 2 * r0 + sgitg;
3713
-
3714
- const uint i12 = im%ne12;
3715
- const uint i13 = im/ne12;
3716
-
3717
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3718
-
3719
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
3720
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3721
-
3722
- const int ix = tiisg/4;
3723
- const int il = 4 * (tiisg%4);// 0, 4, 8, 12
3724
- const int iq = il/8; // 0, 0, 1, 1
3725
- const int in = il%8; // 0, 4, 0, 4
3726
-
3727
- float2 sum = {0.f, 0.f};
3728
-
3729
- for (int i = ix; i < nb; i += 8) {
3730
-
3731
- const float d_all = (float)(x[i].d);
3732
-
3733
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
3734
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
3735
- device const uint16_t * s = (device const uint16_t *)(x[i].scales);
3736
- device const float * y = yy + i * QK_K + il;
3737
-
3738
- const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
3739
- const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
3740
- const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
3741
- const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
3742
-
3743
- for (int l = 0; l < 4; l += 2) {
3744
- const uint16_t hm = h[l/2] >> iq;
3745
- sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
3746
- + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
3747
- + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
3748
- + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
3749
- sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
3750
- + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
3751
- + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
3752
- + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
3753
- }
3754
-
3755
- }
3756
- const float sumf = sum[0] + sum[1] * 1.f/256.f;
3757
-
3758
- const float tot = simd_sum(sumf);
3759
- if (tiisg == 0) {
3760
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
3761
- }
3762
-
3763
- }
3764
- #endif
3765
3635
 
3766
3636
  [[host_name("kernel_mul_mv_q3_K_f32")]]
3767
3637
  kernel void kernel_mul_mv_q3_K_f32(
@@ -3791,7 +3661,6 @@ kernel void kernel_mul_mv_q3_K_f32(
3791
3661
  kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3792
3662
  }
3793
3663
 
3794
- #if QK_K == 256
3795
3664
  void kernel_mul_mv_q4_K_f32_impl(
3796
3665
  device const void * src0,
3797
3666
  device const float * src1,
@@ -3905,103 +3774,6 @@ void kernel_mul_mv_q4_K_f32_impl(
3905
3774
  }
3906
3775
  }
3907
3776
  }
3908
- #else
3909
- void kernel_mul_mv_q4_K_f32_impl(
3910
- device const void * src0,
3911
- device const float * src1,
3912
- device float * dst,
3913
- constant int64_t & ne00,
3914
- constant int64_t & ne01,
3915
- constant int64_t & ne02,
3916
- constant int64_t & ne10,
3917
- constant int64_t & ne12,
3918
- constant int64_t & ne0,
3919
- constant int64_t & ne1,
3920
- constant uint & r2,
3921
- constant uint & r3,
3922
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3923
- uint3 tgpig[[threadgroup_position_in_grid]],
3924
- uint tiisg[[thread_index_in_simdgroup]],
3925
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3926
-
3927
- const int ix = tiisg/4; // 0...7
3928
- const int it = tiisg%4; // 0...3
3929
-
3930
- const int nb = ne00/QK_K;
3931
- const int r0 = tgpig.x;
3932
- const int r1 = tgpig.y;
3933
- const int im = tgpig.z;
3934
- const int first_row = r0 * N_DST;
3935
- const int ib_row = first_row * nb;
3936
-
3937
- const uint i12 = im%ne12;
3938
- const uint i13 = im/ne12;
3939
-
3940
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3941
-
3942
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
3943
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3944
-
3945
- float yl[8];
3946
- float yh[8];
3947
- float sumf[N_DST]={0.f}, all_sum;
3948
-
3949
- const int step = sizeof(block_q4_K) * nb / 2;
3950
-
3951
- device const float * y4 = y + ix * QK_K + 8 * it;
3952
-
3953
- uint16_t sc16[4];
3954
-
3955
- for (int ib = ix; ib < nb; ib += 8) {
3956
-
3957
- float2 sumy = {0.f, 0.f};
3958
- for (int i = 0; i < 8; ++i) {
3959
- yl[i] = y4[i+ 0]; sumy[0] += yl[i];
3960
- yh[i] = y4[i+32]; sumy[1] += yh[i];
3961
- }
3962
-
3963
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
3964
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3965
- device const half * dh = x[ib].d;
3966
-
3967
- for (int row = 0; row < N_DST; row++) {
3968
-
3969
- sc16[0] = sc[0] & 0x000f;
3970
- sc16[1] = sc[0] & 0x0f00;
3971
- sc16[2] = sc[0] & 0x00f0;
3972
- sc16[3] = sc[0] & 0xf000;
3973
-
3974
- float2 acc1 = {0.f, 0.f};
3975
- float2 acc2 = {0.f, 0.f};
3976
- for (int i = 0; i < 8; i += 2) {
3977
- acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
3978
- acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
3979
- acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
3980
- acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
3981
- }
3982
-
3983
- float dall = dh[0];
3984
- float dmin = dh[1];
3985
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
3986
- (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
3987
- dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
3988
-
3989
- qs += step;
3990
- sc += step;
3991
- dh += step;
3992
- }
3993
-
3994
- y4 += 8 * QK_K;
3995
- }
3996
-
3997
- for (int row = 0; row < N_DST; ++row) {
3998
- all_sum = simd_sum(sumf[row]);
3999
- if (tiisg == 0) {
4000
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4001
- }
4002
- }
4003
- }
4004
- #endif
4005
3777
 
4006
3778
  [[host_name("kernel_mul_mv_q4_K_f32")]]
4007
3779
  kernel void kernel_mul_mv_q4_K_f32(
@@ -4069,8 +3841,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4069
3841
 
4070
3842
  const int step = sizeof(block_q5_K) * nb;
4071
3843
 
4072
- #if QK_K == 256
4073
- #
4074
3844
  float yl[16], yh[16];
4075
3845
 
4076
3846
  const uint16_t kmask1 = 0x3f3f;
@@ -4153,54 +3923,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4153
3923
  y1 += 4 * QK_K;
4154
3924
 
4155
3925
  }
4156
- #else
4157
- float yl[8], yh[8];
4158
-
4159
- const int il = 4 * (tiisg/8); // 0, 4, 8, 12
4160
- const int ix = tiisg%8;
4161
- const int iq = il/8; // 0, 0, 1, 1
4162
- const int in = il%8; // 0, 4, 0, 4
4163
-
4164
- device const float * y = yy + ix*QK_K + il;
4165
-
4166
- for (int i = ix; i < nb; i += 8) {
4167
-
4168
- for (int l = 0; l < 4; ++l) {
4169
- yl[l+0] = y[l+ 0];
4170
- yl[l+4] = y[l+16];
4171
- yh[l+0] = y[l+32];
4172
- yh[l+4] = y[l+48];
4173
- }
4174
-
4175
- device const half * dh = &x[i].d;
4176
- device const uint8_t * q = x[i].qs + il;
4177
- device const uint8_t * h = x[i].qh + in;
4178
- device const int8_t * s = x[i].scales;
4179
-
4180
- for (int row = 0; row < 2; ++row) {
4181
-
4182
- const float d = dh[0];
4183
-
4184
- float2 acc = {0.f, 0.f};
4185
- for (int l = 0; l < 4; ++l) {
4186
- const uint8_t hl = h[l] >> iq;
4187
- acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
4188
- + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
4189
- acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
4190
- + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
4191
- }
4192
- sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
4193
-
4194
- q += step;
4195
- h += step;
4196
- s += step;
4197
- dh += step/2;
4198
-
4199
- }
4200
-
4201
- y += 8 * QK_K;
4202
- }
4203
- #endif
4204
3926
 
4205
3927
  for (int row = 0; row < 2; ++row) {
4206
3928
  const float tot = simd_sum(sumf[row]);
@@ -4279,7 +4001,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4279
4001
 
4280
4002
  float sumf = 0;
4281
4003
 
4282
- #if QK_K == 256
4283
4004
  const int tid = tiisg/2;
4284
4005
  const int ix = tiisg%2;
4285
4006
  const int ip = tid/8; // 0 or 1
@@ -4315,30 +4036,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4315
4036
 
4316
4037
  }
4317
4038
 
4318
- #else
4319
- const int ix = tiisg/4;
4320
- const int il = 4*(tiisg%4);
4321
-
4322
- for (int i = ix; i < nb; i += 8) {
4323
- device const float * y = yy + i * QK_K + il;
4324
- device const uint8_t * ql = x[i].ql + il;
4325
- device const uint8_t * qh = x[i].qh + il;
4326
- device const int8_t * s = x[i].scales;
4327
-
4328
- const float d = x[i].d;
4329
-
4330
- float4 sums = {0.f, 0.f, 0.f, 0.f};
4331
- for (int l = 0; l < 4; ++l) {
4332
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
4333
- sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
4334
- sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
4335
- sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
4336
- }
4337
- sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
4338
- }
4339
-
4340
- #endif
4341
-
4342
4039
  const float tot = simd_sum(sumf);
4343
4040
  if (tiisg == 0) {
4344
4041
  dst[r1*ne0 + im*ne0*ne1 + row] = tot;
@@ -5172,9 +4869,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5172
4869
 
5173
4870
  device const float * y4 = y + 32 * ix;
5174
4871
 
5175
- #if QK_K != 64
5176
4872
  iq1m_scale_t scale;
5177
- #endif
5178
4873
 
5179
4874
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5180
4875
 
@@ -5195,10 +4890,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5195
4890
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5196
4891
 
5197
4892
  for (int row = 0; row < N_DST; row++) {
5198
-
5199
- #if QK_K != 64
5200
4893
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5201
- #endif
5202
4894
 
5203
4895
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5204
4896
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -5214,14 +4906,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
5214
4906
  }
5215
4907
  const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5216
4908
  const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5217
- #if QK_K == 64
5218
- const float d = (float) *((device const half *)(sc - 1));
5219
- sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
5220
- (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
5221
- #else
4909
+
5222
4910
  sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
5223
4911
  (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
5224
- #endif
5225
4912
 
5226
4913
  sc += nb*sizeof(block_iq1_m)/2;
5227
4914
  qs += nb*sizeof(block_iq1_m);
@@ -5333,7 +5020,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5333
5020
  }
5334
5021
  }
5335
5022
 
5336
- #if QK_K != 64
5337
5023
  void kernel_mul_mv_iq4_xs_f32_impl(
5338
5024
  device const void * src0,
5339
5025
  device const float * src1,
@@ -5428,7 +5114,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5428
5114
  }
5429
5115
  }
5430
5116
  }
5431
- #endif
5432
5117
 
5433
5118
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5434
5119
  kernel void kernel_mul_mv_iq1_s_f32(
@@ -5541,11 +5226,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5541
5226
  uint tiisg[[thread_index_in_simdgroup]],
5542
5227
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5543
5228
 
5544
- #if QK_K == 64
5545
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5546
- #else
5547
5229
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5548
- #endif
5549
5230
  }
5550
5231
 
5551
5232
  //============================= templates and their specializations =============================
@@ -5671,10 +5352,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
5671
5352
  float dl, ml;
5672
5353
  uint8_t sc = xb->scales[il];
5673
5354
 
5674
- #if QK_K == 256
5675
5355
  q = q + 32*(il/8) + 16*(il&1);
5676
5356
  il = (il/2)%4;
5677
- #endif
5357
+
5678
5358
  half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5679
5359
  uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5680
5360
  dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
@@ -5690,7 +5370,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5690
5370
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
5691
5371
  device const int8_t * scales = (device const int8_t *)xb->scales;
5692
5372
 
5693
- #if QK_K == 256
5694
5373
  q = q + 32 * (il/8) + 16 * (il&1);
5695
5374
  h = h + 16 * (il&1);
5696
5375
  uint8_t m = 1 << (il/2);
@@ -5711,17 +5390,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5711
5390
  for (int i = 0; i < 16; ++i) {
5712
5391
  reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
5713
5392
  }
5714
- #else
5715
- float kcoef = il&1 ? 1.f/16.f : 1.f;
5716
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
5717
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
5718
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5719
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5720
- uint8_t m = 1<<(il*2);
5721
- for (int i = 0; i < 16; ++i) {
5722
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
5723
- }
5724
- #endif
5725
5393
  }
5726
5394
 
5727
5395
  static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
@@ -5733,7 +5401,6 @@ template <typename type4x4>
5733
5401
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
5734
5402
  device const uchar * q = xb->qs;
5735
5403
 
5736
- #if QK_K == 256
5737
5404
  short is = (il/4) * 2;
5738
5405
  q = q + (il/4) * 32 + 16 * (il&1);
5739
5406
  il = il & 3;
@@ -5742,16 +5409,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
5742
5409
  const float min = xb->dmin;
5743
5410
  const float dl = d * sc[0];
5744
5411
  const float ml = min * sc[1];
5745
- #else
5746
- (void) get_scale_min_k4_just2;
5747
-
5748
- q = q + 16 * (il&1);
5749
- device const uint8_t * s = xb->scales;
5750
- device const half2 * dh = (device const half2 *)xb->d;
5751
- const float2 d = (float2)dh[0];
5752
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
5753
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
5754
- #endif
5412
+
5755
5413
  const ushort mask = il<2 ? 0x0F : 0xF0;
5756
5414
  for (int i = 0; i < 16; ++i) {
5757
5415
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
@@ -5763,7 +5421,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5763
5421
  device const uint8_t * q = xb->qs;
5764
5422
  device const uint8_t * qh = xb->qh;
5765
5423
 
5766
- #if QK_K == 256
5767
5424
  short is = (il/4) * 2;
5768
5425
  q = q + 32 * (il/4) + 16 * (il&1);
5769
5426
  qh = qh + 16 * (il&1);
@@ -5780,17 +5437,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5780
5437
  for (int i = 0; i < 16; ++i) {
5781
5438
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
5782
5439
  }
5783
- #else
5784
- q = q + 16 * (il&1);
5785
- device const int8_t * s = xb->scales;
5786
- const float dl = xb->d * s[il];
5787
- uint8_t m = 1<<(il*2);
5788
- const float coef = il<2 ? 1.f : 1.f/16.f;
5789
- const ushort mask = il<2 ? 0x0F : 0xF0;
5790
- for (int i = 0; i < 16; ++i) {
5791
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
5792
- }
5793
- #endif
5794
5440
  }
5795
5441
 
5796
5442
  template <typename type4x4>
@@ -5800,15 +5446,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
5800
5446
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
5801
5447
  device const int8_t * scales = (device const int8_t *)xb->scales;
5802
5448
 
5803
- #if QK_K == 256
5804
5449
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
5805
5450
  qh = qh + 32*(il/8) + 16*(il&1);
5806
5451
  float sc = scales[(il%2) + 2 * ((il/2))];
5807
5452
  il = (il/2) & 3;
5808
- #else
5809
- ql = ql + 16 * (il&1);
5810
- float sc = scales[il];
5811
- #endif
5453
+
5812
5454
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5813
5455
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
5814
5456
  const float coef = il>1 ? 1.f/16.f : 1.f;
@@ -5965,20 +5607,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
5965
5607
  const int ib32 = il/2;
5966
5608
  il = il%2;
5967
5609
  device const uint16_t * sc = (device const uint16_t *)xb->scales;
5968
- #if QK_K == 64
5969
- const float d = xb->d;
5970
- #else
5610
+
5971
5611
  iq1m_scale_t scale;
5972
5612
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5973
5613
  const float d = scale.f16;
5974
- #endif
5614
+
5975
5615
  device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5976
5616
  device const uint8_t * qh = xb->qh + 2*ib32 + il;
5977
- #if QK_K == 64
5978
- const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5979
- #else
5617
+
5980
5618
  const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5981
- #endif
5982
5619
  const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5983
5620
  const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5984
5621
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -6008,9 +5645,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
6008
5645
 
6009
5646
  template <typename type4x4>
6010
5647
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
6011
- #if QK_K == 64
6012
- dequantize_iq4_nl(xb, il, reg);
6013
- #else
6014
5648
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
6015
5649
  const int ib32 = il/2;
6016
5650
  il = il%2;
@@ -6027,7 +5661,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
6027
5661
  reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
6028
5662
  reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
6029
5663
  }
6030
- #endif
6031
5664
  }
6032
5665
 
6033
5666
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6532,11 +6165,7 @@ kernel void kernel_mul_mm_id(
6532
6165
  sgitg);
6533
6166
  }
6534
6167
 
6535
- #if QK_K == 256
6536
6168
  #define QK_NL 16
6537
- #else
6538
- #define QK_NL 4
6539
- #endif
6540
6169
 
6541
6170
  //
6542
6171
  // get rows
@@ -6576,11 +6205,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
6576
6205
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6577
6206
  template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6578
6207
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6579
- #if QK_K == 64
6580
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6581
- #else
6582
6208
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6583
- #endif
6584
6209
 
6585
6210
  //
6586
6211
  // matrix-matrix multiplication
@@ -6608,11 +6233,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
6608
6233
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6609
6234
  template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
6610
6235
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6611
- #if QK_K == 64
6612
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6613
- #else
6614
6236
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6615
- #endif
6616
6237
 
6617
6238
  //
6618
6239
  // indirect matrix-matrix multiplication
@@ -6640,11 +6261,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
6640
6261
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6641
6262
  template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
6642
6263
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6643
- #if QK_K == 64
6644
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6645
- #else
6646
6264
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6647
- #endif
6648
6265
 
6649
6266
  //
6650
6267
  // matrix-vector multiplication
@@ -6853,7 +6470,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
6853
6470
  template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6854
6471
  template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6855
6472
  template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6856
- #if QK_K != 64
6857
6473
  template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6858
- #endif
6859
6474