llama_cpp 0.15.1 → 0.15.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -229,6 +229,13 @@ kernel void kernel_relu(
229
229
  dst[tpig] = max(0.0f, src0[tpig]);
230
230
  }
231
231
 
232
+ kernel void kernel_sigmoid(
233
+ device const float * src0,
234
+ device float * dst,
235
+ uint tpig[[thread_position_in_grid]]) {
236
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
237
+ }
238
+
232
239
  kernel void kernel_tanh(
233
240
  device const float * src0,
234
241
  device float * dst,
@@ -356,7 +363,6 @@ template<typename T>
356
363
  kernel void kernel_soft_max(
357
364
  device const char * src0,
358
365
  device const char * src1,
359
- device const char * src2,
360
366
  device char * dst,
361
367
  constant int64_t & ne00,
362
368
  constant int64_t & ne01,
@@ -378,10 +384,9 @@ kernel void kernel_soft_max(
378
384
 
379
385
  device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
380
386
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
381
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
382
387
  device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
383
388
 
384
- float slope = 0.0f;
389
+ float slope = 1.0f;
385
390
 
386
391
  // ALiBi
387
392
  if (max_bias > 0.0f) {
@@ -397,7 +402,7 @@ kernel void kernel_soft_max(
397
402
  float lmax = -INFINITY;
398
403
 
399
404
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
400
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
405
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
401
406
  }
402
407
 
403
408
  // find the max value in the block
@@ -422,7 +427,7 @@ kernel void kernel_soft_max(
422
427
  // parallel sum
423
428
  float lsum = 0.0f;
424
429
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
425
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
430
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
426
431
  lsum += exp_psrc0;
427
432
  pdst[i00] = exp_psrc0;
428
433
  }
@@ -461,7 +466,6 @@ template<typename T>
461
466
  kernel void kernel_soft_max_4(
462
467
  device const char * src0,
463
468
  device const char * src1,
464
- device const char * src2,
465
469
  device char * dst,
466
470
  constant int64_t & ne00,
467
471
  constant int64_t & ne01,
@@ -483,10 +487,9 @@ kernel void kernel_soft_max_4(
483
487
 
484
488
  device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
485
489
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
486
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
487
490
  device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
488
491
 
489
- float slope = 0.0f;
492
+ float slope = 1.0f;
490
493
 
491
494
  if (max_bias > 0.0f) {
492
495
  const int64_t h = i02;
@@ -501,7 +504,7 @@ kernel void kernel_soft_max_4(
501
504
  float4 lmax4 = -INFINITY;
502
505
 
503
506
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
504
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
507
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
505
508
  }
506
509
 
507
510
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -527,7 +530,7 @@ kernel void kernel_soft_max_4(
527
530
  // parallel sum
528
531
  float4 lsum4 = 0.0f;
529
532
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
530
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
533
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
531
534
  lsum4 += exp_psrc4;
532
535
  pdst4[i00] = exp_psrc4;
533
536
  }
@@ -1595,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
1595
1598
  }
1596
1599
  }
1597
1600
 
1598
- kernel void kernel_alibi_f32(
1599
- device const float * src0,
1600
- device float * dst,
1601
- constant int64_t & ne00,
1602
- constant int64_t & ne01,
1603
- constant int64_t & ne02,
1604
- constant int64_t & ne03,
1605
- constant uint64_t & nb00,
1606
- constant uint64_t & nb01,
1607
- constant uint64_t & nb02,
1608
- constant uint64_t & nb03,
1609
- constant int64_t & ne0,
1610
- constant int64_t & ne1,
1611
- constant int64_t & ne2,
1612
- constant int64_t & ne3,
1613
- constant uint64_t & nb0,
1614
- constant uint64_t & nb1,
1615
- constant uint64_t & nb2,
1616
- constant uint64_t & nb3,
1617
- constant float & m0,
1618
- constant float & m1,
1619
- constant int & n_heads_log2_floor,
1620
- uint3 tgpig[[threadgroup_position_in_grid]],
1621
- uint3 tpitg[[thread_position_in_threadgroup]],
1622
- uint3 ntg[[threads_per_threadgroup]]) {
1623
- const int64_t i03 = tgpig[2];
1624
- const int64_t i02 = tgpig[1];
1625
- const int64_t i01 = tgpig[0];
1626
-
1627
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1628
-
1629
- const int64_t i3 = n / (ne2*ne1*ne0);
1630
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1631
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1632
- //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1633
-
1634
- const int64_t k = i3*ne3 + i2;
1635
-
1636
- float m_k;
1637
- if (k < n_heads_log2_floor) {
1638
- m_k = pow(m0, k + 1);
1639
- } else {
1640
- m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1641
- }
1642
-
1643
- device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1644
- device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1645
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1646
- const float src_v = *(device float *)(src_row + i00*nb00);
1647
- device float * dst_v = (device float *)(dst_row + i00*nb0);
1648
- *dst_v = i00 * m_k + src_v;
1649
- }
1650
- }
1651
-
1652
1601
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1653
1602
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1654
1603
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -1691,6 +1640,7 @@ static void rope_yarn_corr_dims(
1691
1640
  typedef void (rope_t)(
1692
1641
  device const void * src0,
1693
1642
  device const int32_t * src1,
1643
+ device const float * src2,
1694
1644
  device float * dst,
1695
1645
  constant int64_t & ne00,
1696
1646
  constant int64_t & ne01,
@@ -1726,6 +1676,7 @@ template<typename T>
1726
1676
  kernel void kernel_rope(
1727
1677
  device const void * src0,
1728
1678
  device const int32_t * src1,
1679
+ device const float * src2,
1729
1680
  device float * dst,
1730
1681
  constant int64_t & ne00,
1731
1682
  constant int64_t & ne01,
@@ -1795,8 +1746,10 @@ kernel void kernel_rope(
1795
1746
 
1796
1747
  // simplified from `(ib * n_dims + ic) * inv_ndims`
1797
1748
  const float cur_rot = inv_ndims*ic - ib;
1749
+ const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1750
+
1751
+ const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1798
1752
 
1799
- const float theta = theta_0 * pow(freq_base, cur_rot);
1800
1753
  float cos_theta, sin_theta;
1801
1754
  rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1802
1755
 
@@ -1903,7 +1856,10 @@ kernel void kernel_upscale_f32(
1903
1856
  constant uint64_t & nb1,
1904
1857
  constant uint64_t & nb2,
1905
1858
  constant uint64_t & nb3,
1906
- constant int32_t & sf,
1859
+ constant float & sf0,
1860
+ constant float & sf1,
1861
+ constant float & sf2,
1862
+ constant float & sf3,
1907
1863
  uint3 tgpig[[threadgroup_position_in_grid]],
1908
1864
  uint3 tpitg[[thread_position_in_threadgroup]],
1909
1865
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -1912,15 +1868,17 @@ kernel void kernel_upscale_f32(
1912
1868
  const int64_t i2 = tgpig.y;
1913
1869
  const int64_t i1 = tgpig.x;
1914
1870
 
1915
- const int64_t i03 = i3;
1916
- const int64_t i02 = i2;
1917
- const int64_t i01 = i1/sf;
1918
-
1919
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1920
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1871
+ const int64_t i03 = i3/sf3;
1872
+ const int64_t i02 = i2/sf2;
1873
+ const int64_t i01 = i1/sf1;
1921
1874
 
1922
1875
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1923
- dst_ptr[i0] = src0_ptr[i0/sf];
1876
+ const int64_t i00 = i0/sf0;
1877
+
1878
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1879
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1880
+
1881
+ dst_ptr[0] = src0_ptr[0];
1924
1882
  }
1925
1883
  }
1926
1884
 
@@ -2100,29 +2058,29 @@ typedef void (flash_attn_ext_f16_t)(
2100
2058
  device const char * v,
2101
2059
  device const char * mask,
2102
2060
  device float * dst,
2103
- constant int64_t & ne00,
2104
2061
  constant int64_t & ne01,
2105
2062
  constant int64_t & ne02,
2106
2063
  constant int64_t & ne03,
2107
- constant uint64_t & nb00,
2108
2064
  constant uint64_t & nb01,
2109
2065
  constant uint64_t & nb02,
2110
2066
  constant uint64_t & nb03,
2111
- constant int64_t & ne10,
2112
2067
  constant int64_t & ne11,
2113
2068
  constant int64_t & ne12,
2114
2069
  constant int64_t & ne13,
2115
- constant uint64_t & nb10,
2116
2070
  constant uint64_t & nb11,
2117
2071
  constant uint64_t & nb12,
2118
2072
  constant uint64_t & nb13,
2119
- constant int64_t & ne31,
2073
+ constant uint64_t & nb21,
2074
+ constant uint64_t & nb22,
2075
+ constant uint64_t & nb23,
2120
2076
  constant uint64_t & nb31,
2121
- constant int64_t & ne0,
2122
2077
  constant int64_t & ne1,
2123
2078
  constant int64_t & ne2,
2124
- constant int64_t & ne3,
2125
2079
  constant float & scale,
2080
+ constant float & max_bias,
2081
+ constant float & m0,
2082
+ constant float & m1,
2083
+ constant uint32_t & n_head_log2,
2126
2084
  threadgroup half * shared,
2127
2085
  uint3 tgpig[[threadgroup_position_in_grid]],
2128
2086
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2138,29 +2096,29 @@ kernel void kernel_flash_attn_ext_f16(
2138
2096
  device const char * v,
2139
2097
  device const char * mask,
2140
2098
  device float * dst,
2141
- constant int64_t & ne00,
2142
2099
  constant int64_t & ne01,
2143
2100
  constant int64_t & ne02,
2144
2101
  constant int64_t & ne03,
2145
- constant uint64_t & nb00,
2146
2102
  constant uint64_t & nb01,
2147
2103
  constant uint64_t & nb02,
2148
2104
  constant uint64_t & nb03,
2149
- constant int64_t & ne10,
2150
2105
  constant int64_t & ne11,
2151
2106
  constant int64_t & ne12,
2152
2107
  constant int64_t & ne13,
2153
- constant uint64_t & nb10,
2154
2108
  constant uint64_t & nb11,
2155
2109
  constant uint64_t & nb12,
2156
2110
  constant uint64_t & nb13,
2157
- constant int64_t & ne31,
2111
+ constant uint64_t & nb21,
2112
+ constant uint64_t & nb22,
2113
+ constant uint64_t & nb23,
2158
2114
  constant uint64_t & nb31,
2159
- constant int64_t & ne0,
2160
2115
  constant int64_t & ne1,
2161
2116
  constant int64_t & ne2,
2162
- constant int64_t & ne3,
2163
2117
  constant float & scale,
2118
+ constant float & max_bias,
2119
+ constant float & m0,
2120
+ constant float & m1,
2121
+ constant uint32_t & n_head_log2,
2164
2122
  threadgroup half * shared [[threadgroup(0)]],
2165
2123
  uint3 tgpig[[threadgroup_position_in_grid]],
2166
2124
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2225,10 +2183,6 @@ kernel void kernel_flash_attn_ext_f16(
2225
2183
  const short ne22 = ne12;
2226
2184
  const short ne23 = ne13;
2227
2185
 
2228
- const uint nb21 = nb11;
2229
- const uint nb22 = nb12;
2230
- const uint nb23 = nb13;
2231
-
2232
2186
  // broadcast
2233
2187
  const short rk2 = ne02/ne12;
2234
2188
  const short rk3 = ne03/ne13;
@@ -2254,8 +2208,17 @@ kernel void kernel_flash_attn_ext_f16(
2254
2208
  // pointer to the mask
2255
2209
  device const half * mp = (device const half *) (mask + iq1*nb31);
2256
2210
 
2257
- // prepare diagonal scale matrix
2258
- simdgroup_float8x8 mscale(scale);
2211
+ float slope = 1.0f;
2212
+
2213
+ // ALiBi
2214
+ if (max_bias > 0.0f) {
2215
+ const uint32_t h = iq2;
2216
+
2217
+ const float base = h < n_head_log2 ? m0 : m1;
2218
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
+
2220
+ slope = pow(base, exph);
2221
+ }
2259
2222
 
2260
2223
  // loop over the KV cache
2261
2224
  // each simdgroup handles blocks of Q rows and C columns
@@ -2279,12 +2242,20 @@ kernel void kernel_flash_attn_ext_f16(
2279
2242
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2280
2243
  }
2281
2244
 
2282
- // mqk = mqk*scale + mask
2283
- simdgroup_half8x8 mm;
2284
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2285
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2286
-
2287
2245
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2246
+
2247
+ const short tx = tiisg%4;
2248
+ const short ty = tiisg/4;
2249
+
2250
+ if (mask != q) {
2251
+ // mqk = mqk*scale + mask*slope
2252
+ ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
2253
+ ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
2254
+ } else {
2255
+ // mqk = mqk*scale
2256
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
2257
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
2258
+ }
2288
2259
  }
2289
2260
  }
2290
2261
 
@@ -2456,29 +2427,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
2456
2427
  device const char * v,
2457
2428
  device const char * mask,
2458
2429
  device float * dst,
2459
- constant int64_t & ne00,
2460
2430
  constant int64_t & ne01,
2461
2431
  constant int64_t & ne02,
2462
2432
  constant int64_t & ne03,
2463
- constant uint64_t & nb00,
2464
2433
  constant uint64_t & nb01,
2465
2434
  constant uint64_t & nb02,
2466
2435
  constant uint64_t & nb03,
2467
- constant int64_t & ne10,
2468
2436
  constant int64_t & ne11,
2469
2437
  constant int64_t & ne12,
2470
2438
  constant int64_t & ne13,
2471
- constant uint64_t & nb10,
2472
2439
  constant uint64_t & nb11,
2473
2440
  constant uint64_t & nb12,
2474
2441
  constant uint64_t & nb13,
2475
- constant int64_t & ne31,
2442
+ constant uint64_t & nb21,
2443
+ constant uint64_t & nb22,
2444
+ constant uint64_t & nb23,
2476
2445
  constant uint64_t & nb31,
2477
- constant int64_t & ne0,
2478
2446
  constant int64_t & ne1,
2479
2447
  constant int64_t & ne2,
2480
- constant int64_t & ne3,
2481
2448
  constant float & scale,
2449
+ constant float & max_bias,
2450
+ constant float & m0,
2451
+ constant float & m1,
2452
+ constant uint32_t & n_head_log2,
2482
2453
  threadgroup half * shared [[threadgroup(0)]],
2483
2454
  uint3 tgpig[[threadgroup_position_in_grid]],
2484
2455
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2497,6 +2468,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
2497
2468
 
2498
2469
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2499
2470
 
2471
+ float slope = 1.0f;
2472
+
2473
+ // ALiBi
2474
+ if (max_bias > 0.0f) {
2475
+ const uint32_t h = iq2;
2476
+
2477
+ const float base = h < n_head_log2 ? m0 : m1;
2478
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2479
+
2480
+ slope = pow(base, exp);
2481
+ }
2482
+
2500
2483
  //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2501
2484
  threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2502
2485
  threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@@ -2537,10 +2520,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
2537
2520
  const short ne22 = ne12;
2538
2521
  const short ne23 = ne13;
2539
2522
 
2540
- const uint nb21 = nb11;
2541
- const uint nb22 = nb12;
2542
- const uint nb23 = nb13;
2543
-
2544
2523
  // broadcast
2545
2524
  const short rk2 = ne02/ne12;
2546
2525
  const short rk3 = ne03/ne13;
@@ -2603,10 +2582,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
2603
2582
  mqk += simd_shuffle_down(mqk, 2);
2604
2583
  mqk += simd_shuffle_down(mqk, 1);
2605
2584
 
2606
- // mqk = mqk*scale + mask
2585
+ // mqk = mqk*scale + mask*slope
2607
2586
  if (tiisg == 0) {
2608
- float4 mm = (float4) mp4[ic/4 + cc];
2609
- mqk = mqk*scale + mm;
2587
+ mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
2610
2588
 
2611
2589
  ss4[cc] = mqk;
2612
2590
  }
@@ -3408,7 +3386,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3408
3386
 
3409
3387
  const int step = sizeof(block_q2_K) * nb;
3410
3388
 
3411
- #if QK_K == 256
3412
3389
  const int ix = tiisg/8; // 0...3
3413
3390
  const int it = tiisg%8; // 0...7
3414
3391
  const int iq = it/4; // 0 or 1
@@ -3460,57 +3437,6 @@ void kernel_mul_mv_q2_K_f32_impl(
3460
3437
 
3461
3438
  y4 += 4 * QK_K;
3462
3439
  }
3463
- #else
3464
- const int ix = tiisg/2; // 0...15
3465
- const int it = tiisg%2; // 0...1
3466
-
3467
- device const float * y4 = y + ix * QK_K + 8 * it;
3468
-
3469
- for (int ib = ix; ib < nb; ib += 16) {
3470
-
3471
- float4 sumy = {0.f, 0.f, 0.f, 0.f};
3472
- for (int i = 0; i < 8; ++i) {
3473
- yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
3474
- yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
3475
- yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
3476
- yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
3477
- }
3478
-
3479
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
3480
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3481
- device const half * dh = &x[ib].d;
3482
-
3483
- for (int row = 0; row < N_DST; row++) {
3484
-
3485
- float4 acc1 = {0.f, 0.f, 0.f, 0.f};
3486
- float4 acc2 = {0.f, 0.f, 0.f, 0.f};
3487
- for (int i = 0; i < 8; i += 2) {
3488
- acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
3489
- acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
3490
- acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
3491
- acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
3492
- acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
3493
- acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
3494
- acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
3495
- acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
3496
- }
3497
-
3498
- float dall = dh[0];
3499
- float dmin = dh[1];
3500
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
3501
- (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
3502
- (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
3503
- (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
3504
- dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
3505
-
3506
- qs += step/2;
3507
- sc += step;
3508
- dh += step/2;
3509
- }
3510
-
3511
- y4 += 16 * QK_K;
3512
- }
3513
- #endif
3514
3440
 
3515
3441
  for (int row = 0; row < N_DST; ++row) {
3516
3442
  all_sum = simd_sum(sumf[row]);
@@ -3548,7 +3474,6 @@ kernel void kernel_mul_mv_q2_K_f32(
3548
3474
  kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3549
3475
  }
3550
3476
 
3551
- #if QK_K == 256
3552
3477
  void kernel_mul_mv_q3_K_f32_impl(
3553
3478
  device const void * src0,
3554
3479
  device const float * src1,
@@ -3707,84 +3632,6 @@ void kernel_mul_mv_q3_K_f32_impl(
3707
3632
  }
3708
3633
  }
3709
3634
  }
3710
- #else
3711
- void kernel_mul_mv_q3_K_f32_impl(
3712
- device const void * src0,
3713
- device const float * src1,
3714
- device float * dst,
3715
- constant int64_t & ne00,
3716
- constant int64_t & ne01,
3717
- constant int64_t & ne02,
3718
- constant int64_t & ne10,
3719
- constant int64_t & ne12,
3720
- constant int64_t & ne0,
3721
- constant int64_t & ne1,
3722
- constant uint & r2,
3723
- constant uint & r3,
3724
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3725
- uint3 tgpig[[threadgroup_position_in_grid]],
3726
- uint tiisg[[thread_index_in_simdgroup]],
3727
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3728
-
3729
- const int nb = ne00/QK_K;
3730
-
3731
- const int64_t r0 = tgpig.x;
3732
- const int64_t r1 = tgpig.y;
3733
- const int64_t im = tgpig.z;
3734
-
3735
- const int row = 2 * r0 + sgitg;
3736
-
3737
- const uint i12 = im%ne12;
3738
- const uint i13 = im/ne12;
3739
-
3740
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3741
-
3742
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
3743
- device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3744
-
3745
- const int ix = tiisg/4;
3746
- const int il = 4 * (tiisg%4);// 0, 4, 8, 12
3747
- const int iq = il/8; // 0, 0, 1, 1
3748
- const int in = il%8; // 0, 4, 0, 4
3749
-
3750
- float2 sum = {0.f, 0.f};
3751
-
3752
- for (int i = ix; i < nb; i += 8) {
3753
-
3754
- const float d_all = (float)(x[i].d);
3755
-
3756
- device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
3757
- device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
3758
- device const uint16_t * s = (device const uint16_t *)(x[i].scales);
3759
- device const float * y = yy + i * QK_K + il;
3760
-
3761
- const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
3762
- const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
3763
- const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
3764
- const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
3765
-
3766
- for (int l = 0; l < 4; l += 2) {
3767
- const uint16_t hm = h[l/2] >> iq;
3768
- sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
3769
- + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
3770
- + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
3771
- + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
3772
- sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
3773
- + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
3774
- + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
3775
- + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
3776
- }
3777
-
3778
- }
3779
- const float sumf = sum[0] + sum[1] * 1.f/256.f;
3780
-
3781
- const float tot = simd_sum(sumf);
3782
- if (tiisg == 0) {
3783
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
3784
- }
3785
-
3786
- }
3787
- #endif
3788
3635
 
3789
3636
  [[host_name("kernel_mul_mv_q3_K_f32")]]
3790
3637
  kernel void kernel_mul_mv_q3_K_f32(
@@ -3814,7 +3661,6 @@ kernel void kernel_mul_mv_q3_K_f32(
3814
3661
  kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
3815
3662
  }
3816
3663
 
3817
- #if QK_K == 256
3818
3664
  void kernel_mul_mv_q4_K_f32_impl(
3819
3665
  device const void * src0,
3820
3666
  device const float * src1,
@@ -3928,103 +3774,6 @@ void kernel_mul_mv_q4_K_f32_impl(
3928
3774
  }
3929
3775
  }
3930
3776
  }
3931
- #else
3932
- void kernel_mul_mv_q4_K_f32_impl(
3933
- device const void * src0,
3934
- device const float * src1,
3935
- device float * dst,
3936
- constant int64_t & ne00,
3937
- constant int64_t & ne01,
3938
- constant int64_t & ne02,
3939
- constant int64_t & ne10,
3940
- constant int64_t & ne12,
3941
- constant int64_t & ne0,
3942
- constant int64_t & ne1,
3943
- constant uint & r2,
3944
- constant uint & r3,
3945
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3946
- uint3 tgpig[[threadgroup_position_in_grid]],
3947
- uint tiisg[[thread_index_in_simdgroup]],
3948
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3949
-
3950
- const int ix = tiisg/4; // 0...7
3951
- const int it = tiisg%4; // 0...3
3952
-
3953
- const int nb = ne00/QK_K;
3954
- const int r0 = tgpig.x;
3955
- const int r1 = tgpig.y;
3956
- const int im = tgpig.z;
3957
- const int first_row = r0 * N_DST;
3958
- const int ib_row = first_row * nb;
3959
-
3960
- const uint i12 = im%ne12;
3961
- const uint i13 = im/ne12;
3962
-
3963
- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3964
-
3965
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
3966
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3967
-
3968
- float yl[8];
3969
- float yh[8];
3970
- float sumf[N_DST]={0.f}, all_sum;
3971
-
3972
- const int step = sizeof(block_q4_K) * nb / 2;
3973
-
3974
- device const float * y4 = y + ix * QK_K + 8 * it;
3975
-
3976
- uint16_t sc16[4];
3977
-
3978
- for (int ib = ix; ib < nb; ib += 8) {
3979
-
3980
- float2 sumy = {0.f, 0.f};
3981
- for (int i = 0; i < 8; ++i) {
3982
- yl[i] = y4[i+ 0]; sumy[0] += yl[i];
3983
- yh[i] = y4[i+32]; sumy[1] += yh[i];
3984
- }
3985
-
3986
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
3987
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
3988
- device const half * dh = x[ib].d;
3989
-
3990
- for (int row = 0; row < N_DST; row++) {
3991
-
3992
- sc16[0] = sc[0] & 0x000f;
3993
- sc16[1] = sc[0] & 0x0f00;
3994
- sc16[2] = sc[0] & 0x00f0;
3995
- sc16[3] = sc[0] & 0xf000;
3996
-
3997
- float2 acc1 = {0.f, 0.f};
3998
- float2 acc2 = {0.f, 0.f};
3999
- for (int i = 0; i < 8; i += 2) {
4000
- acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
4001
- acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
4002
- acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
4003
- acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
4004
- }
4005
-
4006
- float dall = dh[0];
4007
- float dmin = dh[1];
4008
- sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
4009
- (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
4010
- dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
4011
-
4012
- qs += step;
4013
- sc += step;
4014
- dh += step;
4015
- }
4016
-
4017
- y4 += 8 * QK_K;
4018
- }
4019
-
4020
- for (int row = 0; row < N_DST; ++row) {
4021
- all_sum = simd_sum(sumf[row]);
4022
- if (tiisg == 0) {
4023
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4024
- }
4025
- }
4026
- }
4027
- #endif
4028
3777
 
4029
3778
  [[host_name("kernel_mul_mv_q4_K_f32")]]
4030
3779
  kernel void kernel_mul_mv_q4_K_f32(
@@ -4092,8 +3841,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4092
3841
 
4093
3842
  const int step = sizeof(block_q5_K) * nb;
4094
3843
 
4095
- #if QK_K == 256
4096
- #
4097
3844
  float yl[16], yh[16];
4098
3845
 
4099
3846
  const uint16_t kmask1 = 0x3f3f;
@@ -4176,54 +3923,6 @@ void kernel_mul_mv_q5_K_f32_impl(
4176
3923
  y1 += 4 * QK_K;
4177
3924
 
4178
3925
  }
4179
- #else
4180
- float yl[8], yh[8];
4181
-
4182
- const int il = 4 * (tiisg/8); // 0, 4, 8, 12
4183
- const int ix = tiisg%8;
4184
- const int iq = il/8; // 0, 0, 1, 1
4185
- const int in = il%8; // 0, 4, 0, 4
4186
-
4187
- device const float * y = yy + ix*QK_K + il;
4188
-
4189
- for (int i = ix; i < nb; i += 8) {
4190
-
4191
- for (int l = 0; l < 4; ++l) {
4192
- yl[l+0] = y[l+ 0];
4193
- yl[l+4] = y[l+16];
4194
- yh[l+0] = y[l+32];
4195
- yh[l+4] = y[l+48];
4196
- }
4197
-
4198
- device const half * dh = &x[i].d;
4199
- device const uint8_t * q = x[i].qs + il;
4200
- device const uint8_t * h = x[i].qh + in;
4201
- device const int8_t * s = x[i].scales;
4202
-
4203
- for (int row = 0; row < 2; ++row) {
4204
-
4205
- const float d = dh[0];
4206
-
4207
- float2 acc = {0.f, 0.f};
4208
- for (int l = 0; l < 4; ++l) {
4209
- const uint8_t hl = h[l] >> iq;
4210
- acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
4211
- + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
4212
- acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
4213
- + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
4214
- }
4215
- sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
4216
-
4217
- q += step;
4218
- h += step;
4219
- s += step;
4220
- dh += step/2;
4221
-
4222
- }
4223
-
4224
- y += 8 * QK_K;
4225
- }
4226
- #endif
4227
3926
 
4228
3927
  for (int row = 0; row < 2; ++row) {
4229
3928
  const float tot = simd_sum(sumf[row]);
@@ -4302,7 +4001,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4302
4001
 
4303
4002
  float sumf = 0;
4304
4003
 
4305
- #if QK_K == 256
4306
4004
  const int tid = tiisg/2;
4307
4005
  const int ix = tiisg%2;
4308
4006
  const int ip = tid/8; // 0 or 1
@@ -4338,30 +4036,6 @@ void kernel_mul_mv_q6_K_f32_impl(
4338
4036
 
4339
4037
  }
4340
4038
 
4341
- #else
4342
- const int ix = tiisg/4;
4343
- const int il = 4*(tiisg%4);
4344
-
4345
- for (int i = ix; i < nb; i += 8) {
4346
- device const float * y = yy + i * QK_K + il;
4347
- device const uint8_t * ql = x[i].ql + il;
4348
- device const uint8_t * qh = x[i].qh + il;
4349
- device const int8_t * s = x[i].scales;
4350
-
4351
- const float d = x[i].d;
4352
-
4353
- float4 sums = {0.f, 0.f, 0.f, 0.f};
4354
- for (int l = 0; l < 4; ++l) {
4355
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
4356
- sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
4357
- sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
4358
- sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
4359
- }
4360
- sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
4361
- }
4362
-
4363
- #endif
4364
-
4365
4039
  const float tot = simd_sum(sumf);
4366
4040
  if (tiisg == 0) {
4367
4041
  dst[r1*ne0 + im*ne0*ne1 + row] = tot;
@@ -5195,9 +4869,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5195
4869
 
5196
4870
  device const float * y4 = y + 32 * ix;
5197
4871
 
5198
- #if QK_K != 64
5199
4872
  iq1m_scale_t scale;
5200
- #endif
5201
4873
 
5202
4874
  for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
5203
4875
 
@@ -5218,10 +4890,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
5218
4890
  device const uint16_t * sc = (device const uint16_t *)xr->scales;
5219
4891
 
5220
4892
  for (int row = 0; row < N_DST; row++) {
5221
-
5222
- #if QK_K != 64
5223
4893
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5224
- #endif
5225
4894
 
5226
4895
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5227
4896
  constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -5237,14 +4906,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
5237
4906
  }
5238
4907
  const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5239
4908
  const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5240
- #if QK_K == 64
5241
- const float d = (float) *((device const half *)(sc - 1));
5242
- sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
5243
- (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
5244
- #else
4909
+
5245
4910
  sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
5246
4911
  (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
5247
- #endif
5248
4912
 
5249
4913
  sc += nb*sizeof(block_iq1_m)/2;
5250
4914
  qs += nb*sizeof(block_iq1_m);
@@ -5356,7 +5020,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
5356
5020
  }
5357
5021
  }
5358
5022
 
5359
- #if QK_K != 64
5360
5023
  void kernel_mul_mv_iq4_xs_f32_impl(
5361
5024
  device const void * src0,
5362
5025
  device const float * src1,
@@ -5451,7 +5114,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
5451
5114
  }
5452
5115
  }
5453
5116
  }
5454
- #endif
5455
5117
 
5456
5118
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
5457
5119
  kernel void kernel_mul_mv_iq1_s_f32(
@@ -5564,11 +5226,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
5564
5226
  uint tiisg[[thread_index_in_simdgroup]],
5565
5227
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5566
5228
 
5567
- #if QK_K == 64
5568
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5569
- #else
5570
5229
  kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5571
- #endif
5572
5230
  }
5573
5231
 
5574
5232
  //============================= templates and their specializations =============================
@@ -5694,10 +5352,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
5694
5352
  float dl, ml;
5695
5353
  uint8_t sc = xb->scales[il];
5696
5354
 
5697
- #if QK_K == 256
5698
5355
  q = q + 32*(il/8) + 16*(il&1);
5699
5356
  il = (il/2)%4;
5700
- #endif
5357
+
5701
5358
  half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5702
5359
  uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5703
5360
  dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
@@ -5713,7 +5370,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5713
5370
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
5714
5371
  device const int8_t * scales = (device const int8_t *)xb->scales;
5715
5372
 
5716
- #if QK_K == 256
5717
5373
  q = q + 32 * (il/8) + 16 * (il&1);
5718
5374
  h = h + 16 * (il&1);
5719
5375
  uint8_t m = 1 << (il/2);
@@ -5734,17 +5390,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
5734
5390
  for (int i = 0; i < 16; ++i) {
5735
5391
  reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
5736
5392
  }
5737
- #else
5738
- float kcoef = il&1 ? 1.f/16.f : 1.f;
5739
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
5740
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
5741
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5742
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5743
- uint8_t m = 1<<(il*2);
5744
- for (int i = 0; i < 16; ++i) {
5745
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
5746
- }
5747
- #endif
5748
5393
  }
5749
5394
 
5750
5395
  static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
@@ -5756,7 +5401,6 @@ template <typename type4x4>
5756
5401
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
5757
5402
  device const uchar * q = xb->qs;
5758
5403
 
5759
- #if QK_K == 256
5760
5404
  short is = (il/4) * 2;
5761
5405
  q = q + (il/4) * 32 + 16 * (il&1);
5762
5406
  il = il & 3;
@@ -5765,16 +5409,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
5765
5409
  const float min = xb->dmin;
5766
5410
  const float dl = d * sc[0];
5767
5411
  const float ml = min * sc[1];
5768
- #else
5769
- (void) get_scale_min_k4_just2;
5770
-
5771
- q = q + 16 * (il&1);
5772
- device const uint8_t * s = xb->scales;
5773
- device const half2 * dh = (device const half2 *)xb->d;
5774
- const float2 d = (float2)dh[0];
5775
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
5776
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
5777
- #endif
5412
+
5778
5413
  const ushort mask = il<2 ? 0x0F : 0xF0;
5779
5414
  for (int i = 0; i < 16; ++i) {
5780
5415
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
@@ -5786,7 +5421,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5786
5421
  device const uint8_t * q = xb->qs;
5787
5422
  device const uint8_t * qh = xb->qh;
5788
5423
 
5789
- #if QK_K == 256
5790
5424
  short is = (il/4) * 2;
5791
5425
  q = q + 32 * (il/4) + 16 * (il&1);
5792
5426
  qh = qh + 16 * (il&1);
@@ -5803,17 +5437,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
5803
5437
  for (int i = 0; i < 16; ++i) {
5804
5438
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
5805
5439
  }
5806
- #else
5807
- q = q + 16 * (il&1);
5808
- device const int8_t * s = xb->scales;
5809
- const float dl = xb->d * s[il];
5810
- uint8_t m = 1<<(il*2);
5811
- const float coef = il<2 ? 1.f : 1.f/16.f;
5812
- const ushort mask = il<2 ? 0x0F : 0xF0;
5813
- for (int i = 0; i < 16; ++i) {
5814
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
5815
- }
5816
- #endif
5817
5440
  }
5818
5441
 
5819
5442
  template <typename type4x4>
@@ -5823,15 +5446,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
5823
5446
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
5824
5447
  device const int8_t * scales = (device const int8_t *)xb->scales;
5825
5448
 
5826
- #if QK_K == 256
5827
5449
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
5828
5450
  qh = qh + 32*(il/8) + 16*(il&1);
5829
5451
  float sc = scales[(il%2) + 2 * ((il/2))];
5830
5452
  il = (il/2) & 3;
5831
- #else
5832
- ql = ql + 16 * (il&1);
5833
- float sc = scales[il];
5834
- #endif
5453
+
5835
5454
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5836
5455
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
5837
5456
  const float coef = il>1 ? 1.f/16.f : 1.f;
@@ -5988,20 +5607,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
5988
5607
  const int ib32 = il/2;
5989
5608
  il = il%2;
5990
5609
  device const uint16_t * sc = (device const uint16_t *)xb->scales;
5991
- #if QK_K == 64
5992
- const float d = xb->d;
5993
- #else
5610
+
5994
5611
  iq1m_scale_t scale;
5995
5612
  scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5996
5613
  const float d = scale.f16;
5997
- #endif
5614
+
5998
5615
  device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5999
5616
  device const uint8_t * qh = xb->qh + 2*ib32 + il;
6000
- #if QK_K == 64
6001
- const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
6002
- #else
5617
+
6003
5618
  const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
6004
- #endif
6005
5619
  const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
6006
5620
  const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
6007
5621
  constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -6031,9 +5645,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
6031
5645
 
6032
5646
  template <typename type4x4>
6033
5647
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
6034
- #if QK_K == 64
6035
- dequantize_iq4_nl(xb, il, reg);
6036
- #else
6037
5648
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
6038
5649
  const int ib32 = il/2;
6039
5650
  il = il%2;
@@ -6050,7 +5661,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
6050
5661
  reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
6051
5662
  reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
6052
5663
  }
6053
- #endif
6054
5664
  }
6055
5665
 
6056
5666
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6555,11 +6165,7 @@ kernel void kernel_mul_mm_id(
6555
6165
  sgitg);
6556
6166
  }
6557
6167
 
6558
- #if QK_K == 256
6559
6168
  #define QK_NL 16
6560
- #else
6561
- #define QK_NL 4
6562
- #endif
6563
6169
 
6564
6170
  //
6565
6171
  // get rows
@@ -6599,11 +6205,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
6599
6205
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6600
6206
  template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6601
6207
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6602
- #if QK_K == 64
6603
- template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6604
- #else
6605
6208
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6606
- #endif
6607
6209
 
6608
6210
  //
6609
6211
  // matrix-matrix multiplication
@@ -6631,11 +6233,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
6631
6233
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6632
6234
  template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
6633
6235
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6634
- #if QK_K == 64
6635
- template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6636
- #else
6637
6236
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6638
- #endif
6639
6237
 
6640
6238
  //
6641
6239
  // indirect matrix-matrix multiplication
@@ -6663,11 +6261,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
6663
6261
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6664
6262
  template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
6665
6263
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6666
- #if QK_K == 64
6667
- template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6668
- #else
6669
6264
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6670
- #endif
6671
6265
 
6672
6266
  //
6673
6267
  // matrix-vector multiplication
@@ -6876,7 +6470,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
6876
6470
  template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
6877
6471
  template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6878
6472
  template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6879
- #if QK_K != 64
6880
6473
  template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6881
- #endif
6882
6474