llama_cpp 0.15.2 → 0.15.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -168,6 +168,53 @@ kernel void kernel_div(
168
168
  }
169
169
  }
170
170
 
171
+ template<typename T>
172
+ kernel void kernel_repeat(
173
+ device const char * src0,
174
+ device char * dst,
175
+ constant int64_t & ne00,
176
+ constant int64_t & ne01,
177
+ constant int64_t & ne02,
178
+ constant int64_t & ne03,
179
+ constant uint64_t & nb00,
180
+ constant uint64_t & nb01,
181
+ constant uint64_t & nb02,
182
+ constant uint64_t & nb03,
183
+ constant int64_t & ne0,
184
+ constant int64_t & ne1,
185
+ constant int64_t & ne2,
186
+ constant int64_t & ne3,
187
+ constant uint64_t & nb0,
188
+ constant uint64_t & nb1,
189
+ constant uint64_t & nb2,
190
+ constant uint64_t & nb3,
191
+ uint3 tgpig[[threadgroup_position_in_grid]],
192
+ uint3 tpitg[[thread_position_in_threadgroup]],
193
+ uint3 ntg[[threads_per_threadgroup]]) {
194
+ const int64_t i3 = tgpig.z;
195
+ const int64_t i2 = tgpig.y;
196
+ const int64_t i1 = tgpig.x;
197
+
198
+ const int64_t i03 = i3 % ne03;
199
+ const int64_t i02 = i2 % ne02;
200
+ const int64_t i01 = i1 % ne01;
201
+
202
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204
+
205
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206
+ const int i00 = i0 % ne00;
207
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208
+ }
209
+ }
210
+
211
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212
+
213
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217
+
171
218
  // assumption: src1 is a row
172
219
  // broadcast src1 into src0
173
220
  kernel void kernel_add_row(
@@ -1640,6 +1687,7 @@ static void rope_yarn_corr_dims(
1640
1687
  typedef void (rope_t)(
1641
1688
  device const void * src0,
1642
1689
  device const int32_t * src1,
1690
+ device const float * src2,
1643
1691
  device float * dst,
1644
1692
  constant int64_t & ne00,
1645
1693
  constant int64_t & ne01,
@@ -1675,6 +1723,7 @@ template<typename T>
1675
1723
  kernel void kernel_rope(
1676
1724
  device const void * src0,
1677
1725
  device const int32_t * src1,
1726
+ device const float * src2,
1678
1727
  device float * dst,
1679
1728
  constant int64_t & ne00,
1680
1729
  constant int64_t & ne01,
@@ -1718,13 +1767,13 @@ kernel void kernel_rope(
1718
1767
 
1719
1768
  const int64_t p = pos[i2];
1720
1769
 
1721
- const float theta_0 = (float)p;
1770
+ const float theta_base = (float)p;
1722
1771
  const float inv_ndims = -1.f/n_dims;
1723
1772
 
1724
1773
  if (!is_neox) {
1725
1774
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1726
1776
 
1727
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1728
1777
  float cos_theta, sin_theta;
1729
1778
  rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1730
1779
 
@@ -1740,16 +1789,14 @@ kernel void kernel_rope(
1740
1789
  } else {
1741
1790
  for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1742
1791
  if (ic < n_dims) {
1743
- const int64_t ib = 0;
1792
+ const int64_t i0 = ic/2;
1744
1793
 
1745
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1746
- const float cur_rot = inv_ndims*ic - ib;
1794
+ const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1747
1795
 
1748
- const float theta = theta_0 * pow(freq_base, cur_rot);
1749
- float cos_theta, sin_theta;
1750
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1796
+ const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1751
1797
 
1752
- const int64_t i0 = ib*n_dims + ic/2;
1798
+ float cos_theta, sin_theta;
1799
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1753
1800
 
1754
1801
  device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
1802
  device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -2204,11 +2251,7 @@ kernel void kernel_flash_attn_ext_f16(
2204
2251
  // pointer to the mask
2205
2252
  device const half * mp = (device const half *) (mask + iq1*nb31);
2206
2253
 
2207
- // prepare diagonal scale matrix
2208
- simdgroup_float8x8 mscale(scale);
2209
-
2210
- // prepare diagonal slope matrix
2211
- simdgroup_float8x8 mslope(1.0f);
2254
+ float slope = 1.0f;
2212
2255
 
2213
2256
  // ALiBi
2214
2257
  if (max_bias > 0.0f) {
@@ -2217,7 +2260,7 @@ kernel void kernel_flash_attn_ext_f16(
2217
2260
  const float base = h < n_head_log2 ? m0 : m1;
2218
2261
  const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
2262
 
2220
- mslope = simdgroup_float8x8(pow(base, exph));
2263
+ slope = pow(base, exph);
2221
2264
  }
2222
2265
 
2223
2266
  // loop over the KV cache
@@ -2242,18 +2285,20 @@ kernel void kernel_flash_attn_ext_f16(
2242
2285
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2243
2286
  }
2244
2287
 
2288
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2289
+
2290
+ const short tx = tiisg%4;
2291
+ const short ty = tiisg/4;
2292
+
2245
2293
  if (mask != q) {
2246
2294
  // mqk = mqk*scale + mask*slope
2247
- simdgroup_half8x8 mm;
2248
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2249
- simdgroup_multiply(mm, mslope, mm);
2250
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2295
+ ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2296
+ ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2251
2297
  } else {
2252
2298
  // mqk = mqk*scale
2253
- simdgroup_multiply(mqk, mscale, mqk);
2299
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2300
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2254
2301
  }
2255
-
2256
- simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2257
2302
  }
2258
2303
  }
2259
2304
 
@@ -2416,7 +2461,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
2416
2461
  template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2417
2462
  template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2418
2463
  template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2419
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2464
+ //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2420
2465
 
2421
2466
  template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2422
2467
  kernel void kernel_flash_attn_ext_vec_f16(
@@ -2694,7 +2739,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2694
2739
  }
2695
2740
 
2696
2741
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2697
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2742
+ //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2698
2743
 
2699
2744
  kernel void kernel_cpy_f16_f16(
2700
2745
  device const half * src0,
@@ -2816,8 +2861,7 @@ kernel void kernel_cpy_f32_f16(
2816
2861
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2817
2862
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2818
2863
 
2819
- // TODO: is there a better way to handle -INFINITY?
2820
- dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2864
+ dst_data[i00] = src[0];
2821
2865
  }
2822
2866
  }
2823
2867
 
@@ -3318,31 +3362,30 @@ kernel void kernel_concat(
3318
3362
  constant uint64_t & nb1,
3319
3363
  constant uint64_t & nb2,
3320
3364
  constant uint64_t & nb3,
3365
+ constant int32_t & dim,
3321
3366
  uint3 tgpig[[threadgroup_position_in_grid]],
3322
3367
  uint3 tpitg[[thread_position_in_threadgroup]],
3323
3368
  uint3 ntg[[threads_per_threadgroup]]) {
3324
3369
 
3325
- const int64_t i03 = tgpig.z;
3326
- const int64_t i02 = tgpig.y;
3327
- const int64_t i01 = tgpig.x;
3370
+ const int64_t i3 = tgpig.z;
3371
+ const int64_t i2 = tgpig.y;
3372
+ const int64_t i1 = tgpig.x;
3328
3373
 
3329
- const int64_t i13 = i03 % ne13;
3330
- const int64_t i12 = i02 % ne12;
3331
- const int64_t i11 = i01 % ne11;
3374
+ int64_t o[4] = {0, 0, 0, 0};
3375
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
3332
3376
 
3333
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
3334
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
3335
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
3377
+ device const float * x;
3336
3378
 
3337
3379
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3338
- if (i02 < ne02) {
3339
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
3340
- src0_ptr += ntg.x*nb00;
3380
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3381
+ x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
3341
3382
  } else {
3342
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
3343
- src1_ptr += ntg.x*nb10;
3383
+ x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
3344
3384
  }
3345
- dst_ptr += ntg.x*nb0;
3385
+
3386
+ device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3387
+
3388
+ *y = *x;
3346
3389
  }
3347
3390
  }
3348
3391
 
@@ -3385,7 +3428,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3385
3428
 
3386
3429
  const int step = sizeof(block_q2_K) * nb;
3387
3430
 
3388
- #if QK_K == 256
3389
3431
  const int ix = tiisg/8; // 0...3
3390
3432
  const int it = tiisg%8; // 0...7
3391
3433
  const int iq = it/4; // 0 or 1
@@ -3437,57 +3479,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3437
3479
 
3438
3480
  y4 += 4 * QK_K;
3439
3481
  }
3440
- #else
3441
- const int ix = tiisg/2; // 0...15
3442
- const int it = tiisg%2; // 0...1
3443
-
3444
- device const float * y4 = y + ix * QK_K + 8 * it;
3445
-
3446
- for (int ib = ix; ib < nb; ib += 16) {
3447
-
3448
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
3449
- for (int i = 0; i < 8; ++i) {
3450
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
3451
- yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
3452
- yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
3453
- yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
3454
- }
3455
-
3456
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
3457
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3458
- device const half * dh = &x[ib].d;
3459
-
3460
- for (int row = 0; row < N_DST; row++) {
3461
-
3462
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
3463
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
3464
- for (int i = 0; i < 8; i += 2) {
3465
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
3466
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
3467
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
3468
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
3469
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
3470
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
3471
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
3472
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
3473
- }
3474
-
3475
- float dall = dh[0];
3476
- float dmin = dh[1];
3477
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
3478
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
3479
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
3480
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
3481
- dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
3482
-
3483
- qs += step/2;
3484
- sc += step;
3485
- dh += step/2;
3486
- }
3487
-
3488
- y4 += 16 * QK_K;
3489
- }
3490
- #endif
3491
3482
 
3492
3483
  for (int row = 0; row < N_DST; ++row) {
3493
3484
  all_sum = simd_sum(sumf[row]);
@@ -3525,7 +3516,6 @@ kernel void kernel_mul_mv_q2_K_f32(
3525
3516
  kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3526
3517
  }
3527
3518
 
3528
- #if QK_K == 256
3529
3519
  void kernel_mul_mv_q3_K_f32_impl(
3530
3520
  device const void * src0,
3531
3521
  device const float * src1,
@@ -3684,84 +3674,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3684
3674
  }
3685
3675
  }
3686
3676
  }
3687
- #else
3688
- void kernel_mul_mv_q3_K_f32_impl(
3689
- device const void * src0,
3690
- device const float * src1,
3691
- device float * dst,
3692
- constant int64_t & ne00,
3693
- constant int64_t & ne01,
3694
- constant int64_t & ne02,
3695
- constant int64_t & ne10,
3696
- constant int64_t & ne12,
3697
- constant int64_t & ne0,
3698
- constant int64_t & ne1,
3699
- constant uint & r2,
3700
- constant uint & r3,
3701
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3702
- uint3 tgpig[[threadgroup_position_in_grid]],
3703
- uint tiisg[[thread_index_in_simdgroup]],
3704
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3705
-
3706
- const int nb = ne00/QK_K;
3707
-
3708
- const int64_t r0 = tgpig.x;
3709
- const int64_t r1 = tgpig.y;
3710
- const int64_t im = tgpig.z;
3711
-
3712
- const int row = 2 * r0 + sgitg;
3713
-
3714
- const uint i12 = im%ne12;
3715
- const uint i13 = im/ne12;
3716
-
3717
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3718
-
3719
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
3720
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3721
-
3722
- const int ix = tiisg/4;
3723
- const int il = 4 * (tiisg%4);// 0, 4, 8, 12
3724
- const int iq = il/8; // 0, 0, 1, 1
3725
- const int in = il%8; // 0, 4, 0, 4
3726
-
3727
- float2 sum = {0.f, 0.f};
3728
-
3729
- for (int i = ix; i < nb; i += 8) {
3730
-
3731
- const float d_all = (float)(x[i].d);
3732
-
3733
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
3734
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
3735
- device const uint16_t * s = (device const uint16_t *)(x[i].scales);
3736
- device const float * y = yy + i * QK_K + il;
3737
-
3738
- const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
3739
- const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
3740
- const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
3741
- const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
3742
-
3743
- for (int l = 0; l < 4; l += 2) {
3744
- const uint16_t hm = h[l/2] >> iq;
3745
- sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
3746
- + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
3747
- + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
3748
- + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
3749
- sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
3750
- + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
3751
- + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
3752
- + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
3753
- }
3754
-
3755
- }
3756
- const float sumf = sum[0] + sum[1] * 1.f/256.f;
3757
-
3758
- const float tot = simd_sum(sumf);
3759
- if (tiisg == 0) {
3760
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
3761
- }
3762
-
3763
- }
3764
- #endif
3765
3677
 
3766
3678
  [[host_name("kernel_mul_mv_q3_K_f32")]]
3767
3679
  kernel void kernel_mul_mv_q3_K_f32(
@@ -3791,7 +3703,6 @@ kernel void kernel_mul_mv_q3_K_f32(
3791
3703
  kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3792
3704
  }
3793
3705
 
3794
- #if QK_K == 256
3795
3706
  void kernel_mul_mv_q4_K_f32_impl(
3796
3707
  device const void * src0,
3797
3708
  device const float * src1,
@@ -3905,103 +3816,6 @@ void kernel_mul_mv_q4_K_f32_impl(
3905
3816
  }
3906
3817
  }
3907
3818
  }
3908
- #else
3909
- void kernel_mul_mv_q4_K_f32_impl(
3910
- device const void * src0,
3911
- device const float * src1,
3912
- device float * dst,
3913
- constant int64_t & ne00,
3914
- constant int64_t & ne01,
3915
- constant int64_t & ne02,
3916
- constant int64_t & ne10,
3917
- constant int64_t & ne12,
3918
- constant int64_t & ne0,
3919
- constant int64_t & ne1,
3920
- constant uint & r2,
3921
- constant uint & r3,
3922
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3923
- uint3 tgpig[[threadgroup_position_in_grid]],
3924
- uint tiisg[[thread_index_in_simdgroup]],
3925
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3926
-
3927
- const int ix = tiisg/4; // 0...7
3928
- const int it = tiisg%4; // 0...3
3929
-
3930
- const int nb = ne00/QK_K;
3931
- const int r0 = tgpig.x;
3932
- const int r1 = tgpig.y;
3933
- const int im = tgpig.z;
3934
- const int first_row = r0 * N_DST;
3935
- const int ib_row = first_row * nb;
3936
-
3937
- const uint i12 = im%ne12;
3938
- const uint i13 = im/ne12;
3939
-
3940
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3941
-
3942
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
3943
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3944
-
3945
- float yl[8];
3946
- float yh[8];
3947
- float sumf[N_DST]={0.f}, all_sum;
3948
-
3949
- const int step = sizeof(block_q4_K) * nb / 2;
3950
-
3951
- device const float * y4 = y + ix * QK_K + 8 * it;
3952
-
3953
- uint16_t sc16[4];
3954
-
3955
- for (int ib = ix; ib < nb; ib += 8) {
3956
-
3957
- float2 sumy = {0.f, 0.f};
3958
- for (int i = 0; i < 8; ++i) {
3959
- yl[i] = y4[i+ 0]; sumy[0] += yl[i];
3960
- yh[i] = y4[i+32]; sumy[1] += yh[i];
3961
- }
3962
-
3963
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
3964
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3965
- device const half * dh = x[ib].d;
3966
-
3967
- for (int row = 0; row < N_DST; row++) {
3968
-
3969
- sc16[0] = sc[0] & 0x000f;
3970
- sc16[1] = sc[0] & 0x0f00;
3971
- sc16[2] = sc[0] & 0x00f0;
3972
- sc16[3] = sc[0] & 0xf000;
3973
-
3974
- float2 acc1 = {0.f, 0.f};
3975
- float2 acc2 = {0.f, 0.f};
3976
- for (int i = 0; i < 8; i += 2) {
3977
- acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
3978
- acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
3979
- acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
3980
- acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
3981
- }
3982
-
3983
- float dall = dh[0];
3984
- float dmin = dh[1];
3985
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
3986
- (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
3987
- dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
3988
-
3989
- qs += step;
3990
- sc += step;
3991
- dh += step;
3992
- }
3993
-
3994
- y4 += 8 * QK_K;
3995
- }
3996
-
3997
- for (int row = 0; row < N_DST; ++row) {
3998
- all_sum = simd_sum(sumf[row]);
3999
- if (tiisg == 0) {
4000
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4001
- }
4002
- }
4003
- }
4004
- #endif
4005
3819
 
4006
3820
  [[host_name("kernel_mul_mv_q4_K_f32")]]
4007
3821
  kernel void kernel_mul_mv_q4_K_f32(
@@ -4069,8 +3883,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4069
3883
 
4070
3884
  const int step = sizeof(block_q5_K) * nb;
4071
3885
 
4072
- #if QK_K == 256
4073
- #
4074
3886
  float yl[16], yh[16];
4075
3887
 
4076
3888
  const uint16_t kmask1 = 0x3f3f;
@@ -4153,54 +3965,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4153
3965
  y1 += 4 * QK_K;
4154
3966
 
4155
3967
  }
4156
- #else
4157
- float yl[8], yh[8];
4158
-
4159
- const int il = 4 * (tiisg/8); // 0, 4, 8, 12
4160
- const int ix = tiisg%8;
4161
- const int iq = il/8; // 0, 0, 1, 1
4162
- const int in = il%8; // 0, 4, 0, 4
4163
-
4164
- device const float * y = yy + ix*QK_K + il;
4165
-
4166
- for (int i = ix; i < nb; i += 8) {
4167
-
4168
- for (int l = 0; l < 4; ++l) {
4169
- yl[l+0] = y[l+ 0];
4170
- yl[l+4] = y[l+16];
4171
- yh[l+0] = y[l+32];
4172
- yh[l+4] = y[l+48];
4173
- }
4174
-
4175
- device const half * dh = &x[i].d;
4176
- device const uint8_t * q = x[i].qs + il;
4177
- device const uint8_t * h = x[i].qh + in;
4178
- device const int8_t * s = x[i].scales;
4179
-
4180
- for (int row = 0; row < 2; ++row) {
4181
-
4182
- const float d = dh[0];
4183
-
4184
- float2 acc = {0.f, 0.f};
4185
- for (int l = 0; l < 4; ++l) {
4186
- const uint8_t hl = h[l] >> iq;
4187
- acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
4188
- + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
4189
- acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
4190
- + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
4191
- }
4192
- sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
4193
-
4194
- q += step;
4195
- h += step;
4196
- s += step;
4197
- dh += step/2;
4198
-
4199
- }
4200
-
4201
- y += 8 * QK_K;
4202
- }
4203
- #endif
4204
3968
 
4205
3969
  for (int row = 0; row < 2; ++row) {
4206
3970
  const float tot = simd_sum(sumf[row]);
@@ -4279,7 +4043,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4279
4043
 
4280
4044
  float sumf = 0;
4281
4045
 
4282
- #if QK_K == 256
4283
4046
  const int tid = tiisg/2;
4284
4047
  const int ix = tiisg%2;
4285
4048
  const int ip = tid/8; // 0 or 1
@@ -4315,30 +4078,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4315
4078
 
4316
4079
  }
4317
4080
 
4318
- #else
4319
- const int ix = tiisg/4;
4320
- const int il = 4*(tiisg%4);
4321
-
4322
- for (int i = ix; i < nb; i += 8) {
4323
- device const float * y = yy + i * QK_K + il;
4324
- device const uint8_t * ql = x[i].ql + il;
4325
- device const uint8_t * qh = x[i].qh + il;
4326
- device const int8_t * s = x[i].scales;
4327
-
4328
- const float d = x[i].d;
4329
-
4330
- float4 sums = {0.f, 0.f, 0.f, 0.f};
4331
- for (int l = 0; l < 4; ++l) {
4332
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
4333
- sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
4334
- sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
4335
- sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
4336
- }
4337
- sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
4338
- }
4339
-
4340
- #endif
4341
-
4342
4081
  const float tot = simd_sum(sumf);
4343
4082
  if (tiisg == 0) {
4344
4083
  dst[r1*ne0 + im*ne0*ne1 + row] = tot;
@@ -5172,9 +4911,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5172
4911
 
5173
4912
  device const float * y4 = y + 32 * ix;
5174
4913
 
5175
- #if QK_K != 64
5176
4914
  iq1m_scale_t scale;
5177
- #endif
5178
4915
 
5179
4916
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5180
4917
 
@@ -5195,10 +4932,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5195
4932
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5196
4933
 
5197
4934
  for (int row = 0; row < N_DST; row++) {
5198
-
5199
- #if QK_K != 64
5200
4935
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5201
- #endif
5202
4936
 
5203
4937
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5204
4938
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -5214,14 +4948,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
5214
4948
  }
5215
4949
  const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5216
4950
  const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5217
- #if QK_K == 64
5218
- const float d = (float) *((device const half *)(sc - 1));
5219
- sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
5220
- (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
5221
- #else
4951
+
5222
4952
  sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
5223
4953
  (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
5224
- #endif
5225
4954
 
5226
4955
  sc += nb*sizeof(block_iq1_m)/2;
5227
4956
  qs += nb*sizeof(block_iq1_m);
@@ -5333,7 +5062,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5333
5062
  }
5334
5063
  }
5335
5064
 
5336
- #if QK_K != 64
5337
5065
  void kernel_mul_mv_iq4_xs_f32_impl(
5338
5066
  device const void * src0,
5339
5067
  device const float * src1,
@@ -5428,7 +5156,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5428
5156
  }
5429
5157
  }
5430
5158
  }
5431
- #endif
5432
5159
 
5433
5160
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5434
5161
  kernel void kernel_mul_mv_iq1_s_f32(
@@ -5541,11 +5268,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5541
5268
  uint tiisg[[thread_index_in_simdgroup]],
5542
5269
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5543
5270
 
5544
- #if QK_K == 64
5545
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5546
- #else
5547
5271
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5548
- #endif
5549
5272
  }
5550
5273
 
5551
5274
  //============================= templates and their specializations =============================
@@ -5671,10 +5394,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
5671
5394
  float dl, ml;
5672
5395
  uint8_t sc = xb->scales[il];
5673
5396
 
5674
- #if QK_K == 256
5675
5397
  q = q + 32*(il/8) + 16*(il&1);
5676
5398
  il = (il/2)%4;
5677
- #endif
5399
+
5678
5400
  half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5679
5401
  uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5680
5402
  dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
@@ -5690,7 +5412,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5690
5412
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
5691
5413
  device const int8_t * scales = (device const int8_t *)xb->scales;
5692
5414
 
5693
- #if QK_K == 256
5694
5415
  q = q + 32 * (il/8) + 16 * (il&1);
5695
5416
  h = h + 16 * (il&1);
5696
5417
  uint8_t m = 1 << (il/2);
@@ -5711,17 +5432,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5711
5432
  for (int i = 0; i < 16; ++i) {
5712
5433
  reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
5713
5434
  }
5714
- #else
5715
- float kcoef = il&1 ? 1.f/16.f : 1.f;
5716
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
5717
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
5718
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5719
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5720
- uint8_t m = 1<<(il*2);
5721
- for (int i = 0; i < 16; ++i) {
5722
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
5723
- }
5724
- #endif
5725
5435
  }
5726
5436
 
5727
5437
  static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
@@ -5733,7 +5443,6 @@ template <typename type4x4>
5733
5443
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
5734
5444
  device const uchar * q = xb->qs;
5735
5445
 
5736
- #if QK_K == 256
5737
5446
  short is = (il/4) * 2;
5738
5447
  q = q + (il/4) * 32 + 16 * (il&1);
5739
5448
  il = il & 3;
@@ -5742,16 +5451,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
5742
5451
  const float min = xb->dmin;
5743
5452
  const float dl = d * sc[0];
5744
5453
  const float ml = min * sc[1];
5745
- #else
5746
- (void) get_scale_min_k4_just2;
5747
-
5748
- q = q + 16 * (il&1);
5749
- device const uint8_t * s = xb->scales;
5750
- device const half2 * dh = (device const half2 *)xb->d;
5751
- const float2 d = (float2)dh[0];
5752
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
5753
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
5754
- #endif
5454
+
5755
5455
  const ushort mask = il<2 ? 0x0F : 0xF0;
5756
5456
  for (int i = 0; i < 16; ++i) {
5757
5457
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
@@ -5763,7 +5463,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5763
5463
  device const uint8_t * q = xb->qs;
5764
5464
  device const uint8_t * qh = xb->qh;
5765
5465
 
5766
- #if QK_K == 256
5767
5466
  short is = (il/4) * 2;
5768
5467
  q = q + 32 * (il/4) + 16 * (il&1);
5769
5468
  qh = qh + 16 * (il&1);
@@ -5780,17 +5479,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5780
5479
  for (int i = 0; i < 16; ++i) {
5781
5480
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
5782
5481
  }
5783
- #else
5784
- q = q + 16 * (il&1);
5785
- device const int8_t * s = xb->scales;
5786
- const float dl = xb->d * s[il];
5787
- uint8_t m = 1<<(il*2);
5788
- const float coef = il<2 ? 1.f : 1.f/16.f;
5789
- const ushort mask = il<2 ? 0x0F : 0xF0;
5790
- for (int i = 0; i < 16; ++i) {
5791
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
5792
- }
5793
- #endif
5794
5482
  }
5795
5483
 
5796
5484
  template <typename type4x4>
@@ -5800,15 +5488,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
5800
5488
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
5801
5489
  device const int8_t * scales = (device const int8_t *)xb->scales;
5802
5490
 
5803
- #if QK_K == 256
5804
5491
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
5805
5492
  qh = qh + 32*(il/8) + 16*(il&1);
5806
5493
  float sc = scales[(il%2) + 2 * ((il/2))];
5807
5494
  il = (il/2) & 3;
5808
- #else
5809
- ql = ql + 16 * (il&1);
5810
- float sc = scales[il];
5811
- #endif
5495
+
5812
5496
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5813
5497
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
5814
5498
  const float coef = il>1 ? 1.f/16.f : 1.f;
@@ -5965,20 +5649,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
5965
5649
  const int ib32 = il/2;
5966
5650
  il = il%2;
5967
5651
  device const uint16_t * sc = (device const uint16_t *)xb->scales;
5968
- #if QK_K == 64
5969
- const float d = xb->d;
5970
- #else
5652
+
5971
5653
  iq1m_scale_t scale;
5972
5654
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5973
5655
  const float d = scale.f16;
5974
- #endif
5656
+
5975
5657
  device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5976
5658
  device const uint8_t * qh = xb->qh + 2*ib32 + il;
5977
- #if QK_K == 64
5978
- const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5979
- #else
5659
+
5980
5660
  const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5981
- #endif
5982
5661
  const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5983
5662
  const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5984
5663
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -6008,9 +5687,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
6008
5687
 
6009
5688
  template <typename type4x4>
6010
5689
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
6011
- #if QK_K == 64
6012
- dequantize_iq4_nl(xb, il, reg);
6013
- #else
6014
5690
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
6015
5691
  const int ib32 = il/2;
6016
5692
  il = il%2;
@@ -6027,7 +5703,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
6027
5703
  reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
6028
5704
  reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
6029
5705
  }
6030
- #endif
6031
5706
  }
6032
5707
 
6033
5708
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6532,11 +6207,7 @@ kernel void kernel_mul_mm_id(
6532
6207
  sgitg);
6533
6208
  }
6534
6209
 
6535
- #if QK_K == 256
6536
6210
  #define QK_NL 16
6537
- #else
6538
- #define QK_NL 4
6539
- #endif
6540
6211
 
6541
6212
  //
6542
6213
  // get rows
@@ -6576,11 +6247,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
6576
6247
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6577
6248
  template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6578
6249
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6579
- #if QK_K == 64
6580
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6581
- #else
6582
6250
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6583
- #endif
6584
6251
 
6585
6252
  //
6586
6253
  // matrix-matrix multiplication
@@ -6608,11 +6275,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
6608
6275
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6609
6276
  template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
6610
6277
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6611
- #if QK_K == 64
6612
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6613
- #else
6614
6278
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6615
- #endif
6616
6279
 
6617
6280
  //
6618
6281
  // indirect matrix-matrix multiplication
@@ -6640,11 +6303,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
6640
6303
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6641
6304
  template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
6642
6305
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6643
- #if QK_K == 64
6644
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6645
- #else
6646
6306
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6647
- #endif
6648
6307
 
6649
6308
  //
6650
6309
  // matrix-vector multiplication
@@ -6853,7 +6512,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
6853
6512
  template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6854
6513
  template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6855
6514
  template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6856
- #if QK_K != 64
6857
6515
  template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6858
- #endif
6859
6516