llama_cpp 0.15.2 → 0.15.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -168,6 +168,53 @@ kernel void kernel_div(
168
168
  }
169
169
  }
170
170
 
171
+ template<typename T>
172
+ kernel void kernel_repeat(
173
+ device const char * src0,
174
+ device char * dst,
175
+ constant int64_t & ne00,
176
+ constant int64_t & ne01,
177
+ constant int64_t & ne02,
178
+ constant int64_t & ne03,
179
+ constant uint64_t & nb00,
180
+ constant uint64_t & nb01,
181
+ constant uint64_t & nb02,
182
+ constant uint64_t & nb03,
183
+ constant int64_t & ne0,
184
+ constant int64_t & ne1,
185
+ constant int64_t & ne2,
186
+ constant int64_t & ne3,
187
+ constant uint64_t & nb0,
188
+ constant uint64_t & nb1,
189
+ constant uint64_t & nb2,
190
+ constant uint64_t & nb3,
191
+ uint3 tgpig[[threadgroup_position_in_grid]],
192
+ uint3 tpitg[[thread_position_in_threadgroup]],
193
+ uint3 ntg[[threads_per_threadgroup]]) {
194
+ const int64_t i3 = tgpig.z;
195
+ const int64_t i2 = tgpig.y;
196
+ const int64_t i1 = tgpig.x;
197
+
198
+ const int64_t i03 = i3 % ne03;
199
+ const int64_t i02 = i2 % ne02;
200
+ const int64_t i01 = i1 % ne01;
201
+
202
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204
+
205
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206
+ const int i00 = i0 % ne00;
207
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208
+ }
209
+ }
210
+
211
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212
+
213
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217
+
171
218
  // assumption: src1 is a row
172
219
  // broadcast src1 into src0
173
220
  kernel void kernel_add_row(
@@ -1640,6 +1687,7 @@ static void rope_yarn_corr_dims(
1640
1687
  typedef void (rope_t)(
1641
1688
  device const void * src0,
1642
1689
  device const int32_t * src1,
1690
+ device const float * src2,
1643
1691
  device float * dst,
1644
1692
  constant int64_t & ne00,
1645
1693
  constant int64_t & ne01,
@@ -1675,6 +1723,7 @@ template<typename T>
1675
1723
  kernel void kernel_rope(
1676
1724
  device const void * src0,
1677
1725
  device const int32_t * src1,
1726
+ device const float * src2,
1678
1727
  device float * dst,
1679
1728
  constant int64_t & ne00,
1680
1729
  constant int64_t & ne01,
@@ -1718,13 +1767,13 @@ kernel void kernel_rope(
1718
1767
 
1719
1768
  const int64_t p = pos[i2];
1720
1769
 
1721
- const float theta_0 = (float)p;
1770
+ const float theta_base = (float)p;
1722
1771
  const float inv_ndims = -1.f/n_dims;
1723
1772
 
1724
1773
  if (!is_neox) {
1725
1774
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1726
1776
 
1727
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1728
1777
  float cos_theta, sin_theta;
1729
1778
  rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1730
1779
 
@@ -1740,16 +1789,14 @@ kernel void kernel_rope(
1740
1789
  } else {
1741
1790
  for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1742
1791
  if (ic < n_dims) {
1743
- const int64_t ib = 0;
1792
+ const int64_t i0 = ic/2;
1744
1793
 
1745
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1746
- const float cur_rot = inv_ndims*ic - ib;
1794
+ const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1747
1795
 
1748
- const float theta = theta_0 * pow(freq_base, cur_rot);
1749
- float cos_theta, sin_theta;
1750
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1796
+ const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1751
1797
 
1752
- const int64_t i0 = ib*n_dims + ic/2;
1798
+ float cos_theta, sin_theta;
1799
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1753
1800
 
1754
1801
  device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
1802
  device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -2204,11 +2251,7 @@ kernel void kernel_flash_attn_ext_f16(
2204
2251
  // pointer to the mask
2205
2252
  device const half * mp = (device const half *) (mask + iq1*nb31);
2206
2253
 
2207
- // prepare diagonal scale matrix
2208
- simdgroup_float8x8 mscale(scale);
2209
-
2210
- // prepare diagonal slope matrix
2211
- simdgroup_float8x8 mslope(1.0f);
2254
+ float slope = 1.0f;
2212
2255
 
2213
2256
  // ALiBi
2214
2257
  if (max_bias > 0.0f) {
@@ -2217,7 +2260,7 @@ kernel void kernel_flash_attn_ext_f16(
2217
2260
  const float base = h < n_head_log2 ? m0 : m1;
2218
2261
  const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
2262
 
2220
- mslope = simdgroup_float8x8(pow(base, exph));
2263
+ slope = pow(base, exph);
2221
2264
  }
2222
2265
 
2223
2266
  // loop over the KV cache
@@ -2242,18 +2285,20 @@ kernel void kernel_flash_attn_ext_f16(
2242
2285
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2243
2286
  }
2244
2287
 
2288
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2289
+
2290
+ const short tx = tiisg%4;
2291
+ const short ty = tiisg/4;
2292
+
2245
2293
  if (mask != q) {
2246
2294
  // mqk = mqk*scale + mask*slope
2247
- simdgroup_half8x8 mm;
2248
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2249
- simdgroup_multiply(mm, mslope, mm);
2250
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2295
+ ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2296
+ ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2251
2297
  } else {
2252
2298
  // mqk = mqk*scale
2253
- simdgroup_multiply(mqk, mscale, mqk);
2299
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2300
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2254
2301
  }
2255
-
2256
- simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2257
2302
  }
2258
2303
  }
2259
2304
 
@@ -2416,7 +2461,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
2416
2461
  template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2417
2462
  template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2418
2463
  template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2419
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2464
+ //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2420
2465
 
2421
2466
  template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2422
2467
  kernel void kernel_flash_attn_ext_vec_f16(
@@ -2694,7 +2739,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2694
2739
  }
2695
2740
 
2696
2741
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2697
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2742
+ //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2698
2743
 
2699
2744
  kernel void kernel_cpy_f16_f16(
2700
2745
  device const half * src0,
@@ -2816,8 +2861,7 @@ kernel void kernel_cpy_f32_f16(
2816
2861
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2817
2862
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2818
2863
 
2819
- // TODO: is there a better way to handle -INFINITY?
2820
- dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2864
+ dst_data[i00] = src[0];
2821
2865
  }
2822
2866
  }
2823
2867
 
@@ -3318,31 +3362,30 @@ kernel void kernel_concat(
3318
3362
  constant uint64_t & nb1,
3319
3363
  constant uint64_t & nb2,
3320
3364
  constant uint64_t & nb3,
3365
+ constant int32_t & dim,
3321
3366
  uint3 tgpig[[threadgroup_position_in_grid]],
3322
3367
  uint3 tpitg[[thread_position_in_threadgroup]],
3323
3368
  uint3 ntg[[threads_per_threadgroup]]) {
3324
3369
 
3325
- const int64_t i03 = tgpig.z;
3326
- const int64_t i02 = tgpig.y;
3327
- const int64_t i01 = tgpig.x;
3370
+ const int64_t i3 = tgpig.z;
3371
+ const int64_t i2 = tgpig.y;
3372
+ const int64_t i1 = tgpig.x;
3328
3373
 
3329
- const int64_t i13 = i03 % ne13;
3330
- const int64_t i12 = i02 % ne12;
3331
- const int64_t i11 = i01 % ne11;
3374
+ int64_t o[4] = {0, 0, 0, 0};
3375
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
3332
3376
 
3333
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
3334
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
3335
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
3377
+ device const float * x;
3336
3378
 
3337
3379
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3338
- if (i02 < ne02) {
3339
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
3340
- src0_ptr += ntg.x*nb00;
3380
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3381
+ x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
3341
3382
  } else {
3342
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
3343
- src1_ptr += ntg.x*nb10;
3383
+ x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
3344
3384
  }
3345
- dst_ptr += ntg.x*nb0;
3385
+
3386
+ device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3387
+
3388
+ *y = *x;
3346
3389
  }
3347
3390
  }
3348
3391
 
@@ -3385,7 +3428,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3385
3428
 
3386
3429
  const int step = sizeof(block_q2_K) * nb;
3387
3430
 
3388
- #if QK_K == 256
3389
3431
  const int ix = tiisg/8; // 0...3
3390
3432
  const int it = tiisg%8; // 0...7
3391
3433
  const int iq = it/4; // 0 or 1
@@ -3437,57 +3479,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3437
3479
 
3438
3480
  y4 += 4 * QK_K;
3439
3481
  }
3440
- #else
3441
- const int ix = tiisg/2; // 0...15
3442
- const int it = tiisg%2; // 0...1
3443
-
3444
- device const float * y4 = y + ix * QK_K + 8 * it;
3445
-
3446
- for (int ib = ix; ib < nb; ib += 16) {
3447
-
3448
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
3449
- for (int i = 0; i < 8; ++i) {
3450
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
3451
- yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
3452
- yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
3453
- yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
3454
- }
3455
-
3456
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
3457
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3458
- device const half * dh = &x[ib].d;
3459
-
3460
- for (int row = 0; row < N_DST; row++) {
3461
-
3462
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
3463
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
3464
- for (int i = 0; i < 8; i += 2) {
3465
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
3466
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
3467
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
3468
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
3469
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
3470
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
3471
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
3472
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
3473
- }
3474
-
3475
- float dall = dh[0];
3476
- float dmin = dh[1];
3477
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
3478
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
3479
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
3480
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
3481
- dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
3482
-
3483
- qs += step/2;
3484
- sc += step;
3485
- dh += step/2;
3486
- }
3487
-
3488
- y4 += 16 * QK_K;
3489
- }
3490
- #endif
3491
3482
 
3492
3483
  for (int row = 0; row < N_DST; ++row) {
3493
3484
  all_sum = simd_sum(sumf[row]);
@@ -3525,7 +3516,6 @@ kernel void kernel_mul_mv_q2_K_f32(
3525
3516
  kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3526
3517
  }
3527
3518
 
3528
- #if QK_K == 256
3529
3519
  void kernel_mul_mv_q3_K_f32_impl(
3530
3520
  device const void * src0,
3531
3521
  device const float * src1,
@@ -3684,84 +3674,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3684
3674
  }
3685
3675
  }
3686
3676
  }
3687
- #else
3688
- void kernel_mul_mv_q3_K_f32_impl(
3689
- device const void * src0,
3690
- device const float * src1,
3691
- device float * dst,
3692
- constant int64_t & ne00,
3693
- constant int64_t & ne01,
3694
- constant int64_t & ne02,
3695
- constant int64_t & ne10,
3696
- constant int64_t & ne12,
3697
- constant int64_t & ne0,
3698
- constant int64_t & ne1,
3699
- constant uint & r2,
3700
- constant uint & r3,
3701
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3702
- uint3 tgpig[[threadgroup_position_in_grid]],
3703
- uint tiisg[[thread_index_in_simdgroup]],
3704
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3705
-
3706
- const int nb = ne00/QK_K;
3707
-
3708
- const int64_t r0 = tgpig.x;
3709
- const int64_t r1 = tgpig.y;
3710
- const int64_t im = tgpig.z;
3711
-
3712
- const int row = 2 * r0 + sgitg;
3713
-
3714
- const uint i12 = im%ne12;
3715
- const uint i13 = im/ne12;
3716
-
3717
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3718
-
3719
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
3720
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3721
-
3722
- const int ix = tiisg/4;
3723
- const int il = 4 * (tiisg%4);// 0, 4, 8, 12
3724
- const int iq = il/8; // 0, 0, 1, 1
3725
- const int in = il%8; // 0, 4, 0, 4
3726
-
3727
- float2 sum = {0.f, 0.f};
3728
-
3729
- for (int i = ix; i < nb; i += 8) {
3730
-
3731
- const float d_all = (float)(x[i].d);
3732
-
3733
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
3734
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
3735
- device const uint16_t * s = (device const uint16_t *)(x[i].scales);
3736
- device const float * y = yy + i * QK_K + il;
3737
-
3738
- const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
3739
- const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
3740
- const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
3741
- const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
3742
-
3743
- for (int l = 0; l < 4; l += 2) {
3744
- const uint16_t hm = h[l/2] >> iq;
3745
- sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
3746
- + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
3747
- + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
3748
- + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
3749
- sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
3750
- + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
3751
- + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
3752
- + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
3753
- }
3754
-
3755
- }
3756
- const float sumf = sum[0] + sum[1] * 1.f/256.f;
3757
-
3758
- const float tot = simd_sum(sumf);
3759
- if (tiisg == 0) {
3760
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
3761
- }
3762
-
3763
- }
3764
- #endif
3765
3677
 
3766
3678
  [[host_name("kernel_mul_mv_q3_K_f32")]]
3767
3679
  kernel void kernel_mul_mv_q3_K_f32(
@@ -3791,7 +3703,6 @@ kernel void kernel_mul_mv_q3_K_f32(
3791
3703
  kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3792
3704
  }
3793
3705
 
3794
- #if QK_K == 256
3795
3706
  void kernel_mul_mv_q4_K_f32_impl(
3796
3707
  device const void * src0,
3797
3708
  device const float * src1,
@@ -3905,103 +3816,6 @@ void kernel_mul_mv_q4_K_f32_impl(
3905
3816
  }
3906
3817
  }
3907
3818
  }
3908
- #else
3909
- void kernel_mul_mv_q4_K_f32_impl(
3910
- device const void * src0,
3911
- device const float * src1,
3912
- device float * dst,
3913
- constant int64_t & ne00,
3914
- constant int64_t & ne01,
3915
- constant int64_t & ne02,
3916
- constant int64_t & ne10,
3917
- constant int64_t & ne12,
3918
- constant int64_t & ne0,
3919
- constant int64_t & ne1,
3920
- constant uint & r2,
3921
- constant uint & r3,
3922
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3923
- uint3 tgpig[[threadgroup_position_in_grid]],
3924
- uint tiisg[[thread_index_in_simdgroup]],
3925
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3926
-
3927
- const int ix = tiisg/4; // 0...7
3928
- const int it = tiisg%4; // 0...3
3929
-
3930
- const int nb = ne00/QK_K;
3931
- const int r0 = tgpig.x;
3932
- const int r1 = tgpig.y;
3933
- const int im = tgpig.z;
3934
- const int first_row = r0 * N_DST;
3935
- const int ib_row = first_row * nb;
3936
-
3937
- const uint i12 = im%ne12;
3938
- const uint i13 = im/ne12;
3939
-
3940
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3941
-
3942
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
3943
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3944
-
3945
- float yl[8];
3946
- float yh[8];
3947
- float sumf[N_DST]={0.f}, all_sum;
3948
-
3949
- const int step = sizeof(block_q4_K) * nb / 2;
3950
-
3951
- device const float * y4 = y + ix * QK_K + 8 * it;
3952
-
3953
- uint16_t sc16[4];
3954
-
3955
- for (int ib = ix; ib < nb; ib += 8) {
3956
-
3957
- float2 sumy = {0.f, 0.f};
3958
- for (int i = 0; i < 8; ++i) {
3959
- yl[i] = y4[i+ 0]; sumy[0] += yl[i];
3960
- yh[i] = y4[i+32]; sumy[1] += yh[i];
3961
- }
3962
-
3963
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
3964
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3965
- device const half * dh = x[ib].d;
3966
-
3967
- for (int row = 0; row < N_DST; row++) {
3968
-
3969
- sc16[0] = sc[0] & 0x000f;
3970
- sc16[1] = sc[0] & 0x0f00;
3971
- sc16[2] = sc[0] & 0x00f0;
3972
- sc16[3] = sc[0] & 0xf000;
3973
-
3974
- float2 acc1 = {0.f, 0.f};
3975
- float2 acc2 = {0.f, 0.f};
3976
- for (int i = 0; i < 8; i += 2) {
3977
- acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
3978
- acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
3979
- acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
3980
- acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
3981
- }
3982
-
3983
- float dall = dh[0];
3984
- float dmin = dh[1];
3985
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
3986
- (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
3987
- dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
3988
-
3989
- qs += step;
3990
- sc += step;
3991
- dh += step;
3992
- }
3993
-
3994
- y4 += 8 * QK_K;
3995
- }
3996
-
3997
- for (int row = 0; row < N_DST; ++row) {
3998
- all_sum = simd_sum(sumf[row]);
3999
- if (tiisg == 0) {
4000
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4001
- }
4002
- }
4003
- }
4004
- #endif
4005
3819
 
4006
3820
  [[host_name("kernel_mul_mv_q4_K_f32")]]
4007
3821
  kernel void kernel_mul_mv_q4_K_f32(
@@ -4069,8 +3883,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4069
3883
 
4070
3884
  const int step = sizeof(block_q5_K) * nb;
4071
3885
 
4072
- #if QK_K == 256
4073
- #
4074
3886
  float yl[16], yh[16];
4075
3887
 
4076
3888
  const uint16_t kmask1 = 0x3f3f;
@@ -4153,54 +3965,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4153
3965
  y1 += 4 * QK_K;
4154
3966
 
4155
3967
  }
4156
- #else
4157
- float yl[8], yh[8];
4158
-
4159
- const int il = 4 * (tiisg/8); // 0, 4, 8, 12
4160
- const int ix = tiisg%8;
4161
- const int iq = il/8; // 0, 0, 1, 1
4162
- const int in = il%8; // 0, 4, 0, 4
4163
-
4164
- device const float * y = yy + ix*QK_K + il;
4165
-
4166
- for (int i = ix; i < nb; i += 8) {
4167
-
4168
- for (int l = 0; l < 4; ++l) {
4169
- yl[l+0] = y[l+ 0];
4170
- yl[l+4] = y[l+16];
4171
- yh[l+0] = y[l+32];
4172
- yh[l+4] = y[l+48];
4173
- }
4174
-
4175
- device const half * dh = &x[i].d;
4176
- device const uint8_t * q = x[i].qs + il;
4177
- device const uint8_t * h = x[i].qh + in;
4178
- device const int8_t * s = x[i].scales;
4179
-
4180
- for (int row = 0; row < 2; ++row) {
4181
-
4182
- const float d = dh[0];
4183
-
4184
- float2 acc = {0.f, 0.f};
4185
- for (int l = 0; l < 4; ++l) {
4186
- const uint8_t hl = h[l] >> iq;
4187
- acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
4188
- + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
4189
- acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
4190
- + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
4191
- }
4192
- sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
4193
-
4194
- q += step;
4195
- h += step;
4196
- s += step;
4197
- dh += step/2;
4198
-
4199
- }
4200
-
4201
- y += 8 * QK_K;
4202
- }
4203
- #endif
4204
3968
 
4205
3969
  for (int row = 0; row < 2; ++row) {
4206
3970
  const float tot = simd_sum(sumf[row]);
@@ -4279,7 +4043,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4279
4043
 
4280
4044
  float sumf = 0;
4281
4045
 
4282
- #if QK_K == 256
4283
4046
  const int tid = tiisg/2;
4284
4047
  const int ix = tiisg%2;
4285
4048
  const int ip = tid/8; // 0 or 1
@@ -4315,30 +4078,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4315
4078
 
4316
4079
  }
4317
4080
 
4318
- #else
4319
- const int ix = tiisg/4;
4320
- const int il = 4*(tiisg%4);
4321
-
4322
- for (int i = ix; i < nb; i += 8) {
4323
- device const float * y = yy + i * QK_K + il;
4324
- device const uint8_t * ql = x[i].ql + il;
4325
- device const uint8_t * qh = x[i].qh + il;
4326
- device const int8_t * s = x[i].scales;
4327
-
4328
- const float d = x[i].d;
4329
-
4330
- float4 sums = {0.f, 0.f, 0.f, 0.f};
4331
- for (int l = 0; l < 4; ++l) {
4332
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
4333
- sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
4334
- sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
4335
- sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
4336
- }
4337
- sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
4338
- }
4339
-
4340
- #endif
4341
-
4342
4081
  const float tot = simd_sum(sumf);
4343
4082
  if (tiisg == 0) {
4344
4083
  dst[r1*ne0 + im*ne0*ne1 + row] = tot;
@@ -5172,9 +4911,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5172
4911
 
5173
4912
  device const float * y4 = y + 32 * ix;
5174
4913
 
5175
- #if QK_K != 64
5176
4914
  iq1m_scale_t scale;
5177
- #endif
5178
4915
 
5179
4916
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5180
4917
 
@@ -5195,10 +4932,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5195
4932
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5196
4933
 
5197
4934
  for (int row = 0; row < N_DST; row++) {
5198
-
5199
- #if QK_K != 64
5200
4935
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5201
- #endif
5202
4936
 
5203
4937
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5204
4938
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -5214,14 +4948,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
5214
4948
  }
5215
4949
  const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5216
4950
  const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5217
- #if QK_K == 64
5218
- const float d = (float) *((device const half *)(sc - 1));
5219
- sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
5220
- (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
5221
- #else
4951
+
5222
4952
  sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
5223
4953
  (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
5224
- #endif
5225
4954
 
5226
4955
  sc += nb*sizeof(block_iq1_m)/2;
5227
4956
  qs += nb*sizeof(block_iq1_m);
@@ -5333,7 +5062,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5333
5062
  }
5334
5063
  }
5335
5064
 
5336
- #if QK_K != 64
5337
5065
  void kernel_mul_mv_iq4_xs_f32_impl(
5338
5066
  device const void * src0,
5339
5067
  device const float * src1,
@@ -5428,7 +5156,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5428
5156
  }
5429
5157
  }
5430
5158
  }
5431
- #endif
5432
5159
 
5433
5160
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5434
5161
  kernel void kernel_mul_mv_iq1_s_f32(
@@ -5541,11 +5268,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5541
5268
  uint tiisg[[thread_index_in_simdgroup]],
5542
5269
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5543
5270
 
5544
- #if QK_K == 64
5545
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5546
- #else
5547
5271
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5548
- #endif
5549
5272
  }
5550
5273
 
5551
5274
  //============================= templates and their specializations =============================
@@ -5671,10 +5394,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
5671
5394
  float dl, ml;
5672
5395
  uint8_t sc = xb->scales[il];
5673
5396
 
5674
- #if QK_K == 256
5675
5397
  q = q + 32*(il/8) + 16*(il&1);
5676
5398
  il = (il/2)%4;
5677
- #endif
5399
+
5678
5400
  half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5679
5401
  uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5680
5402
  dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
@@ -5690,7 +5412,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5690
5412
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
5691
5413
  device const int8_t * scales = (device const int8_t *)xb->scales;
5692
5414
 
5693
- #if QK_K == 256
5694
5415
  q = q + 32 * (il/8) + 16 * (il&1);
5695
5416
  h = h + 16 * (il&1);
5696
5417
  uint8_t m = 1 << (il/2);
@@ -5711,17 +5432,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5711
5432
  for (int i = 0; i < 16; ++i) {
5712
5433
  reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
5713
5434
  }
5714
- #else
5715
- float kcoef = il&1 ? 1.f/16.f : 1.f;
5716
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
5717
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
5718
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5719
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5720
- uint8_t m = 1<<(il*2);
5721
- for (int i = 0; i < 16; ++i) {
5722
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
5723
- }
5724
- #endif
5725
5435
  }
5726
5436
 
5727
5437
  static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
@@ -5733,7 +5443,6 @@ template <typename type4x4>
5733
5443
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
5734
5444
  device const uchar * q = xb->qs;
5735
5445
 
5736
- #if QK_K == 256
5737
5446
  short is = (il/4) * 2;
5738
5447
  q = q + (il/4) * 32 + 16 * (il&1);
5739
5448
  il = il & 3;
@@ -5742,16 +5451,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
5742
5451
  const float min = xb->dmin;
5743
5452
  const float dl = d * sc[0];
5744
5453
  const float ml = min * sc[1];
5745
- #else
5746
- (void) get_scale_min_k4_just2;
5747
-
5748
- q = q + 16 * (il&1);
5749
- device const uint8_t * s = xb->scales;
5750
- device const half2 * dh = (device const half2 *)xb->d;
5751
- const float2 d = (float2)dh[0];
5752
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
5753
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
5754
- #endif
5454
+
5755
5455
  const ushort mask = il<2 ? 0x0F : 0xF0;
5756
5456
  for (int i = 0; i < 16; ++i) {
5757
5457
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
@@ -5763,7 +5463,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5763
5463
  device const uint8_t * q = xb->qs;
5764
5464
  device const uint8_t * qh = xb->qh;
5765
5465
 
5766
- #if QK_K == 256
5767
5466
  short is = (il/4) * 2;
5768
5467
  q = q + 32 * (il/4) + 16 * (il&1);
5769
5468
  qh = qh + 16 * (il&1);
@@ -5780,17 +5479,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5780
5479
  for (int i = 0; i < 16; ++i) {
5781
5480
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
5782
5481
  }
5783
- #else
5784
- q = q + 16 * (il&1);
5785
- device const int8_t * s = xb->scales;
5786
- const float dl = xb->d * s[il];
5787
- uint8_t m = 1<<(il*2);
5788
- const float coef = il<2 ? 1.f : 1.f/16.f;
5789
- const ushort mask = il<2 ? 0x0F : 0xF0;
5790
- for (int i = 0; i < 16; ++i) {
5791
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
5792
- }
5793
- #endif
5794
5482
  }
5795
5483
 
5796
5484
  template <typename type4x4>
@@ -5800,15 +5488,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
5800
5488
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
5801
5489
  device const int8_t * scales = (device const int8_t *)xb->scales;
5802
5490
 
5803
- #if QK_K == 256
5804
5491
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
5805
5492
  qh = qh + 32*(il/8) + 16*(il&1);
5806
5493
  float sc = scales[(il%2) + 2 * ((il/2))];
5807
5494
  il = (il/2) & 3;
5808
- #else
5809
- ql = ql + 16 * (il&1);
5810
- float sc = scales[il];
5811
- #endif
5495
+
5812
5496
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5813
5497
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
5814
5498
  const float coef = il>1 ? 1.f/16.f : 1.f;
@@ -5965,20 +5649,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
5965
5649
  const int ib32 = il/2;
5966
5650
  il = il%2;
5967
5651
  device const uint16_t * sc = (device const uint16_t *)xb->scales;
5968
- #if QK_K == 64
5969
- const float d = xb->d;
5970
- #else
5652
+
5971
5653
  iq1m_scale_t scale;
5972
5654
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5973
5655
  const float d = scale.f16;
5974
- #endif
5656
+
5975
5657
  device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5976
5658
  device const uint8_t * qh = xb->qh + 2*ib32 + il;
5977
- #if QK_K == 64
5978
- const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5979
- #else
5659
+
5980
5660
  const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5981
- #endif
5982
5661
  const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5983
5662
  const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5984
5663
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -6008,9 +5687,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
6008
5687
 
6009
5688
  template <typename type4x4>
6010
5689
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
6011
- #if QK_K == 64
6012
- dequantize_iq4_nl(xb, il, reg);
6013
- #else
6014
5690
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
6015
5691
  const int ib32 = il/2;
6016
5692
  il = il%2;
@@ -6027,7 +5703,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
6027
5703
  reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
6028
5704
  reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
6029
5705
  }
6030
- #endif
6031
5706
  }
6032
5707
 
6033
5708
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6532,11 +6207,7 @@ kernel void kernel_mul_mm_id(
6532
6207
  sgitg);
6533
6208
  }
6534
6209
 
6535
- #if QK_K == 256
6536
6210
  #define QK_NL 16
6537
- #else
6538
- #define QK_NL 4
6539
- #endif
6540
6211
 
6541
6212
  //
6542
6213
  // get rows
@@ -6576,11 +6247,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
6576
6247
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6577
6248
  template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6578
6249
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6579
- #if QK_K == 64
6580
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6581
- #else
6582
6250
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6583
- #endif
6584
6251
 
6585
6252
  //
6586
6253
  // matrix-matrix multiplication
@@ -6608,11 +6275,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
6608
6275
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6609
6276
  template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
6610
6277
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6611
- #if QK_K == 64
6612
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6613
- #else
6614
6278
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6615
- #endif
6616
6279
 
6617
6280
  //
6618
6281
  // indirect matrix-matrix multiplication
@@ -6640,11 +6303,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
6640
6303
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6641
6304
  template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
6642
6305
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6643
- #if QK_K == 64
6644
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6645
- #else
6646
6306
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6647
- #endif
6648
6307
 
6649
6308
  //
6650
6309
  // matrix-vector multiplication
@@ -6853,7 +6512,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
6853
6512
  template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6854
6513
  template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6855
6514
  template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6856
- #if QK_K != 64
6857
6515
  template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6858
- #endif
6859
6516