llama_cpp 0.15.1 → 0.15.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -229,6 +229,13 @@ kernel void kernel_relu(
229
229
  dst[tpig] = max(0.0f, src0[tpig]);
230
230
  }
231
231
 
232
+ kernel void kernel_sigmoid(
233
+ device const float * src0,
234
+ device float * dst,
235
+ uint tpig[[thread_position_in_grid]]) {
236
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
237
+ }
238
+
232
239
  kernel void kernel_tanh(
233
240
  device const float * src0,
234
241
  device float * dst,
@@ -356,7 +363,6 @@ template<typename T>
356
363
  kernel void kernel_soft_max(
357
364
  device const char * src0,
358
365
  device const char * src1,
359
- device const char * src2,
360
366
  device char * dst,
361
367
  constant int64_t & ne00,
362
368
  constant int64_t & ne01,
@@ -378,10 +384,9 @@ kernel void kernel_soft_max(
378
384
 
379
385
  device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
380
386
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
381
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
382
387
  device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
383
388
 
384
- float slope = 0.0f;
389
+ float slope = 1.0f;
385
390
 
386
391
  // ALiBi
387
392
  if (max_bias > 0.0f) {
@@ -397,7 +402,7 @@ kernel void kernel_soft_max(
397
402
  float lmax = -INFINITY;
398
403
 
399
404
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
400
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
405
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
401
406
  }
402
407
 
403
408
  // find the max value in the block
@@ -422,7 +427,7 @@ kernel void kernel_soft_max(
422
427
  // parallel sum
423
428
  float lsum = 0.0f;
424
429
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
425
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
430
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
426
431
  lsum += exp_psrc0;
427
432
  pdst[i00] = exp_psrc0;
428
433
  }
@@ -461,7 +466,6 @@ template<typename T>
461
466
  kernel void kernel_soft_max_4(
462
467
  device const char * src0,
463
468
  device const char * src1,
464
- device const char * src2,
465
469
  device char * dst,
466
470
  constant int64_t & ne00,
467
471
  constant int64_t & ne01,
@@ -483,10 +487,9 @@ kernel void kernel_soft_max_4(
483
487
 
484
488
  device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
485
489
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
486
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
487
490
  device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
488
491
 
489
- float slope = 0.0f;
492
+ float slope = 1.0f;
490
493
 
491
494
  if (max_bias > 0.0f) {
492
495
  const int64_t h = i02;
@@ -501,7 +504,7 @@ kernel void kernel_soft_max_4(
501
504
  float4 lmax4 = -INFINITY;
502
505
 
503
506
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
504
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
507
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
505
508
  }
506
509
 
507
510
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -527,7 +530,7 @@ kernel void kernel_soft_max_4(
527
530
  // parallel sum
528
531
  float4 lsum4 = 0.0f;
529
532
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
530
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
533
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
531
534
  lsum4 += exp_psrc4;
532
535
  pdst4[i00] = exp_psrc4;
533
536
  }
@@ -1595,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
1595
1598
  }
1596
1599
  }
1597
1600
 
1598
- kernel void kernel_alibi_f32(
1599
- device const float * src0,
1600
- device float * dst,
1601
- constant int64_t & ne00,
1602
- constant int64_t & ne01,
1603
- constant int64_t & ne02,
1604
- constant int64_t & ne03,
1605
- constant uint64_t & nb00,
1606
- constant uint64_t & nb01,
1607
- constant uint64_t & nb02,
1608
- constant uint64_t & nb03,
1609
- constant int64_t & ne0,
1610
- constant int64_t & ne1,
1611
- constant int64_t & ne2,
1612
- constant int64_t & ne3,
1613
- constant uint64_t & nb0,
1614
- constant uint64_t & nb1,
1615
- constant uint64_t & nb2,
1616
- constant uint64_t & nb3,
1617
- constant float & m0,
1618
- constant float & m1,
1619
- constant int & n_heads_log2_floor,
1620
- uint3 tgpig[[threadgroup_position_in_grid]],
1621
- uint3 tpitg[[thread_position_in_threadgroup]],
1622
- uint3 ntg[[threads_per_threadgroup]]) {
1623
- const int64_t i03 = tgpig[2];
1624
- const int64_t i02 = tgpig[1];
1625
- const int64_t i01 = tgpig[0];
1626
-
1627
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1628
-
1629
- const int64_t i3 = n / (ne2*ne1*ne0);
1630
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1631
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1632
- //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1633
-
1634
- const int64_t k = i3*ne3 + i2;
1635
-
1636
- float m_k;
1637
- if (k < n_heads_log2_floor) {
1638
- m_k = pow(m0, k + 1);
1639
- } else {
1640
- m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1641
- }
1642
-
1643
- device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1644
- device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1645
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1646
- const float src_v = *(device float *)(src_row + i00*nb00);
1647
- device float * dst_v = (device float *)(dst_row + i00*nb0);
1648
- *dst_v = i00 * m_k + src_v;
1649
- }
1650
- }
1651
-
1652
1601
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1653
1602
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1654
1603
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -1691,6 +1640,7 @@ static void rope_yarn_corr_dims(
1691
1640
  typedef void (rope_t)(
1692
1641
  device const void * src0,
1693
1642
  device const int32_t * src1,
1643
+ device const float * src2,
1694
1644
  device float * dst,
1695
1645
  constant int64_t & ne00,
1696
1646
  constant int64_t & ne01,
@@ -1726,6 +1676,7 @@ template<typename T>
1726
1676
  kernel void kernel_rope(
1727
1677
  device const void * src0,
1728
1678
  device const int32_t * src1,
1679
+ device const float * src2,
1729
1680
  device float * dst,
1730
1681
  constant int64_t & ne00,
1731
1682
  constant int64_t & ne01,
@@ -1795,8 +1746,10 @@ kernel void kernel_rope(
1795
1746
 
1796
1747
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1797
1748
  const float cur_rot = inv_ndims*ic - ib;
1749
+ const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1750
+
1751
+ const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1798
1752
 
1799
- const float theta = theta_0 * pow(freq_base, cur_rot);
1800
1753
  float cos_theta, sin_theta;
1801
1754
  rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1802
1755
 
@@ -1903,7 +1856,10 @@ kernel void kernel_upscale_f32(
1903
1856
  constant uint64_t & nb1,
1904
1857
  constant uint64_t & nb2,
1905
1858
  constant uint64_t & nb3,
1906
- constant int32_t & sf,
1859
+ constant float & sf0,
1860
+ constant float & sf1,
1861
+ constant float & sf2,
1862
+ constant float & sf3,
1907
1863
  uint3 tgpig[[threadgroup_position_in_grid]],
1908
1864
  uint3 tpitg[[thread_position_in_threadgroup]],
1909
1865
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -1912,15 +1868,17 @@ kernel void kernel_upscale_f32(
1912
1868
  const int64_t i2 = tgpig.y;
1913
1869
  const int64_t i1 = tgpig.x;
1914
1870
 
1915
- const int64_t i03 = i3;
1916
- const int64_t i02 = i2;
1917
- const int64_t i01 = i1/sf;
1918
-
1919
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1920
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1871
+ const int64_t i03 = i3/sf3;
1872
+ const int64_t i02 = i2/sf2;
1873
+ const int64_t i01 = i1/sf1;
1921
1874
 
1922
1875
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1923
- dst_ptr[i0] = src0_ptr[i0/sf];
1876
+ const int64_t i00 = i0/sf0;
1877
+
1878
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1879
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1880
+
1881
+ dst_ptr[0] = src0_ptr[0];
1924
1882
  }
1925
1883
  }
1926
1884
 
@@ -2100,29 +2058,29 @@ typedef void (flash_attn_ext_f16_t)(
2100
2058
  device const char * v,
2101
2059
  device const char * mask,
2102
2060
  device float * dst,
2103
- constant int64_t & ne00,
2104
2061
  constant int64_t & ne01,
2105
2062
  constant int64_t & ne02,
2106
2063
  constant int64_t & ne03,
2107
- constant uint64_t & nb00,
2108
2064
  constant uint64_t & nb01,
2109
2065
  constant uint64_t & nb02,
2110
2066
  constant uint64_t & nb03,
2111
- constant int64_t & ne10,
2112
2067
  constant int64_t & ne11,
2113
2068
  constant int64_t & ne12,
2114
2069
  constant int64_t & ne13,
2115
- constant uint64_t & nb10,
2116
2070
  constant uint64_t & nb11,
2117
2071
  constant uint64_t & nb12,
2118
2072
  constant uint64_t & nb13,
2119
- constant int64_t & ne31,
2073
+ constant uint64_t & nb21,
2074
+ constant uint64_t & nb22,
2075
+ constant uint64_t & nb23,
2120
2076
  constant uint64_t & nb31,
2121
- constant int64_t & ne0,
2122
2077
  constant int64_t & ne1,
2123
2078
  constant int64_t & ne2,
2124
- constant int64_t & ne3,
2125
2079
  constant float & scale,
2080
+ constant float & max_bias,
2081
+ constant float & m0,
2082
+ constant float & m1,
2083
+ constant uint32_t & n_head_log2,
2126
2084
  threadgroup half * shared,
2127
2085
  uint3 tgpig[[threadgroup_position_in_grid]],
2128
2086
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2138,29 +2096,29 @@ kernel void kernel_flash_attn_ext_f16(
2138
2096
  device const char * v,
2139
2097
  device const char * mask,
2140
2098
  device float * dst,
2141
- constant int64_t & ne00,
2142
2099
  constant int64_t & ne01,
2143
2100
  constant int64_t & ne02,
2144
2101
  constant int64_t & ne03,
2145
- constant uint64_t & nb00,
2146
2102
  constant uint64_t & nb01,
2147
2103
  constant uint64_t & nb02,
2148
2104
  constant uint64_t & nb03,
2149
- constant int64_t & ne10,
2150
2105
  constant int64_t & ne11,
2151
2106
  constant int64_t & ne12,
2152
2107
  constant int64_t & ne13,
2153
- constant uint64_t & nb10,
2154
2108
  constant uint64_t & nb11,
2155
2109
  constant uint64_t & nb12,
2156
2110
  constant uint64_t & nb13,
2157
- constant int64_t & ne31,
2111
+ constant uint64_t & nb21,
2112
+ constant uint64_t & nb22,
2113
+ constant uint64_t & nb23,
2158
2114
  constant uint64_t & nb31,
2159
- constant int64_t & ne0,
2160
2115
  constant int64_t & ne1,
2161
2116
  constant int64_t & ne2,
2162
- constant int64_t & ne3,
2163
2117
  constant float & scale,
2118
+ constant float & max_bias,
2119
+ constant float & m0,
2120
+ constant float & m1,
2121
+ constant uint32_t & n_head_log2,
2164
2122
  threadgroup half * shared [[threadgroup(0)]],
2165
2123
  uint3 tgpig[[threadgroup_position_in_grid]],
2166
2124
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2225,10 +2183,6 @@ kernel void kernel_flash_attn_ext_f16(
2225
2183
  const short ne22 = ne12;
2226
2184
  const short ne23 = ne13;
2227
2185
 
2228
- const uint nb21 = nb11;
2229
- const uint nb22 = nb12;
2230
- const uint nb23 = nb13;
2231
-
2232
2186
  // broadcast
2233
2187
  const short rk2 = ne02/ne12;
2234
2188
  const short rk3 = ne03/ne13;
@@ -2254,8 +2208,17 @@ kernel void kernel_flash_attn_ext_f16(
2254
2208
  // pointer to the mask
2255
2209
  device const half * mp = (device const half *) (mask + iq1*nb31);
2256
2210
 
2257
- // prepare diagonal scale matrix
2258
- simdgroup_float8x8 mscale(scale);
2211
+ float slope = 1.0f;
2212
+
2213
+ // ALiBi
2214
+ if (max_bias > 0.0f) {
2215
+ const uint32_t h = iq2;
2216
+
2217
+ const float base = h < n_head_log2 ? m0 : m1;
2218
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
+
2220
+ slope = pow(base, exph);
2221
+ }
2259
2222
 
2260
2223
  // loop over the KV cache
2261
2224
  // each simdgroup handles blocks of Q rows and C columns
@@ -2279,12 +2242,20 @@ kernel void kernel_flash_attn_ext_f16(
2279
2242
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2280
2243
  }
2281
2244
 
2282
- // mqk = mqk*scale + mask
2283
- simdgroup_half8x8 mm;
2284
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2285
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2286
-
2287
2245
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2246
+
2247
+ const short tx = tiisg%4;
2248
+ const short ty = tiisg/4;
2249
+
2250
+ if (mask != q) {
2251
+ // mqk = mqk*scale + mask*slope
2252
+ ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2253
+ ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2254
+ } else {
2255
+ // mqk = mqk*scale
2256
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2257
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2258
+ }
2288
2259
  }
2289
2260
  }
2290
2261
 
@@ -2456,29 +2427,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
2456
2427
  device const char * v,
2457
2428
  device const char * mask,
2458
2429
  device float * dst,
2459
- constant int64_t & ne00,
2460
2430
  constant int64_t & ne01,
2461
2431
  constant int64_t & ne02,
2462
2432
  constant int64_t & ne03,
2463
- constant uint64_t & nb00,
2464
2433
  constant uint64_t & nb01,
2465
2434
  constant uint64_t & nb02,
2466
2435
  constant uint64_t & nb03,
2467
- constant int64_t & ne10,
2468
2436
  constant int64_t & ne11,
2469
2437
  constant int64_t & ne12,
2470
2438
  constant int64_t & ne13,
2471
- constant uint64_t & nb10,
2472
2439
  constant uint64_t & nb11,
2473
2440
  constant uint64_t & nb12,
2474
2441
  constant uint64_t & nb13,
2475
- constant int64_t & ne31,
2442
+ constant uint64_t & nb21,
2443
+ constant uint64_t & nb22,
2444
+ constant uint64_t & nb23,
2476
2445
  constant uint64_t & nb31,
2477
- constant int64_t & ne0,
2478
2446
  constant int64_t & ne1,
2479
2447
  constant int64_t & ne2,
2480
- constant int64_t & ne3,
2481
2448
  constant float & scale,
2449
+ constant float & max_bias,
2450
+ constant float & m0,
2451
+ constant float & m1,
2452
+ constant uint32_t & n_head_log2,
2482
2453
  threadgroup half * shared [[threadgroup(0)]],
2483
2454
  uint3 tgpig[[threadgroup_position_in_grid]],
2484
2455
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2497,6 +2468,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
2497
2468
 
2498
2469
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2499
2470
 
2471
+ float slope = 1.0f;
2472
+
2473
+ // ALiBi
2474
+ if (max_bias > 0.0f) {
2475
+ const uint32_t h = iq2;
2476
+
2477
+ const float base = h < n_head_log2 ? m0 : m1;
2478
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2479
+
2480
+ slope = pow(base, exp);
2481
+ }
2482
+
2500
2483
  //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2501
2484
  threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2502
2485
  threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@@ -2537,10 +2520,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
2537
2520
  const short ne22 = ne12;
2538
2521
  const short ne23 = ne13;
2539
2522
 
2540
- const uint nb21 = nb11;
2541
- const uint nb22 = nb12;
2542
- const uint nb23 = nb13;
2543
-
2544
2523
  // broadcast
2545
2524
  const short rk2 = ne02/ne12;
2546
2525
  const short rk3 = ne03/ne13;
@@ -2603,10 +2582,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
2603
2582
  mqk += simd_shuffle_down(mqk, 2);
2604
2583
  mqk += simd_shuffle_down(mqk, 1);
2605
2584
 
2606
- // mqk = mqk*scale + mask
2585
+ // mqk = mqk*scale + mask*slope
2607
2586
  if (tiisg == 0) {
2608
- float4 mm = (float4) mp4[ic/4 + cc];
2609
- mqk = mqk*scale + mm;
2587
+ mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
2610
2588
 
2611
2589
  ss4[cc] = mqk;
2612
2590
  }
@@ -3408,7 +3386,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3408
3386
 
3409
3387
  const int step = sizeof(block_q2_K) * nb;
3410
3388
 
3411
- #if QK_K == 256
3412
3389
  const int ix = tiisg/8; // 0...3
3413
3390
  const int it = tiisg%8; // 0...7
3414
3391
  const int iq = it/4; // 0 or 1
@@ -3460,57 +3437,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3460
3437
 
3461
3438
  y4 += 4 * QK_K;
3462
3439
  }
3463
- #else
3464
- const int ix = tiisg/2; // 0...15
3465
- const int it = tiisg%2; // 0...1
3466
-
3467
- device const float * y4 = y + ix * QK_K + 8 * it;
3468
-
3469
- for (int ib = ix; ib < nb; ib += 16) {
3470
-
3471
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
3472
- for (int i = 0; i < 8; ++i) {
3473
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
3474
- yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
3475
- yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
3476
- yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
3477
- }
3478
-
3479
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
3480
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3481
- device const half * dh = &x[ib].d;
3482
-
3483
- for (int row = 0; row < N_DST; row++) {
3484
-
3485
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
3486
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
3487
- for (int i = 0; i < 8; i += 2) {
3488
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
3489
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
3490
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
3491
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
3492
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
3493
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
3494
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
3495
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
3496
- }
3497
-
3498
- float dall = dh[0];
3499
- float dmin = dh[1];
3500
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
3501
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
3502
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
3503
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
3504
- dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
3505
-
3506
- qs += step/2;
3507
- sc += step;
3508
- dh += step/2;
3509
- }
3510
-
3511
- y4 += 16 * QK_K;
3512
- }
3513
- #endif
3514
3440
 
3515
3441
  for (int row = 0; row < N_DST; ++row) {
3516
3442
  all_sum = simd_sum(sumf[row]);
@@ -3548,7 +3474,6 @@ kernel void kernel_mul_mv_q2_K_f32(
3548
3474
  kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3549
3475
  }
3550
3476
 
3551
- #if QK_K == 256
3552
3477
  void kernel_mul_mv_q3_K_f32_impl(
3553
3478
  device const void * src0,
3554
3479
  device const float * src1,
@@ -3707,84 +3632,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3707
3632
  }
3708
3633
  }
3709
3634
  }
3710
- #else
3711
- void kernel_mul_mv_q3_K_f32_impl(
3712
- device const void * src0,
3713
- device const float * src1,
3714
- device float * dst,
3715
- constant int64_t & ne00,
3716
- constant int64_t & ne01,
3717
- constant int64_t & ne02,
3718
- constant int64_t & ne10,
3719
- constant int64_t & ne12,
3720
- constant int64_t & ne0,
3721
- constant int64_t & ne1,
3722
- constant uint & r2,
3723
- constant uint & r3,
3724
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3725
- uint3 tgpig[[threadgroup_position_in_grid]],
3726
- uint tiisg[[thread_index_in_simdgroup]],
3727
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3728
-
3729
- const int nb = ne00/QK_K;
3730
-
3731
- const int64_t r0 = tgpig.x;
3732
- const int64_t r1 = tgpig.y;
3733
- const int64_t im = tgpig.z;
3734
-
3735
- const int row = 2 * r0 + sgitg;
3736
-
3737
- const uint i12 = im%ne12;
3738
- const uint i13 = im/ne12;
3739
-
3740
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3741
-
3742
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
3743
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3744
-
3745
- const int ix = tiisg/4;
3746
- const int il = 4 * (tiisg%4);// 0, 4, 8, 12
3747
- const int iq = il/8; // 0, 0, 1, 1
3748
- const int in = il%8; // 0, 4, 0, 4
3749
-
3750
- float2 sum = {0.f, 0.f};
3751
-
3752
- for (int i = ix; i < nb; i += 8) {
3753
-
3754
- const float d_all = (float)(x[i].d);
3755
-
3756
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
3757
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
3758
- device const uint16_t * s = (device const uint16_t *)(x[i].scales);
3759
- device const float * y = yy + i * QK_K + il;
3760
-
3761
- const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
3762
- const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
3763
- const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
3764
- const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
3765
-
3766
- for (int l = 0; l < 4; l += 2) {
3767
- const uint16_t hm = h[l/2] >> iq;
3768
- sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
3769
- + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
3770
- + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
3771
- + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
3772
- sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
3773
- + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
3774
- + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
3775
- + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
3776
- }
3777
-
3778
- }
3779
- const float sumf = sum[0] + sum[1] * 1.f/256.f;
3780
-
3781
- const float tot = simd_sum(sumf);
3782
- if (tiisg == 0) {
3783
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
3784
- }
3785
-
3786
- }
3787
- #endif
3788
3635
 
3789
3636
  [[host_name("kernel_mul_mv_q3_K_f32")]]
3790
3637
  kernel void kernel_mul_mv_q3_K_f32(
@@ -3814,7 +3661,6 @@ kernel void kernel_mul_mv_q3_K_f32(
3814
3661
  kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3815
3662
  }
3816
3663
 
3817
- #if QK_K == 256
3818
3664
  void kernel_mul_mv_q4_K_f32_impl(
3819
3665
  device const void * src0,
3820
3666
  device const float * src1,
@@ -3928,103 +3774,6 @@ void kernel_mul_mv_q4_K_f32_impl(
3928
3774
  }
3929
3775
  }
3930
3776
  }
3931
- #else
3932
- void kernel_mul_mv_q4_K_f32_impl(
3933
- device const void * src0,
3934
- device const float * src1,
3935
- device float * dst,
3936
- constant int64_t & ne00,
3937
- constant int64_t & ne01,
3938
- constant int64_t & ne02,
3939
- constant int64_t & ne10,
3940
- constant int64_t & ne12,
3941
- constant int64_t & ne0,
3942
- constant int64_t & ne1,
3943
- constant uint & r2,
3944
- constant uint & r3,
3945
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3946
- uint3 tgpig[[threadgroup_position_in_grid]],
3947
- uint tiisg[[thread_index_in_simdgroup]],
3948
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3949
-
3950
- const int ix = tiisg/4; // 0...7
3951
- const int it = tiisg%4; // 0...3
3952
-
3953
- const int nb = ne00/QK_K;
3954
- const int r0 = tgpig.x;
3955
- const int r1 = tgpig.y;
3956
- const int im = tgpig.z;
3957
- const int first_row = r0 * N_DST;
3958
- const int ib_row = first_row * nb;
3959
-
3960
- const uint i12 = im%ne12;
3961
- const uint i13 = im/ne12;
3962
-
3963
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3964
-
3965
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
3966
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3967
-
3968
- float yl[8];
3969
- float yh[8];
3970
- float sumf[N_DST]={0.f}, all_sum;
3971
-
3972
- const int step = sizeof(block_q4_K) * nb / 2;
3973
-
3974
- device const float * y4 = y + ix * QK_K + 8 * it;
3975
-
3976
- uint16_t sc16[4];
3977
-
3978
- for (int ib = ix; ib < nb; ib += 8) {
3979
-
3980
- float2 sumy = {0.f, 0.f};
3981
- for (int i = 0; i < 8; ++i) {
3982
- yl[i] = y4[i+ 0]; sumy[0] += yl[i];
3983
- yh[i] = y4[i+32]; sumy[1] += yh[i];
3984
- }
3985
-
3986
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
3987
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3988
- device const half * dh = x[ib].d;
3989
-
3990
- for (int row = 0; row < N_DST; row++) {
3991
-
3992
- sc16[0] = sc[0] & 0x000f;
3993
- sc16[1] = sc[0] & 0x0f00;
3994
- sc16[2] = sc[0] & 0x00f0;
3995
- sc16[3] = sc[0] & 0xf000;
3996
-
3997
- float2 acc1 = {0.f, 0.f};
3998
- float2 acc2 = {0.f, 0.f};
3999
- for (int i = 0; i < 8; i += 2) {
4000
- acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
4001
- acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
4002
- acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
4003
- acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
4004
- }
4005
-
4006
- float dall = dh[0];
4007
- float dmin = dh[1];
4008
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
4009
- (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
4010
- dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
4011
-
4012
- qs += step;
4013
- sc += step;
4014
- dh += step;
4015
- }
4016
-
4017
- y4 += 8 * QK_K;
4018
- }
4019
-
4020
- for (int row = 0; row < N_DST; ++row) {
4021
- all_sum = simd_sum(sumf[row]);
4022
- if (tiisg == 0) {
4023
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4024
- }
4025
- }
4026
- }
4027
- #endif
4028
3777
 
4029
3778
  [[host_name("kernel_mul_mv_q4_K_f32")]]
4030
3779
  kernel void kernel_mul_mv_q4_K_f32(
@@ -4092,8 +3841,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4092
3841
 
4093
3842
  const int step = sizeof(block_q5_K) * nb;
4094
3843
 
4095
- #if QK_K == 256
4096
- #
4097
3844
  float yl[16], yh[16];
4098
3845
 
4099
3846
  const uint16_t kmask1 = 0x3f3f;
@@ -4176,54 +3923,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4176
3923
  y1 += 4 * QK_K;
4177
3924
 
4178
3925
  }
4179
- #else
4180
- float yl[8], yh[8];
4181
-
4182
- const int il = 4 * (tiisg/8); // 0, 4, 8, 12
4183
- const int ix = tiisg%8;
4184
- const int iq = il/8; // 0, 0, 1, 1
4185
- const int in = il%8; // 0, 4, 0, 4
4186
-
4187
- device const float * y = yy + ix*QK_K + il;
4188
-
4189
- for (int i = ix; i < nb; i += 8) {
4190
-
4191
- for (int l = 0; l < 4; ++l) {
4192
- yl[l+0] = y[l+ 0];
4193
- yl[l+4] = y[l+16];
4194
- yh[l+0] = y[l+32];
4195
- yh[l+4] = y[l+48];
4196
- }
4197
-
4198
- device const half * dh = &x[i].d;
4199
- device const uint8_t * q = x[i].qs + il;
4200
- device const uint8_t * h = x[i].qh + in;
4201
- device const int8_t * s = x[i].scales;
4202
-
4203
- for (int row = 0; row < 2; ++row) {
4204
-
4205
- const float d = dh[0];
4206
-
4207
- float2 acc = {0.f, 0.f};
4208
- for (int l = 0; l < 4; ++l) {
4209
- const uint8_t hl = h[l] >> iq;
4210
- acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
4211
- + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
4212
- acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
4213
- + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
4214
- }
4215
- sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
4216
-
4217
- q += step;
4218
- h += step;
4219
- s += step;
4220
- dh += step/2;
4221
-
4222
- }
4223
-
4224
- y += 8 * QK_K;
4225
- }
4226
- #endif
4227
3926
 
4228
3927
  for (int row = 0; row < 2; ++row) {
4229
3928
  const float tot = simd_sum(sumf[row]);
@@ -4302,7 +4001,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4302
4001
 
4303
4002
  float sumf = 0;
4304
4003
 
4305
- #if QK_K == 256
4306
4004
  const int tid = tiisg/2;
4307
4005
  const int ix = tiisg%2;
4308
4006
  const int ip = tid/8; // 0 or 1
@@ -4338,30 +4036,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4338
4036
 
4339
4037
  }
4340
4038
 
4341
- #else
4342
- const int ix = tiisg/4;
4343
- const int il = 4*(tiisg%4);
4344
-
4345
- for (int i = ix; i < nb; i += 8) {
4346
- device const float * y = yy + i * QK_K + il;
4347
- device const uint8_t * ql = x[i].ql + il;
4348
- device const uint8_t * qh = x[i].qh + il;
4349
- device const int8_t * s = x[i].scales;
4350
-
4351
- const float d = x[i].d;
4352
-
4353
- float4 sums = {0.f, 0.f, 0.f, 0.f};
4354
- for (int l = 0; l < 4; ++l) {
4355
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
4356
- sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
4357
- sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
4358
- sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
4359
- }
4360
- sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
4361
- }
4362
-
4363
- #endif
4364
-
4365
4039
  const float tot = simd_sum(sumf);
4366
4040
  if (tiisg == 0) {
4367
4041
  dst[r1*ne0 + im*ne0*ne1 + row] = tot;
@@ -5195,9 +4869,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5195
4869
 
5196
4870
  device const float * y4 = y + 32 * ix;
5197
4871
 
5198
- #if QK_K != 64
5199
4872
  iq1m_scale_t scale;
5200
- #endif
5201
4873
 
5202
4874
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5203
4875
 
@@ -5218,10 +4890,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5218
4890
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5219
4891
 
5220
4892
  for (int row = 0; row < N_DST; row++) {
5221
-
5222
- #if QK_K != 64
5223
4893
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5224
- #endif
5225
4894
 
5226
4895
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5227
4896
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -5237,14 +4906,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
5237
4906
  }
5238
4907
  const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5239
4908
  const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5240
- #if QK_K == 64
5241
- const float d = (float) *((device const half *)(sc - 1));
5242
- sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
5243
- (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
5244
- #else
4909
+
5245
4910
  sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
5246
4911
  (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
5247
- #endif
5248
4912
 
5249
4913
  sc += nb*sizeof(block_iq1_m)/2;
5250
4914
  qs += nb*sizeof(block_iq1_m);
@@ -5356,7 +5020,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5356
5020
  }
5357
5021
  }
5358
5022
 
5359
- #if QK_K != 64
5360
5023
  void kernel_mul_mv_iq4_xs_f32_impl(
5361
5024
  device const void * src0,
5362
5025
  device const float * src1,
@@ -5451,7 +5114,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5451
5114
  }
5452
5115
  }
5453
5116
  }
5454
- #endif
5455
5117
 
5456
5118
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5457
5119
  kernel void kernel_mul_mv_iq1_s_f32(
@@ -5564,11 +5226,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5564
5226
  uint tiisg[[thread_index_in_simdgroup]],
5565
5227
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5566
5228
 
5567
- #if QK_K == 64
5568
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5569
- #else
5570
5229
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5571
- #endif
5572
5230
  }
5573
5231
 
5574
5232
  //============================= templates and their specializations =============================
@@ -5694,10 +5352,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
5694
5352
  float dl, ml;
5695
5353
  uint8_t sc = xb->scales[il];
5696
5354
 
5697
- #if QK_K == 256
5698
5355
  q = q + 32*(il/8) + 16*(il&1);
5699
5356
  il = (il/2)%4;
5700
- #endif
5357
+
5701
5358
  half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5702
5359
  uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5703
5360
  dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
@@ -5713,7 +5370,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5713
5370
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
5714
5371
  device const int8_t * scales = (device const int8_t *)xb->scales;
5715
5372
 
5716
- #if QK_K == 256
5717
5373
  q = q + 32 * (il/8) + 16 * (il&1);
5718
5374
  h = h + 16 * (il&1);
5719
5375
  uint8_t m = 1 << (il/2);
@@ -5734,17 +5390,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5734
5390
  for (int i = 0; i < 16; ++i) {
5735
5391
  reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
5736
5392
  }
5737
- #else
5738
- float kcoef = il&1 ? 1.f/16.f : 1.f;
5739
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
5740
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
5741
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5742
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5743
- uint8_t m = 1<<(il*2);
5744
- for (int i = 0; i < 16; ++i) {
5745
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
5746
- }
5747
- #endif
5748
5393
  }
5749
5394
 
5750
5395
  static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
@@ -5756,7 +5401,6 @@ template <typename type4x4>
5756
5401
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
5757
5402
  device const uchar * q = xb->qs;
5758
5403
 
5759
- #if QK_K == 256
5760
5404
  short is = (il/4) * 2;
5761
5405
  q = q + (il/4) * 32 + 16 * (il&1);
5762
5406
  il = il & 3;
@@ -5765,16 +5409,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
5765
5409
  const float min = xb->dmin;
5766
5410
  const float dl = d * sc[0];
5767
5411
  const float ml = min * sc[1];
5768
- #else
5769
- (void) get_scale_min_k4_just2;
5770
-
5771
- q = q + 16 * (il&1);
5772
- device const uint8_t * s = xb->scales;
5773
- device const half2 * dh = (device const half2 *)xb->d;
5774
- const float2 d = (float2)dh[0];
5775
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
5776
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
5777
- #endif
5412
+
5778
5413
  const ushort mask = il<2 ? 0x0F : 0xF0;
5779
5414
  for (int i = 0; i < 16; ++i) {
5780
5415
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
@@ -5786,7 +5421,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5786
5421
  device const uint8_t * q = xb->qs;
5787
5422
  device const uint8_t * qh = xb->qh;
5788
5423
 
5789
- #if QK_K == 256
5790
5424
  short is = (il/4) * 2;
5791
5425
  q = q + 32 * (il/4) + 16 * (il&1);
5792
5426
  qh = qh + 16 * (il&1);
@@ -5803,17 +5437,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5803
5437
  for (int i = 0; i < 16; ++i) {
5804
5438
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
5805
5439
  }
5806
- #else
5807
- q = q + 16 * (il&1);
5808
- device const int8_t * s = xb->scales;
5809
- const float dl = xb->d * s[il];
5810
- uint8_t m = 1<<(il*2);
5811
- const float coef = il<2 ? 1.f : 1.f/16.f;
5812
- const ushort mask = il<2 ? 0x0F : 0xF0;
5813
- for (int i = 0; i < 16; ++i) {
5814
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
5815
- }
5816
- #endif
5817
5440
  }
5818
5441
 
5819
5442
  template <typename type4x4>
@@ -5823,15 +5446,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
5823
5446
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
5824
5447
  device const int8_t * scales = (device const int8_t *)xb->scales;
5825
5448
 
5826
- #if QK_K == 256
5827
5449
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
5828
5450
  qh = qh + 32*(il/8) + 16*(il&1);
5829
5451
  float sc = scales[(il%2) + 2 * ((il/2))];
5830
5452
  il = (il/2) & 3;
5831
- #else
5832
- ql = ql + 16 * (il&1);
5833
- float sc = scales[il];
5834
- #endif
5453
+
5835
5454
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5836
5455
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
5837
5456
  const float coef = il>1 ? 1.f/16.f : 1.f;
@@ -5988,20 +5607,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
5988
5607
  const int ib32 = il/2;
5989
5608
  il = il%2;
5990
5609
  device const uint16_t * sc = (device const uint16_t *)xb->scales;
5991
- #if QK_K == 64
5992
- const float d = xb->d;
5993
- #else
5610
+
5994
5611
  iq1m_scale_t scale;
5995
5612
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5996
5613
  const float d = scale.f16;
5997
- #endif
5614
+
5998
5615
  device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5999
5616
  device const uint8_t * qh = xb->qh + 2*ib32 + il;
6000
- #if QK_K == 64
6001
- const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
6002
- #else
5617
+
6003
5618
  const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
6004
- #endif
6005
5619
  const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
6006
5620
  const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
6007
5621
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -6031,9 +5645,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
6031
5645
 
6032
5646
  template <typename type4x4>
6033
5647
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
6034
- #if QK_K == 64
6035
- dequantize_iq4_nl(xb, il, reg);
6036
- #else
6037
5648
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
6038
5649
  const int ib32 = il/2;
6039
5650
  il = il%2;
@@ -6050,7 +5661,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
6050
5661
  reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
6051
5662
  reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
6052
5663
  }
6053
- #endif
6054
5664
  }
6055
5665
 
6056
5666
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6555,11 +6165,7 @@ kernel void kernel_mul_mm_id(
6555
6165
  sgitg);
6556
6166
  }
6557
6167
 
6558
- #if QK_K == 256
6559
6168
  #define QK_NL 16
6560
- #else
6561
- #define QK_NL 4
6562
- #endif
6563
6169
 
6564
6170
  //
6565
6171
  // get rows
@@ -6599,11 +6205,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
6599
6205
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6600
6206
  template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6601
6207
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6602
- #if QK_K == 64
6603
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6604
- #else
6605
6208
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6606
- #endif
6607
6209
 
6608
6210
  //
6609
6211
  // matrix-matrix multiplication
@@ -6631,11 +6233,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
6631
6233
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6632
6234
  template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
6633
6235
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6634
- #if QK_K == 64
6635
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6636
- #else
6637
6236
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6638
- #endif
6639
6237
 
6640
6238
  //
6641
6239
  // indirect matrix-matrix multiplication
@@ -6663,11 +6261,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
6663
6261
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6664
6262
  template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
6665
6263
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6666
- #if QK_K == 64
6667
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6668
- #else
6669
6264
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6670
- #endif
6671
6265
 
6672
6266
  //
6673
6267
  // matrix-vector multiplication
@@ -6876,7 +6470,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
6876
6470
  template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6877
6471
  template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6878
6472
  template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6879
- #if QK_K != 64
6880
6473
  template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6881
- #endif
6882
6474