llama_cpp 0.12.0 → 0.12.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,8 @@
5
5
  #include <string.h>
6
6
  #include <assert.h>
7
7
  #include <float.h>
8
+ #include <stdlib.h> // for qsort
9
+ #include <stdio.h> // for GGML_ASSERT
8
10
 
9
11
  #ifdef __ARM_NEON
10
12
 
@@ -272,10 +274,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
272
274
 
273
275
  // vaddvq_s16
274
276
  // vpaddq_s16
277
+ // vpaddq_s32
275
278
  // vaddvq_s32
276
279
  // vaddvq_f32
277
280
  // vmaxvq_f32
278
281
  // vcvtnq_s32_f32
282
+ // vzip1_u8
283
+ // vzip2_u8
279
284
 
280
285
  inline static int32_t vaddvq_s16(int16x8_t v) {
281
286
  return
@@ -291,6 +296,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
291
296
  return vcombine_s16(a0, b0);
292
297
  }
293
298
 
299
+ inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
300
+ int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
301
+ int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
302
+ return vcombine_s32(a0, b0);
303
+ }
304
+
294
305
  inline static int32_t vaddvq_s32(int32x4_t v) {
295
306
  return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
296
307
  }
@@ -316,6 +327,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
316
327
  return res;
317
328
  }
318
329
 
330
+ inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
331
+ uint8x8_t res;
332
+
333
+ res[0] = a[0]; res[1] = b[0];
334
+ res[2] = a[1]; res[3] = b[1];
335
+ res[4] = a[2]; res[5] = b[2];
336
+ res[6] = a[3]; res[7] = b[3];
337
+
338
+ return res;
339
+ }
340
+
341
+ inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
342
+ uint8x8_t res;
343
+
344
+ res[0] = a[4]; res[1] = b[4];
345
+ res[2] = a[5]; res[3] = b[5];
346
+ res[4] = a[6]; res[5] = b[6];
347
+ res[6] = a[7]; res[7] = b[7];
348
+
349
+ return res;
350
+ }
351
+
319
352
  // vld1q_s16_x2
320
353
  // vld1q_u8_x2
321
354
  // vld1q_u8_x4
@@ -482,6 +515,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
482
515
  quantize_row_q4_0_reference(x, y, k);
483
516
  }
484
517
 
518
+
485
519
  void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
486
520
  const int qk = QK4_1;
487
521
 
@@ -1211,7 +1245,8 @@ static inline int nearest_int(float fval) {
1211
1245
  return (i & 0x007fffff) - 0x00400000;
1212
1246
  }
1213
1247
 
1214
- static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
1248
+ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
1249
+ const float * restrict qw) {
1215
1250
  float max = 0;
1216
1251
  float amax = 0;
1217
1252
  for (int i = 0; i < n; ++i) {
@@ -1237,14 +1272,13 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
1237
1272
  rmse_type = -rmse_type;
1238
1273
  return_early = true;
1239
1274
  }
1240
- int weight_type = rmse_type%2;
1241
1275
  float sumlx = 0;
1242
1276
  float suml2 = 0;
1243
1277
  for (int i = 0; i < n; ++i) {
1244
1278
  int l = nearest_int(iscale * x[i]);
1245
1279
  l = MAX(-nmax, MIN(nmax-1, l));
1246
1280
  L[i] = l + nmax;
1247
- float w = weight_type == 1 ? x[i] * x[i] : 1;
1281
+ float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
1248
1282
  sumlx += w*x[i]*l;
1249
1283
  suml2 += w*l*l;
1250
1284
  }
@@ -1260,7 +1294,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
1260
1294
  for (int i = 0; i < n; ++i) {
1261
1295
  int l = nearest_int(iscale * x[i]);
1262
1296
  l = MAX(-nmax, MIN(nmax-1, l));
1263
- float w = weight_type == 1 ? x[i] * x[i] : 1;
1297
+ float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
1264
1298
  sumlx += w*x[i]*l;
1265
1299
  suml2 += w*l*l;
1266
1300
  }
@@ -1608,6 +1642,241 @@ size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n
1608
1642
  return (n/QK_K*sizeof(block_q2_K));
1609
1643
  }
1610
1644
 
1645
+ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
1646
+ uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
1647
+ float rmin, float rdelta, int nstep, bool use_mad) {
1648
+ float min = x[0];
1649
+ float max = x[0];
1650
+ float sum_w = weights ? weights[0] : x[0]*x[0];
1651
+ float sum_x = sum_w * x[0];
1652
+ for (int i = 1; i < n; ++i) {
1653
+ if (x[i] < min) min = x[i];
1654
+ if (x[i] > max) max = x[i];
1655
+ float w = weights ? weights[i] : x[i]*x[i];
1656
+ sum_w += w;
1657
+ sum_x += w * x[i];
1658
+ }
1659
+ if (min > 0) {
1660
+ min = 0;
1661
+ }
1662
+ if (max <= min) {
1663
+ for (int i = 0; i < n; ++i) L[i] = 0;
1664
+ *the_min = -min;
1665
+ return 0.f;
1666
+ }
1667
+ float iscale = nmax/(max - min);
1668
+ float scale = 1/iscale;
1669
+ float best_mad = 0;
1670
+ for (int i = 0; i < n; ++i) {
1671
+ int l = nearest_int(iscale*(x[i] - min));
1672
+ L[i] = MAX(0, MIN(nmax, l));
1673
+ float diff = scale * L[i] + min - x[i];
1674
+ diff = use_mad ? fabsf(diff) : diff*diff;
1675
+ float w = weights ? weights[i] : x[i]*x[i];
1676
+ best_mad += w * diff;
1677
+ }
1678
+ if (nstep < 1) {
1679
+ *the_min = -min;
1680
+ return scale;
1681
+ }
1682
+ for (int is = 0; is <= nstep; ++is) {
1683
+ iscale = (rmin + rdelta*is + nmax)/(max - min);
1684
+ float sum_l = 0, sum_l2 = 0, sum_xl = 0;
1685
+ for (int i = 0; i < n; ++i) {
1686
+ int l = nearest_int(iscale*(x[i] - min));
1687
+ l = MAX(0, MIN(nmax, l));
1688
+ Laux[i] = l;
1689
+ float w = weights ? weights[i] : x[i]*x[i];
1690
+ sum_l += w*l;
1691
+ sum_l2 += w*l*l;
1692
+ sum_xl += w*l*x[i];
1693
+ }
1694
+ float D = sum_w * sum_l2 - sum_l * sum_l;
1695
+ if (D > 0) {
1696
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
1697
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
1698
+ if (this_min > 0) {
1699
+ this_min = 0;
1700
+ this_scale = sum_xl / sum_l2;
1701
+ }
1702
+ float mad = 0;
1703
+ for (int i = 0; i < n; ++i) {
1704
+ float diff = this_scale * Laux[i] + this_min - x[i];
1705
+ diff = use_mad ? fabsf(diff) : diff*diff;
1706
+ float w = weights ? weights[i] : x[i]*x[i];
1707
+ mad += w * diff;
1708
+ }
1709
+ if (mad < best_mad) {
1710
+ for (int i = 0; i < n; ++i) {
1711
+ L[i] = Laux[i];
1712
+ }
1713
+ best_mad = mad;
1714
+ scale = this_scale;
1715
+ min = this_min;
1716
+ }
1717
+ }
1718
+ }
1719
+ *the_min = -min;
1720
+ return scale;
1721
+ }
1722
+
1723
+ static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
1724
+ float max = 0;
1725
+ for (int i = 0; i < n; ++i) {
1726
+ max = MAX(max, x[i]);
1727
+ }
1728
+ if (!max) { // all zero
1729
+ for (int i = 0; i < n; ++i) { L[i] = 0; }
1730
+ return 0.f;
1731
+ }
1732
+ float iscale = nmax / max;
1733
+ for (int i = 0; i < n; ++i) {
1734
+ L[i] = nearest_int(iscale * x[i]);
1735
+ }
1736
+ float scale = 1/iscale;
1737
+ float best_mse = 0;
1738
+ for (int i = 0; i < n; ++i) {
1739
+ float diff = x[i] - scale*L[i];
1740
+ float w = quant_weights[i];
1741
+ best_mse += w*diff*diff;
1742
+ }
1743
+ for (int is = -4; is <= 4; ++is) {
1744
+ if (is == 0) continue;
1745
+ float iscale_is = (0.1f*is + nmax)/max;
1746
+ float scale_is = 1/iscale_is;
1747
+ float mse = 0;
1748
+ for (int i = 0; i < n; ++i) {
1749
+ int l = nearest_int(iscale_is*x[i]);
1750
+ l = MIN(nmax, l);
1751
+ float diff = x[i] - scale_is*l;
1752
+ float w = quant_weights[i];
1753
+ mse += w*diff*diff;
1754
+ }
1755
+ if (mse < best_mse) {
1756
+ best_mse = mse;
1757
+ iscale = iscale_is;
1758
+ }
1759
+ }
1760
+ float sumlx = 0;
1761
+ float suml2 = 0;
1762
+ for (int i = 0; i < n; ++i) {
1763
+ int l = nearest_int(iscale * x[i]);
1764
+ l = MIN(nmax, l);
1765
+ L[i] = l;
1766
+ float w = quant_weights[i];
1767
+ sumlx += w*x[i]*l;
1768
+ suml2 += w*l*l;
1769
+ }
1770
+ for (int itry = 0; itry < 5; ++itry) {
1771
+ int n_changed = 0;
1772
+ for (int i = 0; i < n; ++i) {
1773
+ float w = quant_weights[i];
1774
+ float slx = sumlx - w*x[i]*L[i];
1775
+ float sl2 = suml2 - w*L[i]*L[i];
1776
+ if (slx > 0 && sl2 > 0) {
1777
+ int new_l = nearest_int(x[i] * sl2 / slx);
1778
+ new_l = MIN(nmax, new_l);
1779
+ if (new_l != L[i]) {
1780
+ slx += w*x[i]*new_l;
1781
+ sl2 += w*new_l*new_l;
1782
+ if (slx*slx*suml2 > sumlx*sumlx*sl2) {
1783
+ L[i] = new_l; sumlx = slx; suml2 = sl2;
1784
+ ++n_changed;
1785
+ }
1786
+ }
1787
+ }
1788
+ }
1789
+ if (!n_changed) {
1790
+ break;
1791
+ }
1792
+ }
1793
+ return sumlx / suml2;
1794
+ }
1795
+
1796
+ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
1797
+ GGML_ASSERT(quant_weights);
1798
+ assert(k % QK_K == 0);
1799
+ const int nb = k / QK_K;
1800
+ const bool requantize = true;
1801
+
1802
+ uint8_t L[QK_K];
1803
+ uint8_t Laux[16];
1804
+ float mins[QK_K/16];
1805
+ float scales[QK_K/16];
1806
+ float sw[QK_K/16];
1807
+ float weight[QK_K/16];
1808
+ uint8_t Ls[QK_K/16], Lm[QK_K/16];
1809
+
1810
+ for (int i = 0; i < nb; i++) {
1811
+ memset(sw, 0, QK_K/16*sizeof(float));
1812
+ float sumx2 = 0;
1813
+ for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
1814
+ float sigma2 = sumx2/QK_K;
1815
+ for (int j = 0; j < QK_K/16; ++j) {
1816
+ const float * restrict qw = quant_weights + QK_K * i + 16*j;
1817
+ for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
1818
+ for (int l = 0; l < 16; ++l) sw[j] += weight[l];
1819
+ scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1820
+ }
1821
+
1822
+ float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1823
+ float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
1824
+ y[i].d = GGML_FP32_TO_FP16(dm);
1825
+ y[i].dmin = GGML_FP32_TO_FP16(mm);
1826
+ dm = GGML_FP16_TO_FP32(y[i].d);
1827
+ mm = GGML_FP16_TO_FP32(y[i].dmin);
1828
+
1829
+ for (int j = 0; j < QK_K/16; ++j) {
1830
+ y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1831
+ }
1832
+
1833
+ if (requantize) {
1834
+ for (int j = 0; j < QK_K/16; ++j) {
1835
+ const float d = dm * (y[i].scales[j] & 0xF);
1836
+ if (!d) continue;
1837
+ const float m = mm * (y[i].scales[j] >> 4);
1838
+ for (int ii = 0; ii < 16; ++ii) {
1839
+ int l = nearest_int((x[16*j + ii] + m)/d);
1840
+ l = MAX(0, MIN(3, l));
1841
+ L[16*j + ii] = l;
1842
+ }
1843
+ }
1844
+ }
1845
+
1846
+ #if QK_K == 256
1847
+ for (int j = 0; j < QK_K; j += 128) {
1848
+ for (int l = 0; l < 32; ++l) {
1849
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1850
+ }
1851
+ }
1852
+ #else
1853
+ for (int l = 0; l < 16; ++l) {
1854
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
1855
+ }
1856
+ #endif
1857
+
1858
+ x += QK_K;
1859
+
1860
+ }
1861
+ }
1862
+
1863
+ size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
1864
+ (void)hist;
1865
+ int row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
1866
+ if (!quant_weights) {
1867
+ quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
1868
+ }
1869
+ else {
1870
+ char * qrow = (char *)dst;
1871
+ for (int row = 0; row < nrow; ++row) {
1872
+ quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
1873
+ src += n_per_row;
1874
+ qrow += row_size;
1875
+ }
1876
+ }
1877
+ return nrow * row_size;
1878
+ }
1879
+
1611
1880
  //========================= 3-bit (de)-quantization
1612
1881
 
1613
1882
  void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
@@ -1821,6 +2090,112 @@ size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n
1821
2090
  return (n/QK_K*sizeof(block_q3_K));
1822
2091
  }
1823
2092
 
2093
+ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) {
2094
+ #if QK_K != 256
2095
+ (void)quant_weights;
2096
+ quantize_row_q3_K_reference(x, y, n_per_row);
2097
+ #else
2098
+ assert(n_per_row % QK_K == 0);
2099
+ const int nb = n_per_row / QK_K;
2100
+
2101
+ int8_t L[QK_K];
2102
+ float scales[QK_K / 16];
2103
+ float weight[16];
2104
+ float sw[QK_K / 16];
2105
+ int8_t Ls[QK_K / 16];
2106
+
2107
+ for (int i = 0; i < nb; i++) {
2108
+
2109
+ float sumx2 = 0;
2110
+ for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
2111
+ float sigma2 = 2*sumx2/QK_K;
2112
+
2113
+ for (int j = 0; j < QK_K/16; ++j) {
2114
+ if (quant_weights) {
2115
+ const float * qw = quant_weights ? quant_weights + QK_K * i + 16*j : NULL;
2116
+ for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]);
2117
+ } else {
2118
+ for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l];
2119
+ }
2120
+ float sumw = 0;
2121
+ for (int l = 0; l < 16; ++l) sumw += weight[l];
2122
+ sw[j] = sumw;
2123
+
2124
+ scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
2125
+
2126
+ }
2127
+
2128
+ memset(y[i].scales, 0, 12);
2129
+
2130
+ float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
2131
+ for (int j = 0; j < QK_K/16; ++j) {
2132
+ int l = Ls[j];
2133
+ if (j < 8) {
2134
+ y[i].scales[j] = l & 0xF;
2135
+ } else {
2136
+ y[i].scales[j-8] |= ((l & 0xF) << 4);
2137
+ }
2138
+ l >>= 4;
2139
+ y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
2140
+ }
2141
+ y[i].d = GGML_FP32_TO_FP16(d_block);
2142
+
2143
+ int8_t sc;
2144
+ for (int j = 0; j < QK_K/16; ++j) {
2145
+ sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
2146
+ sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
2147
+ float d = GGML_FP16_TO_FP32(y[i].d) * sc;
2148
+ if (!d) {
2149
+ continue;
2150
+ }
2151
+ for (int ii = 0; ii < 16; ++ii) {
2152
+ int l = nearest_int(x[16*j + ii]/d);
2153
+ l = MAX(-4, MIN(3, l));
2154
+ L[16*j + ii] = l + 4;
2155
+ }
2156
+ }
2157
+
2158
+ memset(y[i].hmask, 0, QK_K/8);
2159
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
2160
+ int m = 0;
2161
+ uint8_t hm = 1;
2162
+ for (int j = 0; j < QK_K; ++j) {
2163
+ if (L[j] > 3) {
2164
+ y[i].hmask[m] |= hm;
2165
+ L[j] -= 4;
2166
+ }
2167
+ if (++m == QK_K/8) {
2168
+ m = 0; hm <<= 1;
2169
+ }
2170
+ }
2171
+ for (int j = 0; j < QK_K; j += 128) {
2172
+ for (int l = 0; l < 32; ++l) {
2173
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
2174
+ }
2175
+ }
2176
+
2177
+ x += QK_K;
2178
+ }
2179
+ #endif
2180
+ }
2181
+
2182
+ size_t quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2183
+ (void)hist;
2184
+ int row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
2185
+ if (!quant_weights) {
2186
+ quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
2187
+ }
2188
+ else {
2189
+ char * qrow = (char *)dst;
2190
+ for (int row = 0; row < nrow; ++row) {
2191
+ quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
2192
+ src += n_per_row;
2193
+ qrow += row_size;
2194
+ }
2195
+ }
2196
+ return nrow * row_size;
2197
+ }
2198
+
1824
2199
  // ====================== 4-bit (de)-quantization
1825
2200
 
1826
2201
  void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
@@ -1986,36 +2361,38 @@ size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n
1986
2361
  return (n/QK_K*sizeof(block_q4_K));
1987
2362
  }
1988
2363
 
1989
- // ====================== 5-bit (de)-quantization
1990
-
1991
- void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
1992
- assert(k % QK_K == 0);
1993
- const int nb = k / QK_K;
2364
+ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) {
2365
+ #if QK_K != 256
2366
+ (void)quant_weights;
2367
+ quantize_row_q4_K_reference(x, y, n_per_row);
2368
+ #else
2369
+ assert(n_per_row % QK_K == 0);
2370
+ const int nb = n_per_row / QK_K;
1994
2371
 
1995
- #if QK_K == 256
1996
2372
  uint8_t L[QK_K];
2373
+ uint8_t Laux[32];
2374
+ float weights[32];
1997
2375
  float mins[QK_K/32];
1998
2376
  float scales[QK_K/32];
1999
- float weights[32];
2000
- uint8_t Laux[32];
2001
- #else
2002
- int8_t L[QK_K];
2003
- float scales[QK_K/16];
2004
- #endif
2005
2377
 
2006
2378
  for (int i = 0; i < nb; i++) {
2007
2379
 
2008
- #if QK_K == 256
2380
+ float sum_x2 = 0;
2381
+ for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
2382
+ float sigma2 = sum_x2/QK_K;
2383
+ float av_x = sqrtf(sigma2);
2009
2384
 
2010
2385
  float max_scale = 0; // as we are deducting the min, scales are always positive
2011
2386
  float max_min = 0;
2012
2387
  for (int j = 0; j < QK_K/32; ++j) {
2013
- //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
2014
- float sum_x2 = 0;
2015
- for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
2016
- float av_x = sqrtf(sum_x2/32);
2017
- for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2018
- scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
2388
+ if (quant_weights) {
2389
+ const float * qw = quant_weights + QK_K*i + 32*j;
2390
+ for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
2391
+ } else {
2392
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2393
+ }
2394
+ scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
2395
+ //scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
2019
2396
  float scale = scales[j];
2020
2397
  if (scale > max_scale) {
2021
2398
  max_scale = scale;
@@ -2053,18 +2430,118 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
2053
2430
  const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
2054
2431
  for (int ii = 0; ii < 32; ++ii) {
2055
2432
  int l = nearest_int((x[32*j + ii] + dm)/d);
2056
- l = MAX(0, MIN(31, l));
2433
+ l = MAX(0, MIN(15, l));
2057
2434
  L[32*j + ii] = l;
2058
2435
  }
2059
2436
  }
2437
+ uint8_t * q = y[i].qs;
2438
+ for (int j = 0; j < QK_K; j += 64) {
2439
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
2440
+ q += 32;
2441
+ }
2060
2442
 
2061
- uint8_t * restrict qh = y[i].qh;
2062
- uint8_t * restrict ql = y[i].qs;
2063
- memset(qh, 0, QK_K/8);
2443
+ x += QK_K;
2064
2444
 
2065
- uint8_t m1 = 1, m2 = 2;
2066
- for (int n = 0; n < QK_K; n += 64) {
2067
- for (int j = 0; j < 32; ++j) {
2445
+ }
2446
+ #endif
2447
+ }
2448
+
2449
+ size_t quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2450
+ (void)hist;
2451
+ int row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
2452
+ if (!quant_weights) {
2453
+ quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
2454
+ }
2455
+ else {
2456
+ char * qrow = (char *)dst;
2457
+ for (int row = 0; row < nrow; ++row) {
2458
+ quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
2459
+ src += n_per_row;
2460
+ qrow += row_size;
2461
+ }
2462
+ }
2463
+ return nrow * row_size;
2464
+ }
2465
+
2466
+ // ====================== 5-bit (de)-quantization
2467
+
2468
+ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
2469
+ assert(k % QK_K == 0);
2470
+ const int nb = k / QK_K;
2471
+
2472
+ #if QK_K == 256
2473
+ uint8_t L[QK_K];
2474
+ float mins[QK_K/32];
2475
+ float scales[QK_K/32];
2476
+ float weights[32];
2477
+ uint8_t Laux[32];
2478
+ #else
2479
+ int8_t L[QK_K];
2480
+ float scales[QK_K/16];
2481
+ #endif
2482
+
2483
+ for (int i = 0; i < nb; i++) {
2484
+
2485
+ #if QK_K == 256
2486
+
2487
+ float max_scale = 0; // as we are deducting the min, scales are always positive
2488
+ float max_min = 0;
2489
+ for (int j = 0; j < QK_K/32; ++j) {
2490
+ //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
2491
+ float sum_x2 = 0;
2492
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
2493
+ float av_x = sqrtf(sum_x2/32);
2494
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2495
+ scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
2496
+ float scale = scales[j];
2497
+ if (scale > max_scale) {
2498
+ max_scale = scale;
2499
+ }
2500
+ float min = mins[j];
2501
+ if (min > max_min) {
2502
+ max_min = min;
2503
+ }
2504
+ }
2505
+
2506
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2507
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2508
+ for (int j = 0; j < QK_K/32; ++j) {
2509
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
2510
+ uint8_t lm = nearest_int(inv_min*mins[j]);
2511
+ ls = MIN(63, ls);
2512
+ lm = MIN(63, lm);
2513
+ if (j < 4) {
2514
+ y[i].scales[j] = ls;
2515
+ y[i].scales[j+4] = lm;
2516
+ } else {
2517
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
2518
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
2519
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
2520
+ }
2521
+ }
2522
+ y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
2523
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
2524
+
2525
+ uint8_t sc, m;
2526
+ for (int j = 0; j < QK_K/32; ++j) {
2527
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
2528
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
2529
+ if (!d) continue;
2530
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
2531
+ for (int ii = 0; ii < 32; ++ii) {
2532
+ int l = nearest_int((x[32*j + ii] + dm)/d);
2533
+ l = MAX(0, MIN(31, l));
2534
+ L[32*j + ii] = l;
2535
+ }
2536
+ }
2537
+
2538
+ uint8_t * restrict qh = y[i].qh;
2539
+ uint8_t * restrict ql = y[i].qs;
2540
+ memset(qh, 0, QK_K/8);
2541
+
2542
+ uint8_t m1 = 1, m2 = 2;
2543
+ for (int n = 0; n < QK_K; n += 64) {
2544
+ for (int j = 0; j < 32; ++j) {
2068
2545
  int l1 = L[n + j];
2069
2546
  if (l1 > 15) {
2070
2547
  l1 -= 16; qh[j] |= m1;
@@ -2081,7 +2558,7 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
2081
2558
  #else
2082
2559
  float max_scale = 0, amax = 0;
2083
2560
  for (int j = 0; j < QK_K/16; ++j) {
2084
- scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
2561
+ scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL);
2085
2562
  float abs_scale = fabsf(scales[j]);
2086
2563
  if (abs_scale > amax) {
2087
2564
  amax = abs_scale;
@@ -2192,6 +2669,123 @@ size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n
2192
2669
  return (n/QK_K*sizeof(block_q5_K));
2193
2670
  }
2194
2671
 
2672
+ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) {
2673
+ #if QK_K != 256
2674
+ (void)quant_weights;
2675
+ quantize_row_q5_K_reference(x, y, n_per_row);
2676
+ #else
2677
+ assert(n_per_row % QK_K == 0);
2678
+ const int nb = n_per_row / QK_K;
2679
+
2680
+ uint8_t L[QK_K];
2681
+ float mins[QK_K/32];
2682
+ float scales[QK_K/32];
2683
+ float weights[32];
2684
+ uint8_t Laux[32];
2685
+
2686
+ for (int i = 0; i < nb; i++) {
2687
+
2688
+ float sum_x2 = 0;
2689
+ for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
2690
+ float sigma2 = sum_x2/QK_K;
2691
+ float av_x = sqrtf(sigma2);
2692
+
2693
+ float max_scale = 0; // as we are deducting the min, scales are always positive
2694
+ float max_min = 0;
2695
+ for (int j = 0; j < QK_K/32; ++j) {
2696
+ if (quant_weights) {
2697
+ const float * qw = quant_weights + QK_K*i + 32*j;
2698
+ for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
2699
+ } else {
2700
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2701
+ }
2702
+ scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
2703
+ float scale = scales[j];
2704
+ if (scale > max_scale) {
2705
+ max_scale = scale;
2706
+ }
2707
+ float min = mins[j];
2708
+ if (min > max_min) {
2709
+ max_min = min;
2710
+ }
2711
+ }
2712
+
2713
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2714
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2715
+ for (int j = 0; j < QK_K/32; ++j) {
2716
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
2717
+ uint8_t lm = nearest_int(inv_min*mins[j]);
2718
+ ls = MIN(63, ls);
2719
+ lm = MIN(63, lm);
2720
+ if (j < 4) {
2721
+ y[i].scales[j] = ls;
2722
+ y[i].scales[j+4] = lm;
2723
+ } else {
2724
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
2725
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
2726
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
2727
+ }
2728
+ }
2729
+ y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
2730
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
2731
+
2732
+ uint8_t sc, m;
2733
+ for (int j = 0; j < QK_K/32; ++j) {
2734
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
2735
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
2736
+ if (!d) continue;
2737
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
2738
+ for (int ii = 0; ii < 32; ++ii) {
2739
+ int l = nearest_int((x[32*j + ii] + dm)/d);
2740
+ l = MAX(0, MIN(31, l));
2741
+ L[32*j + ii] = l;
2742
+ }
2743
+ }
2744
+
2745
+ uint8_t * restrict qh = y[i].qh;
2746
+ uint8_t * restrict ql = y[i].qs;
2747
+ memset(qh, 0, QK_K/8);
2748
+
2749
+ uint8_t m1 = 1, m2 = 2;
2750
+ for (int n = 0; n < QK_K; n += 64) {
2751
+ for (int j = 0; j < 32; ++j) {
2752
+ int l1 = L[n + j];
2753
+ if (l1 > 15) {
2754
+ l1 -= 16; qh[j] |= m1;
2755
+ }
2756
+ int l2 = L[n + j + 32];
2757
+ if (l2 > 15) {
2758
+ l2 -= 16; qh[j] |= m2;
2759
+ }
2760
+ ql[j] = l1 | (l2 << 4);
2761
+ }
2762
+ m1 <<= 2; m2 <<= 2;
2763
+ ql += 32;
2764
+ }
2765
+
2766
+ x += QK_K;
2767
+
2768
+ }
2769
+ #endif
2770
+ }
2771
+
2772
+ size_t quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2773
+ (void)hist;
2774
+ int row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
2775
+ if (!quant_weights) {
2776
+ quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
2777
+ }
2778
+ else {
2779
+ char * qrow = (char *)dst;
2780
+ for (int row = 0; row < nrow; ++row) {
2781
+ quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
2782
+ src += n_per_row;
2783
+ qrow += row_size;
2784
+ }
2785
+ }
2786
+ return nrow * row_size;
2787
+ }
2788
+
2195
2789
  // ====================== 6-bit (de)-quantization
2196
2790
 
2197
2791
  void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
@@ -2208,7 +2802,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
2208
2802
 
2209
2803
  for (int ib = 0; ib < QK_K/16; ++ib) {
2210
2804
 
2211
- const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
2805
+ const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
2212
2806
  scales[ib] = scale;
2213
2807
 
2214
2808
  const float abs_scale = fabsf(scale);
@@ -2317,27 +2911,590 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
2317
2911
  y[l+32] = d * sc[2] * q3;
2318
2912
  y[l+48] = d * sc[3] * q4;
2319
2913
  }
2320
- y += 64;
2321
- #endif
2914
+ y += 64;
2915
+ #endif
2916
+
2917
+ }
2918
+ }
2919
+
2920
+ void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
2921
+ assert(k % QK_K == 0);
2922
+ block_q6_K * restrict y = vy;
2923
+ quantize_row_q6_K_reference(x, y, k);
2924
+ }
2925
+
2926
+ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
2927
+ assert(k % QK_K == 0);
2928
+ (void)hist; // TODO: collect histograms
2929
+
2930
+ for (int j = 0; j < n; j += k) {
2931
+ block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
2932
+ quantize_row_q6_K_reference(src + j, y, k);
2933
+ }
2934
+ return (n/QK_K*sizeof(block_q6_K));
2935
+ }
2936
+
2937
+ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) {
2938
+ #if QK_K != 256
2939
+ (void)quant_weights;
2940
+ quantize_row_q6_K_reference(x, y, n_per_row);
2941
+ #else
2942
+ assert(n_per_row % QK_K == 0);
2943
+ const int nb = n_per_row / QK_K;
2944
+
2945
+ int8_t L[QK_K];
2946
+ float scales[QK_K/16];
2947
+ //float weights[16];
2948
+
2949
+ for (int i = 0; i < nb; i++) {
2950
+
2951
+ //float sum_x2 = 0;
2952
+ //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j];
2953
+ //float sigma2 = sum_x2/QK_K;
2954
+
2955
+ float max_scale = 0;
2956
+ float max_abs_scale = 0;
2957
+
2958
+ for (int ib = 0; ib < QK_K/16; ++ib) {
2959
+
2960
+ float scale;
2961
+ if (quant_weights) {
2962
+ const float * qw = quant_weights + QK_K*i + 16*ib;
2963
+ //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]);
2964
+ //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights);
2965
+ scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw);
2966
+ } else {
2967
+ scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
2968
+ }
2969
+ scales[ib] = scale;
2970
+
2971
+ const float abs_scale = fabsf(scale);
2972
+ if (abs_scale > max_abs_scale) {
2973
+ max_abs_scale = abs_scale;
2974
+ max_scale = scale;
2975
+ }
2976
+
2977
+ }
2978
+
2979
+ if (!max_abs_scale) {
2980
+ memset(&y[i], 0, sizeof(block_q6_K));
2981
+ y[i].d = GGML_FP32_TO_FP16(0.f);
2982
+ x += QK_K;
2983
+ continue;
2984
+ }
2985
+
2986
+ float iscale = -128.f/max_scale;
2987
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
2988
+ for (int ib = 0; ib < QK_K/16; ++ib) {
2989
+ y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
2990
+ }
2991
+
2992
+ for (int j = 0; j < QK_K/16; ++j) {
2993
+ float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
2994
+ if (!d) {
2995
+ continue;
2996
+ }
2997
+ for (int ii = 0; ii < 16; ++ii) {
2998
+ int l = nearest_int(x[16*j + ii]/d);
2999
+ l = MAX(-32, MIN(31, l));
3000
+ L[16*j + ii] = l + 32;
3001
+ }
3002
+ }
3003
+
3004
+ uint8_t * restrict ql = y[i].ql;
3005
+ uint8_t * restrict qh = y[i].qh;
3006
+ for (int j = 0; j < QK_K; j += 128) {
3007
+ for (int l = 0; l < 32; ++l) {
3008
+ const uint8_t q1 = L[j + l + 0] & 0xF;
3009
+ const uint8_t q2 = L[j + l + 32] & 0xF;
3010
+ const uint8_t q3 = L[j + l + 64] & 0xF;
3011
+ const uint8_t q4 = L[j + l + 96] & 0xF;
3012
+ ql[l+ 0] = q1 | (q3 << 4);
3013
+ ql[l+32] = q2 | (q4 << 4);
3014
+ qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
3015
+ }
3016
+ ql += 64;
3017
+ qh += 32;
3018
+ }
3019
+
3020
+ x += QK_K;
3021
+
3022
+ }
3023
+ #endif
3024
+ }
3025
+
3026
+ size_t quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3027
+ (void)hist;
3028
+ int row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
3029
+ if (!quant_weights) {
3030
+ quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
3031
+ }
3032
+ else {
3033
+ char * qrow = (char *)dst;
3034
+ for (int row = 0; row < nrow; ++row) {
3035
+ quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
3036
+ src += n_per_row;
3037
+ qrow += row_size;
3038
+ }
3039
+ }
3040
+ return nrow * row_size;
3041
+ }
3042
+
3043
+ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int n_per_row, const float * quant_weights) {
3044
+ static_assert(QK4_0 == 32, "QK4_0 must be 32");
3045
+
3046
+ if (!quant_weights) {
3047
+ quantize_row_q4_0_reference(x, y, n_per_row);
3048
+ return;
3049
+ }
3050
+
3051
+ float weight[QK4_0];
3052
+ int8_t L[QK4_0];
3053
+
3054
+ float sum_x2 = 0;
3055
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
3056
+ float sigma2 = sum_x2/n_per_row;
3057
+
3058
+ const int nb = n_per_row/QK4_0;
3059
+ for (int ib = 0; ib < nb; ++ib) {
3060
+ const float * xb = x + QK4_0 * ib;
3061
+ const float * qw = quant_weights + QK4_0 * ib;
3062
+ for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3063
+ float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
3064
+ y[ib].d = GGML_FP32_TO_FP16(d);
3065
+ for (int j = 0; j < 16; ++j) {
3066
+ y[ib].qs[j] = L[j] | (L[j+16] << 4);
3067
+ }
3068
+ }
3069
+ }
3070
+
3071
+ size_t quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3072
+ if (!quant_weights) {
3073
+ return ggml_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist);
3074
+ }
3075
+ int row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
3076
+ char * qrow = (char *)dst;
3077
+ for (int row = 0; row < nrow; ++row) {
3078
+ quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
3079
+ src += n_per_row;
3080
+ qrow += row_size;
3081
+ }
3082
+ return nrow * row_size;
3083
+ }
3084
+
3085
+ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int n_per_row, const float * quant_weights) {
3086
+ static_assert(QK4_1 == 32, "QK4_1 must be 32");
3087
+
3088
+ if (!quant_weights) {
3089
+ quantize_row_q4_1_reference(x, y, n_per_row);
3090
+ return;
3091
+ }
3092
+
3093
+ float weight[QK4_1];
3094
+ uint8_t L[QK4_1], Laux[QK4_1];
3095
+
3096
+ float sum_x2 = 0;
3097
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
3098
+ float sigma2 = sum_x2/n_per_row;
3099
+
3100
+ const int nb = n_per_row/QK4_1;
3101
+ for (int ib = 0; ib < nb; ++ib) {
3102
+ const float * xb = x + QK4_1 * ib;
3103
+ const float * qw = quant_weights + QK4_1 * ib;
3104
+ for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3105
+ float min;
3106
+ float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
3107
+ y[ib].d = GGML_FP32_TO_FP16(d);
3108
+ y[ib].m = GGML_FP32_TO_FP16(-min);
3109
+ for (int j = 0; j < 16; ++j) {
3110
+ y[ib].qs[j] = L[j] | (L[j+16] << 4);
3111
+ }
3112
+ }
3113
+ }
3114
+
3115
+ size_t quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3116
+ if (!quant_weights) {
3117
+ return ggml_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist);
3118
+ }
3119
+ int row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
3120
+ char * qrow = (char *)dst;
3121
+ for (int row = 0; row < nrow; ++row) {
3122
+ quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
3123
+ src += n_per_row;
3124
+ qrow += row_size;
3125
+ }
3126
+ return nrow * row_size;
3127
+ }
3128
+
3129
+ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int n_per_row, const float * quant_weights) {
3130
+ static_assert(QK5_0 == 32, "QK5_0 must be 32");
3131
+
3132
+ if (!quant_weights) {
3133
+ quantize_row_q5_0_reference(x, y, n_per_row);
3134
+ return;
3135
+ }
3136
+
3137
+ float weight[QK5_0];
3138
+ int8_t L[QK5_0];
3139
+
3140
+ float sum_x2 = 0;
3141
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
3142
+ float sigma2 = sum_x2/n_per_row;
3143
+
3144
+ const int nb = n_per_row/QK5_0;
3145
+ for (int ib = 0; ib < nb; ++ib) {
3146
+ const float * xb = x + QK5_0 * ib;
3147
+ const float * qw = quant_weights + QK5_0 * ib;
3148
+ for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3149
+ float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight);
3150
+ y[ib].d = GGML_FP32_TO_FP16(d);
3151
+
3152
+ uint32_t qh = 0;
3153
+
3154
+ for (int j = 0; j < 16; ++j) {
3155
+ const uint8_t xi0 = L[j];
3156
+ const uint8_t xi1 = L[j+16];
3157
+ y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
3158
+
3159
+ // get the 5-th bit and store it in qh at the right position
3160
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
3161
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
3162
+ }
3163
+
3164
+ memcpy(&y[ib].qh, &qh, sizeof(qh));
3165
+ }
3166
+ }
3167
+
3168
+ size_t quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3169
+ if (!quant_weights) {
3170
+ return ggml_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist);
3171
+ }
3172
+ int row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
3173
+ char * qrow = (char *)dst;
3174
+ for (int row = 0; row < nrow; ++row) {
3175
+ quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
3176
+ src += n_per_row;
3177
+ qrow += row_size;
3178
+ }
3179
+ return nrow * row_size;
3180
+ }
3181
+
3182
+ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int n_per_row, const float * quant_weights) {
3183
+ static_assert(QK5_1 == 32, "QK5_1 must be 32");
3184
+
3185
+ if (!quant_weights) {
3186
+ quantize_row_q5_1_reference(x, y, n_per_row);
3187
+ return;
3188
+ }
3189
+
3190
+ float weight[QK5_1];
3191
+ uint8_t L[QK5_1], Laux[QK5_1];
3192
+
3193
+ float sum_x2 = 0;
3194
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
3195
+ float sigma2 = sum_x2/n_per_row;
3196
+
3197
+ const int nb = n_per_row/QK5_1;
3198
+ for (int ib = 0; ib < nb; ++ib) {
3199
+ const float * xb = x + QK5_1 * ib;
3200
+ const float * qw = quant_weights + QK5_1 * ib;
3201
+ for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3202
+ float min;
3203
+ float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
3204
+ y[ib].d = GGML_FP32_TO_FP16(d);
3205
+ y[ib].m = GGML_FP32_TO_FP16(-min);
3206
+
3207
+ uint32_t qh = 0;
3208
+ for (int j = 0; j < 16; ++j) {
3209
+ const uint8_t xi0 = L[j];
3210
+ const uint8_t xi1 = L[j+16];
3211
+ y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
3212
+ // get the 5-th bit and store it in qh at the right position
3213
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
3214
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
3215
+ }
3216
+ memcpy(&y[ib].qh, &qh, sizeof(qh));
3217
+ }
3218
+ }
2322
3219
 
3220
+ size_t quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3221
+ if (!quant_weights) {
3222
+ return ggml_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist);
3223
+ }
3224
+ int row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
3225
+ char * qrow = (char *)dst;
3226
+ for (int row = 0; row < nrow; ++row) {
3227
+ quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
3228
+ src += n_per_row;
3229
+ qrow += row_size;
2323
3230
  }
3231
+ return nrow * row_size;
2324
3232
  }
2325
3233
 
2326
- void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
3234
+ // ====================== "True" 2-bit (de)-quantization
3235
+
3236
+ static const uint64_t iq2xxs_grid[256] = {
3237
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3238
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3239
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
3240
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
3241
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
3242
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
3243
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
3244
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
3245
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
3246
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
3247
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
3248
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
3249
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
3250
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
3251
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
3252
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
3253
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
3254
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
3255
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
3256
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
3257
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
3258
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
3259
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
3260
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
3261
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
3262
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
3263
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
3264
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
3265
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
3266
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
3267
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
3268
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
3269
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
3270
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
3271
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
3272
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
3273
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
3274
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
3275
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
3276
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
3277
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
3278
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
3279
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
3280
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
3281
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
3282
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
3283
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
3284
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
3285
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
3286
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
3287
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
3288
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
3289
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
3290
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
3291
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
3292
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
3293
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
3294
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
3295
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
3296
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
3297
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
3298
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
3299
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
3300
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3301
+ };
3302
+
3303
+ static const uint64_t iq2xs_grid[512] = {
3304
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3305
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3306
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3307
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3308
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3309
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
3310
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
3311
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
3312
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
3313
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
3314
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
3315
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
3316
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
3317
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
3318
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
3319
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
3320
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
3321
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
3322
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
3323
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
3324
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
3325
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
3326
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
3327
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
3328
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
3329
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
3330
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
3331
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
3332
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
3333
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
3334
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
3335
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
3336
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
3337
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
3338
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
3339
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
3340
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
3341
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
3342
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
3343
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
3344
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
3345
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
3346
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
3347
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
3348
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
3349
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
3350
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
3351
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
3352
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
3353
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
3354
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
3355
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
3356
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
3357
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
3358
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
3359
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
3360
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
3361
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
3362
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
3363
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
3364
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
3365
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
3366
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
3367
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
3368
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
3369
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
3370
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
3371
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
3372
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
3373
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
3374
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
3375
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
3376
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
3377
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
3378
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
3379
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
3380
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
3381
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
3382
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
3383
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
3384
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
3385
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
3386
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
3387
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
3388
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
3389
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
3390
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
3391
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
3392
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
3393
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
3394
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
3395
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
3396
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
3397
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
3398
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
3399
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
3400
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
3401
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
3402
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
3403
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
3404
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
3405
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
3406
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
3407
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
3408
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
3409
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
3410
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
3411
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
3412
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
3413
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
3414
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
3415
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
3416
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
3417
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
3418
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
3419
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
3420
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
3421
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
3422
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
3423
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
3424
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
3425
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
3426
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
3427
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
3428
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
3429
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
3430
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
3431
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3432
+ };
3433
+
3434
+ static const uint8_t ksigns_iq2xs[128] = {
3435
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3436
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
3437
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
3438
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
3439
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
3440
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
3441
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
3442
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
3443
+ };
3444
+
3445
+ static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
3446
+
3447
+ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
2327
3448
  assert(k % QK_K == 0);
2328
- block_q6_K * restrict y = vy;
2329
- quantize_row_q6_K_reference(x, y, k);
3449
+ const int nb = k / QK_K;
3450
+
3451
+ uint32_t aux32[2];
3452
+ const uint8_t * aux8 = (const uint8_t *)aux32;
3453
+
3454
+ for (int i = 0; i < nb; i++) {
3455
+
3456
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3457
+
3458
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3459
+ memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
3460
+ const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
3461
+ for (int l = 0; l < 4; ++l) {
3462
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
3463
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
3464
+ for (int j = 0; j < 8; ++j) {
3465
+ y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3466
+ }
3467
+ y += 8;
3468
+ }
3469
+ }
3470
+ }
2330
3471
  }
2331
3472
 
2332
- size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
3473
+ // ====================== 2.3125 bpw (de)-quantization
3474
+
3475
+ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
2333
3476
  assert(k % QK_K == 0);
2334
- (void)hist; // TODO: collect histograms
3477
+ const int nb = k / QK_K;
2335
3478
 
2336
- for (int j = 0; j < n; j += k) {
2337
- block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
2338
- quantize_row_q6_K_reference(src + j, y, k);
3479
+ float db[2];
3480
+
3481
+ for (int i = 0; i < nb; i++) {
3482
+
3483
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3484
+
3485
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3486
+ db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
3487
+ db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
3488
+ for (int l = 0; l < 4; ++l) {
3489
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
3490
+ const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
3491
+ for (int j = 0; j < 8; ++j) {
3492
+ y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3493
+ }
3494
+ y += 8;
3495
+ }
3496
+ }
2339
3497
  }
2340
- return (n/QK_K*sizeof(block_q6_K));
2341
3498
  }
2342
3499
 
2343
3500
  //===================================== Q8_K ==============================================
@@ -2362,7 +3519,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict
2362
3519
  x += QK_K;
2363
3520
  continue;
2364
3521
  }
2365
- const float iscale = -128.f/max;
3522
+ //const float iscale = -128.f/max;
3523
+ // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
3524
+ const float iscale = -127.f/max;
2366
3525
  for (int j = 0; j < QK_K; ++j) {
2367
3526
  int v = nearest_int(iscale*x[j]);
2368
3527
  y[i].qs[j] = MIN(127, v);
@@ -7065,3 +8224,982 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
7065
8224
  }
7066
8225
 
7067
8226
  #endif
8227
+
8228
+ static const int8_t keven_signs_q2xs[1024] = {
8229
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
8230
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
8231
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
8232
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
8233
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
8234
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
8235
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
8236
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
8237
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
8238
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
8239
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
8240
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
8241
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
8242
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
8243
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
8244
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
8245
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
8246
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
8247
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
8248
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
8249
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
8250
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
8251
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
8252
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
8253
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
8254
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
8255
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
8256
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
8257
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
8258
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
8259
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
8260
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
8261
+ };
8262
+
8263
+ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8264
+ assert(n % QK_K == 0);
8265
+
8266
+ const block_iq2_xxs * restrict x = vx;
8267
+ const block_q8_K * restrict y = vy;
8268
+
8269
+ const int nb = n / QK_K;
8270
+
8271
+ #if defined(__ARM_NEON)
8272
+
8273
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8274
+
8275
+ uint32_t aux32[4];
8276
+ const uint8_t * aux8 = (const uint8_t *)aux32;
8277
+
8278
+ ggml_int8x16x4_t q2u;
8279
+ ggml_int8x16x4_t q2s;
8280
+ ggml_int8x16x4_t q8b;
8281
+
8282
+ float sumf = 0;
8283
+ for (int i = 0; i < nb; ++i) {
8284
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8285
+ const uint16_t * restrict q2 = x[i].qs;
8286
+ const int8_t * restrict q8 = y[i].qs;
8287
+ float sumf1 = 0, sumf2 = 0;
8288
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8289
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
8290
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
8291
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
8292
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
8293
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
8294
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
8295
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
8296
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
8297
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
8298
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
8299
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
8300
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
8301
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
8302
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
8303
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
8304
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
8305
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
8306
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
8307
+ }
8308
+ sumf += d*(sumf1 + sumf2);
8309
+ }
8310
+ *s = 0.25f * sumf;
8311
+
8312
+ #elif defined(__AVX2__)
8313
+
8314
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8315
+
8316
+ uint32_t aux32[4];
8317
+ const uint8_t * aux8 = (const uint8_t *)aux32;
8318
+
8319
+ __m256 accumf = _mm256_setzero_ps();
8320
+ for (int i = 0; i < nb; ++i) {
8321
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8322
+ const uint16_t * restrict q2 = x[i].qs;
8323
+ const int8_t * restrict q8 = y[i].qs;
8324
+ __m256i sumi1 = _mm256_setzero_si256();
8325
+ __m256i sumi2 = _mm256_setzero_si256();
8326
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8327
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8328
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8329
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
8330
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
8331
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
8332
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
8333
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
8334
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
8335
+ signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
8336
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8337
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8338
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8339
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8340
+ const uint16_t ls1 = aux32[1] >> 28;
8341
+ const uint16_t ls2 = aux32[3] >> 28;
8342
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
8343
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
8344
+ sumi1 = _mm256_add_epi32(sumi1, p1);
8345
+ sumi2 = _mm256_add_epi32(sumi2, p2);
8346
+ }
8347
+
8348
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
8349
+
8350
+ }
8351
+
8352
+ *s = 0.125f * hsum_float_8(accumf);
8353
+
8354
+ #else
8355
+
8356
+ uint32_t aux32[2];
8357
+ const uint8_t * aux8 = (const uint8_t *)aux32;
8358
+
8359
+ float sumf = 0.f;
8360
+ for (int i = 0; i < nb; ++i) {
8361
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8362
+ const uint16_t * restrict q2 = x[i].qs;
8363
+ const int8_t * restrict q8 = y[i].qs;
8364
+ int32_t bsum = 0;
8365
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
8366
+ memcpy(aux32, q2, 2*sizeof(uint32_t));
8367
+ q2 += 4;
8368
+ const uint32_t ls = 2*(aux32[1] >> 28) + 1;
8369
+ int32_t sumi = 0;
8370
+ for (int l = 0; l < 4; ++l) {
8371
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
8372
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
8373
+ for (int j = 0; j < 8; ++j) {
8374
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
8375
+ }
8376
+ q8 += 8;
8377
+ }
8378
+ bsum += sumi * ls;
8379
+ }
8380
+ sumf += d * bsum;
8381
+ }
8382
+ *s = 0.125f * sumf;
8383
+ #endif
8384
+ }
8385
+
8386
+ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8387
+ assert(n % QK_K == 0);
8388
+
8389
+ const block_iq2_xs * restrict x = vx;
8390
+ const block_q8_K * restrict y = vy;
8391
+
8392
+ const int nb = n / QK_K;
8393
+
8394
+ #if defined(__ARM_NEON)
8395
+
8396
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8397
+
8398
+ ggml_int8x16x4_t q2u;
8399
+ ggml_int8x16x4_t q2s;
8400
+ ggml_int8x16x4_t q8b;
8401
+
8402
+ int32x4x4_t scales32;
8403
+
8404
+ float sumf = 0;
8405
+ for (int i = 0; i < nb; ++i) {
8406
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8407
+ const uint16_t * restrict q2 = x[i].qs;
8408
+ const int8_t * restrict q8 = y[i].qs;
8409
+ const uint8x8_t scales8 = vld1_u8(x[i].scales);
8410
+ const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
8411
+ const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
8412
+ uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
8413
+ scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
8414
+ const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
8415
+ const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
8416
+ scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
8417
+ scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
8418
+ scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
8419
+ scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
8420
+ int32x4_t sumi = vdupq_n_s32(0);
8421
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
8422
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
8423
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
8424
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
8425
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
8426
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
8427
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
8428
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
8429
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
8430
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
8431
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
8432
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
8433
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
8434
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
8435
+ const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
8436
+ const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
8437
+ const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
8438
+ const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
8439
+ const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
8440
+ sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
8441
+ q2 += 8;
8442
+ }
8443
+ sumf += d*vaddvq_s32(sumi);
8444
+ }
8445
+ *s = 0.125f * sumf;
8446
+
8447
+ #elif defined(__AVX2__)
8448
+
8449
+ const __m128i m4 = _mm_set1_epi8(0xf);
8450
+ const __m128i m1 = _mm_set1_epi8(1);
8451
+ const __m128i m511 = _mm_set1_epi16(511);
8452
+ const __m128i m127 = _mm_set1_epi16(127);
8453
+
8454
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8455
+
8456
+ uint64_t aux64;
8457
+
8458
+ // somewhat hacky, but gives a significant boost in performance
8459
+ __m128i aux_gindex, aux_sindex;
8460
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
8461
+ const uint16_t * sindex = (const uint16_t *)&aux_sindex;
8462
+
8463
+ __m256 accumf = _mm256_setzero_ps();
8464
+ for (int i = 0; i < nb; ++i) {
8465
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8466
+ const uint16_t * restrict q2 = x[i].qs;
8467
+ const int8_t * restrict q8 = y[i].qs;
8468
+
8469
+ memcpy(&aux64, x[i].scales, 8);
8470
+ __m128i stmp = _mm_set1_epi64x(aux64);
8471
+ stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
8472
+ const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
8473
+
8474
+ __m256i sumi1 = _mm256_setzero_si256();
8475
+ __m256i sumi2 = _mm256_setzero_si256();
8476
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8477
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8478
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8479
+ const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
8480
+ aux_gindex = _mm_and_si128(q2_data, m511);
8481
+ aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
8482
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
8483
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
8484
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
8485
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
8486
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8487
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8488
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8489
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8490
+
8491
+ const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
8492
+ const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
8493
+
8494
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
8495
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
8496
+ }
8497
+
8498
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
8499
+
8500
+ }
8501
+
8502
+ *s = 0.125f * hsum_float_8(accumf);
8503
+
8504
+ #else
8505
+
8506
+ float sumf = 0.f;
8507
+ for (int i = 0; i < nb; ++i) {
8508
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8509
+ const uint16_t * restrict q2 = x[i].qs;
8510
+ const uint8_t * restrict sc = x[i].scales;
8511
+ const int8_t * restrict q8 = y[i].qs;
8512
+ int32_t bsum = 0;
8513
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
8514
+ const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
8515
+ const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
8516
+ int32_t sumi = 0;
8517
+ for (int l = 0; l < 2; ++l) {
8518
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
8519
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
8520
+ for (int j = 0; j < 8; ++j) {
8521
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
8522
+ }
8523
+ q8 += 8;
8524
+ }
8525
+ bsum += sumi * ls1;
8526
+ sumi = 0;
8527
+ for (int l = 2; l < 4; ++l) {
8528
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
8529
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
8530
+ for (int j = 0; j < 8; ++j) {
8531
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
8532
+ }
8533
+ q8 += 8;
8534
+ }
8535
+ bsum += sumi * ls2;
8536
+ q2 += 4;
8537
+ }
8538
+ sumf += d * bsum;
8539
+ }
8540
+ *s = 0.125f * sumf;
8541
+ #endif
8542
+ }
8543
+
8544
+ // ================================ IQ2 quantization =============================================
8545
+
8546
+ typedef struct {
8547
+ uint64_t * grid;
8548
+ int * map;
8549
+ uint16_t * neighbours;
8550
+ } iq2_entry_t;
8551
+
8552
+ static iq2_entry_t iq2_data[2] = {
8553
+ {NULL, NULL, NULL},
8554
+ {NULL, NULL, NULL},
8555
+ };
8556
+
8557
+ static inline int iq2_data_index(int grid_size) {
8558
+ GGML_ASSERT(grid_size == 256 || grid_size == 512);
8559
+ return grid_size == 256 ? 0 : 1;
8560
+ }
8561
+
8562
+ static int iq2_compare_func(const void * left, const void * right) {
8563
+ const int * l = (const int *)left;
8564
+ const int * r = (const int *)right;
8565
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
8566
+ }
8567
+
8568
+ static void q2xs_init_impl(int grid_size) {
8569
+ const int gindex = iq2_data_index(grid_size);
8570
+ if (iq2_data[gindex].grid) {
8571
+ return;
8572
+ }
8573
+ static const uint16_t kgrid_256[256] = {
8574
+ 0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
8575
+ 100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
8576
+ 1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
8577
+ 1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113,
8578
+ 2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240,
8579
+ 4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400,
8580
+ 5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260,
8581
+ 8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872,
8582
+ 10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
8583
+ 16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
8584
+ 17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
8585
+ 20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
8586
+ 22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
8587
+ 25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
8588
+ 33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
8589
+ 37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
8590
+ };
8591
+ static const uint16_t kgrid_512[512] = {
8592
+ 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
8593
+ 73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
8594
+ 260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
8595
+ 352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597,
8596
+ 640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096,
8597
+ 1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348,
8598
+ 1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065,
8599
+ 2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441,
8600
+ 2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160,
8601
+ 4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372,
8602
+ 4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125,
8603
+ 5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652,
8604
+ 5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197,
8605
+ 8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549,
8606
+ 8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894,
8607
+ 10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
8608
+ 16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
8609
+ 16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
8610
+ 16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
8611
+ 17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
8612
+ 18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
8613
+ 20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
8614
+ 21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
8615
+ 22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
8616
+ 24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
8617
+ 32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
8618
+ 33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
8619
+ 33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
8620
+ 35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
8621
+ 37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
8622
+ 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
8623
+ 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
8624
+ };
8625
+ const int kmap_size = 43692;
8626
+ const int nwant = 2;
8627
+ const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
8628
+ uint64_t * kgrid_q2xs;
8629
+ int * kmap_q2xs;
8630
+ uint16_t * kneighbors_q2xs;
8631
+
8632
+ printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
8633
+ uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
8634
+ for (int k = 0; k < grid_size; ++k) {
8635
+ int8_t * pos = (int8_t *)(the_grid + k);
8636
+ for (int i = 0; i < 8; ++i) {
8637
+ int l = (kgrid[k] >> 2*i) & 0x3;
8638
+ pos[i] = 2*l + 1;
8639
+ }
8640
+ }
8641
+ kgrid_q2xs = the_grid;
8642
+ iq2_data[gindex].grid = the_grid;
8643
+ kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
8644
+ iq2_data[gindex].map = kmap_q2xs;
8645
+ for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
8646
+ uint64_t aux64;
8647
+ uint8_t * aux8 = (uint8_t *)&aux64;
8648
+ for (int i = 0; i < grid_size; ++i) {
8649
+ aux64 = kgrid_q2xs[i];
8650
+ uint16_t index = 0;
8651
+ for (int k=0; k<8; ++k) {
8652
+ uint16_t q = (aux8[k] - 1)/2;
8653
+ index |= (q << 2*k);
8654
+ }
8655
+ kmap_q2xs[index] = i;
8656
+ }
8657
+ int8_t pos[8];
8658
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
8659
+ int num_neighbors = 0, num_not_in_map = 0;
8660
+ for (int i = 0; i < kmap_size; ++i) {
8661
+ if (kmap_q2xs[i] >= 0) continue;
8662
+ ++num_not_in_map;
8663
+ for (int k = 0; k < 8; ++k) {
8664
+ int l = (i >> 2*k) & 0x3;
8665
+ pos[k] = 2*l + 1;
8666
+ }
8667
+ for (int j = 0; j < grid_size; ++j) {
8668
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
8669
+ int d2 = 0;
8670
+ for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
8671
+ dist2[2*j+0] = d2;
8672
+ dist2[2*j+1] = j;
8673
+ }
8674
+ qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
8675
+ int n = 0; int d2 = dist2[0];
8676
+ int nhave = 1;
8677
+ for (int j = 0; j < grid_size; ++j) {
8678
+ if (dist2[2*j] > d2) {
8679
+ if (nhave == nwant) break;
8680
+ d2 = dist2[2*j];
8681
+ ++nhave;
8682
+ }
8683
+ ++n;
8684
+ }
8685
+ num_neighbors += n;
8686
+ }
8687
+ printf("%s: %d neighbours in total\n", __func__, num_neighbors);
8688
+ kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
8689
+ iq2_data[gindex].neighbours = kneighbors_q2xs;
8690
+ int counter = 0;
8691
+ for (int i = 0; i < kmap_size; ++i) {
8692
+ if (kmap_q2xs[i] >= 0) continue;
8693
+ for (int k = 0; k < 8; ++k) {
8694
+ int l = (i >> 2*k) & 0x3;
8695
+ pos[k] = 2*l + 1;
8696
+ }
8697
+ for (int j = 0; j < grid_size; ++j) {
8698
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
8699
+ int d2 = 0;
8700
+ for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
8701
+ dist2[2*j+0] = d2;
8702
+ dist2[2*j+1] = j;
8703
+ }
8704
+ qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
8705
+ kmap_q2xs[i] = -(counter + 1);
8706
+ int d2 = dist2[0];
8707
+ uint16_t * start = &kneighbors_q2xs[counter++];
8708
+ int n = 0, nhave = 1;
8709
+ for (int j = 0; j < grid_size; ++j) {
8710
+ if (dist2[2*j] > d2) {
8711
+ if (nhave == nwant) break;
8712
+ d2 = dist2[2*j];
8713
+ ++nhave;
8714
+ }
8715
+ kneighbors_q2xs[counter++] = dist2[2*j+1];
8716
+ ++n;
8717
+ }
8718
+ *start = n;
8719
+ }
8720
+ free(dist2);
8721
+ }
8722
+
8723
+ void ggml_init_iq2_quantization(enum ggml_type type) {
8724
+ if (type == GGML_TYPE_IQ2_XXS) {
8725
+ q2xs_init_impl(256);
8726
+ }
8727
+ else if (type == GGML_TYPE_IQ2_XS) {
8728
+ q2xs_init_impl(512);
8729
+ }
8730
+ else {
8731
+ fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
8732
+ }
8733
+ }
8734
+
8735
+ static void q2xs_deinit_impl(int grid_size) {
8736
+ GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
8737
+ const int gindex = iq2_data_index(grid_size);
8738
+ if (iq2_data[gindex].grid) {
8739
+ free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
8740
+ free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
8741
+ free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
8742
+ }
8743
+ }
8744
+
8745
+ void ggml_deinit_iq2_quantization(enum ggml_type type) {
8746
+ if (type == GGML_TYPE_IQ2_XXS) {
8747
+ q2xs_deinit_impl(256);
8748
+ }
8749
+ else if (type == GGML_TYPE_IQ2_XS) {
8750
+ q2xs_deinit_impl(512);
8751
+ }
8752
+ else {
8753
+ fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
8754
+ }
8755
+ }
8756
+
8757
+ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
8758
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
8759
+ int num_neighbors = neighbours[0];
8760
+ GGML_ASSERT(num_neighbors > 0);
8761
+ float best_d2 = FLT_MAX;
8762
+ int grid_index = -1;
8763
+ for (int j = 1; j <= num_neighbors; ++j) {
8764
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
8765
+ float d2 = 0;
8766
+ for (int i = 0; i < 8; ++i) {
8767
+ float q = pg[i];
8768
+ float diff = scale*q - xval[i];
8769
+ d2 += weight[i]*diff*diff;
8770
+ }
8771
+ if (d2 < best_d2) {
8772
+ best_d2 = d2; grid_index = neighbours[j];
8773
+ }
8774
+ }
8775
+ GGML_ASSERT(grid_index >= 0);
8776
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
8777
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
8778
+ return grid_index;
8779
+ }
8780
+
8781
+ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
8782
+
8783
+ const int gindex = iq2_data_index(256);
8784
+
8785
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
8786
+ const int * kmap_q2xs = iq2_data[gindex].map;
8787
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
8788
+
8789
+ GGML_ASSERT(quant_weights);
8790
+ GGML_ASSERT(kgrid_q2xs);
8791
+ GGML_ASSERT(kmap_q2xs);
8792
+ GGML_ASSERT(kneighbors_q2xs);
8793
+ GGML_ASSERT(n%QK_K == 0);
8794
+
8795
+ const int kMaxQ = 3;
8796
+
8797
+ const int nbl = n/256;
8798
+
8799
+ block_iq2_xxs * y = vy;
8800
+
8801
+ float scales[QK_K/32];
8802
+ float weight[32];
8803
+ float xval[32];
8804
+ int8_t L[32];
8805
+ int8_t Laux[32];
8806
+ float waux[32];
8807
+ bool is_on_grid[4];
8808
+ bool is_on_grid_aux[4];
8809
+ uint8_t block_signs[4];
8810
+ uint32_t q2[2*(QK_K/32)];
8811
+
8812
+ for (int ibl = 0; ibl < nbl; ++ibl) {
8813
+
8814
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
8815
+ memset(q2, 0, QK_K/4);
8816
+
8817
+ float max_scale = 0;
8818
+
8819
+ const float * xbl = x + QK_K*ibl;
8820
+ float sumx2 = 0;
8821
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
8822
+ float sigma2 = sumx2/QK_K;
8823
+
8824
+ for (int ib = 0; ib < QK_K/32; ++ib) {
8825
+ const float * xb = xbl + 32*ib;
8826
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
8827
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
8828
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
8829
+ for (int k = 0; k < 4; ++k) {
8830
+ int nflip = 0;
8831
+ uint8_t s = 0;
8832
+ for (int i = 0; i < 8; ++i) {
8833
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
8834
+ else {
8835
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
8836
+ }
8837
+ }
8838
+ if (nflip%2) {
8839
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
8840
+ for (int i = 1; i < 8; ++i) {
8841
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
8842
+ if (ax < min) {
8843
+ min = ax; imin = i;
8844
+ }
8845
+ }
8846
+ xval[8*k+imin] = -xval[8*k+imin];
8847
+ s ^= (1 << imin);
8848
+ }
8849
+ block_signs[k] = s & 127;
8850
+ }
8851
+ float max = xval[0];
8852
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
8853
+ if (!max) {
8854
+ scales[ib] = 0;
8855
+ memset(L, 0, 32);
8856
+ continue;
8857
+ }
8858
+ float best = 0;
8859
+ float scale = max/(2*kMaxQ-1);
8860
+ for (int is = -9; is <= 9; ++is) {
8861
+ float id = (2*kMaxQ-1+is*0.1f)/max;
8862
+ float this_scale = 1/id;
8863
+ for (int k = 0; k < 4; ++k) {
8864
+ for (int i = 0; i < 8; ++i) {
8865
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
8866
+ Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
8867
+ }
8868
+ uint16_t u = 0;
8869
+ for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
8870
+ int grid_index = kmap_q2xs[u];
8871
+ is_on_grid_aux[k] = true;
8872
+ if (grid_index < 0) {
8873
+ is_on_grid_aux[k] = false;
8874
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
8875
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
8876
+ }
8877
+ }
8878
+ float sumqx = 0, sumq2 = 0;
8879
+ for (int i = 0; i < 32; ++i) {
8880
+ float w = weight[i];
8881
+ float q = 2*Laux[i] + 1;
8882
+ sumqx += w*xval[i]*q;
8883
+ sumq2 += w*q*q;
8884
+ }
8885
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
8886
+ scale = sumqx/sumq2; best = scale*sumqx;
8887
+ for (int i = 0; i < 32; ++i) L[i] = Laux[i];
8888
+ for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
8889
+ }
8890
+ }
8891
+ int n_not_ongrid = 0;
8892
+ for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
8893
+ if (n_not_ongrid > 0 && scale > 0) {
8894
+ float id = 1/scale;
8895
+ for (int k = 0; k < 4; ++k) {
8896
+ if (is_on_grid[k]) continue;
8897
+ uint16_t u = 0;
8898
+ for (int i = 0; i < 8; ++i) {
8899
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
8900
+ l = MAX(0, MIN(kMaxQ-1, l));
8901
+ u |= (l << 2*i);
8902
+ }
8903
+ int grid_index = kmap_q2xs[u];
8904
+ if (grid_index < 0) {
8905
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
8906
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
8907
+ }
8908
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
8909
+ for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
8910
+ }
8911
+ float sumqx = 0, sumq2 = 0;
8912
+ for (int i = 0; i < 32; ++i) {
8913
+ float w = weight[i];
8914
+ float q = 2*L[i] + 1;
8915
+ sumqx += w*xval[i]*q;
8916
+ sumq2 += w*q*q;
8917
+ }
8918
+ if (sumq2 > 0) scale = sumqx/sumq2;
8919
+ }
8920
+ if (scale < 0) {
8921
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
8922
+ // and correspondingly flip quant signs.
8923
+ scale = -scale;
8924
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
8925
+ }
8926
+ for (int k = 0; k < 4; ++k) {
8927
+ uint16_t u = 0;
8928
+ for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
8929
+ int grid_index = kmap_q2xs[u];
8930
+ if (grid_index < 0) {
8931
+ printf("Oops: found point %u not on grid:", u);
8932
+ for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
8933
+ printf("\n");
8934
+ GGML_ASSERT(false);
8935
+ }
8936
+ q2[2*ib+0] |= (grid_index << 8*k);
8937
+ q2[2*ib+1] |= (block_signs[k] << 7*k);
8938
+ }
8939
+ GGML_ASSERT(scale >= 0);
8940
+ scales[ib] = scale;
8941
+ max_scale = MAX(max_scale, scale);
8942
+ }
8943
+
8944
+ if (!max_scale) {
8945
+ memset(y[ibl].qs, 0, QK_K/4);
8946
+ continue;
8947
+ }
8948
+
8949
+ float d = max_scale/31;
8950
+ y[ibl].d = GGML_FP32_TO_FP16(d);
8951
+ float id = 1/d;
8952
+ float sumqx = 0, sumq2 = 0;
8953
+ for (int ib = 0; ib < QK_K/32; ++ib) {
8954
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
8955
+ l = MAX(0, MIN(15, l));
8956
+ q2[2*ib+1] |= ((uint32_t)l << 28);
8957
+ const float * xb = xbl + 32*ib;
8958
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
8959
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
8960
+ const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
8961
+ const float db = d * (1 + 2*l);
8962
+ uint32_t u = 0;
8963
+ for (int k = 0; k < 4; ++k) {
8964
+ const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
8965
+ const float * xk = xb + 8*k;
8966
+ const float * wk = weight + 8*k;
8967
+ const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
8968
+ float best_mse = 0; int best_index = aux8[k];
8969
+ for (int j = 0; j < 8; ++j) {
8970
+ float diff = db * grid[j] * signs[j] - xk[j];
8971
+ best_mse += wk[j] * diff * diff;
8972
+ }
8973
+ for (int idx = 0; idx < 256; ++idx) {
8974
+ grid = (const uint8_t *)(kgrid_q2xs + idx);
8975
+ float mse = 0;
8976
+ for (int j = 0; j < 8; ++j) {
8977
+ float diff = db * grid[j] * signs[j] - xk[j];
8978
+ mse += wk[j] * diff * diff;
8979
+ }
8980
+ if (mse < best_mse) {
8981
+ best_mse = mse; best_index = idx;
8982
+ }
8983
+ }
8984
+ u |= (best_index << 8*k);
8985
+ grid = (const uint8_t *)(kgrid_q2xs + best_index);
8986
+ //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
8987
+ for (int j = 0; j < 8; ++j) {
8988
+ float q = db * grid[j] * signs[j];
8989
+ sumqx += wk[j] * q * xk[j];
8990
+ sumq2 += wk[j] * q * q;
8991
+ }
8992
+ }
8993
+ q2[2*ib] = u;
8994
+ if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
8995
+ }
8996
+ memcpy(y[ibl].qs, q2, QK_K/4);
8997
+ }
8998
+ }
8999
+
9000
+ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9001
+
9002
+ const int gindex = iq2_data_index(512);
9003
+
9004
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9005
+ const int * kmap_q2xs = iq2_data[gindex].map;
9006
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
9007
+
9008
+ GGML_ASSERT(quant_weights);
9009
+ GGML_ASSERT(kmap_q2xs);
9010
+ GGML_ASSERT(kgrid_q2xs);
9011
+ GGML_ASSERT(kneighbors_q2xs);
9012
+ GGML_ASSERT(n%QK_K == 0);
9013
+
9014
+ const int kMaxQ = 3;
9015
+
9016
+ const int nbl = n/256;
9017
+
9018
+ block_iq2_xs * y = vy;
9019
+
9020
+ float scales[QK_K/16];
9021
+ float weight[16];
9022
+ float xval[16];
9023
+ int8_t L[16];
9024
+ int8_t Laux[16];
9025
+ float waux[16];
9026
+ bool is_on_grid[2];
9027
+ bool is_on_grid_aux[2];
9028
+ uint8_t block_signs[2];
9029
+ uint16_t q2[2*(QK_K/16)];
9030
+
9031
+ for (int ibl = 0; ibl < nbl; ++ibl) {
9032
+
9033
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
9034
+ memset(q2, 0, QK_K/4);
9035
+ memset(y[ibl].scales, 0, QK_K/32);
9036
+
9037
+ float max_scale = 0;
9038
+
9039
+ const float * xbl = x + QK_K*ibl;
9040
+ float sumx2 = 0;
9041
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
9042
+ float sigma2 = sumx2/QK_K;
9043
+
9044
+ for (int ib = 0; ib < QK_K/16; ++ib) {
9045
+ const float * xb = xbl + 16*ib;
9046
+ const float * qw = quant_weights + QK_K*ibl + 16*ib;
9047
+ for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9048
+ for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
9049
+ for (int k = 0; k < 2; ++k) {
9050
+ int nflip = 0;
9051
+ uint8_t s = 0;
9052
+ for (int i = 0; i < 8; ++i) {
9053
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
9054
+ else {
9055
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
9056
+ }
9057
+ }
9058
+ if (nflip%2) {
9059
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
9060
+ for (int i = 1; i < 8; ++i) {
9061
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
9062
+ if (ax < min) {
9063
+ min = ax; imin = i;
9064
+ }
9065
+ }
9066
+ xval[8*k+imin] = -xval[8*k+imin];
9067
+ s ^= (1 << imin);
9068
+ }
9069
+ block_signs[k] = s & 127;
9070
+ }
9071
+ float max = xval[0];
9072
+ for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
9073
+ if (!max) {
9074
+ scales[ib] = 0;
9075
+ memset(L, 0, 16);
9076
+ continue;
9077
+ }
9078
+ float best = 0;
9079
+ float scale = max/(2*kMaxQ-1);
9080
+ is_on_grid[0] = is_on_grid[1] = true;
9081
+ for (int is = -9; is <= 9; ++is) {
9082
+ float id = (2*kMaxQ-1+is*0.1f)/max;
9083
+ float this_scale = 1/id;
9084
+ for (int k = 0; k < 2; ++k) {
9085
+ for (int i = 0; i < 8; ++i) {
9086
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
9087
+ Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
9088
+ }
9089
+ uint16_t u = 0;
9090
+ for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
9091
+ int grid_index = kmap_q2xs[u];
9092
+ is_on_grid_aux[k] = true;
9093
+ if (grid_index < 0) {
9094
+ is_on_grid_aux[k] = false;
9095
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
9096
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
9097
+ }
9098
+ }
9099
+ float sumqx = 0, sumq2 = 0;
9100
+ for (int i = 0; i < 16; ++i) {
9101
+ float w = weight[i];
9102
+ float q = 2*Laux[i] + 1;
9103
+ sumqx += w*xval[i]*q;
9104
+ sumq2 += w*q*q;
9105
+ }
9106
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
9107
+ scale = sumqx/sumq2; best = scale*sumqx;
9108
+ for (int i = 0; i < 16; ++i) L[i] = Laux[i];
9109
+ for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k];
9110
+ }
9111
+ }
9112
+ int n_not_ongrid = 0;
9113
+ for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9114
+ if (n_not_ongrid > 0 && scale > 0) {
9115
+ float id = 1/scale;
9116
+ for (int k = 0; k < 2; ++k) {
9117
+ if (is_on_grid[k]) continue;
9118
+ uint16_t u = 0;
9119
+ for (int i = 0; i < 8; ++i) {
9120
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
9121
+ l = MAX(0, MIN(kMaxQ-1, l));
9122
+ u |= (l << 2*i);
9123
+ L[8*k + i] = l;
9124
+ }
9125
+ int grid_index = kmap_q2xs[u];
9126
+ if (grid_index < 0) {
9127
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
9128
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
9129
+ }
9130
+ }
9131
+ float sumqx = 0, sumq2 = 0;
9132
+ for (int i = 0; i < 16; ++i) {
9133
+ float w = weight[i];
9134
+ float q = 2*L[i] + 1;
9135
+ sumqx += w*xval[i]*q;
9136
+ sumq2 += w*q*q;
9137
+ }
9138
+ if (sumq2 > 0) scale = sumqx/sumq2;
9139
+ }
9140
+ if (scale < 0) {
9141
+ scale = -scale;
9142
+ for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
9143
+ }
9144
+ for (int k = 0; k < 2; ++k) {
9145
+ uint16_t u = 0;
9146
+ for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
9147
+ int grid_index = kmap_q2xs[u];
9148
+ if (grid_index < 0) {
9149
+ printf("Oops: found point %u not on grid:", u);
9150
+ for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
9151
+ printf("\n");
9152
+ GGML_ASSERT(false);
9153
+ }
9154
+ q2[2*ib+k] = grid_index | (block_signs[k] << 9);
9155
+ }
9156
+ GGML_ASSERT(scale >= 0);
9157
+ scales[ib] = scale;
9158
+ max_scale = MAX(max_scale, scale);
9159
+ }
9160
+
9161
+ if (!max_scale) {
9162
+ memset(y[ibl].qs, 0, QK_K/4);
9163
+ continue;
9164
+ }
9165
+
9166
+ float d = max_scale/31;
9167
+ y[ibl].d = GGML_FP32_TO_FP16(d);
9168
+ float id = 1/d;
9169
+ for (int ib = 0; ib < QK_K/16; ++ib) {
9170
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
9171
+ l = MAX(0, MIN(15, l));
9172
+ if (ib%2 == 0) y[ibl].scales[ib/2] = l;
9173
+ else y[ibl].scales[ib/2] |= (l << 4);
9174
+ }
9175
+ memcpy(y[ibl].qs, q2, QK_K/4);
9176
+
9177
+ }
9178
+ }
9179
+
9180
+ size_t quantize_iq2_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
9181
+ (void)hist;
9182
+ GGML_ASSERT(n_per_row%QK_K == 0);
9183
+ int nblock = n_per_row/QK_K;
9184
+ char * qrow = (char *)dst;
9185
+ for (int row = 0; row < nrow; ++row) {
9186
+ quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
9187
+ src += n_per_row;
9188
+ qrow += nblock*sizeof(block_iq2_xxs);
9189
+ }
9190
+ return nrow * nblock * sizeof(block_iq2_xxs);
9191
+ }
9192
+
9193
+ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
9194
+ (void)hist;
9195
+ GGML_ASSERT(n_per_row%QK_K == 0);
9196
+ int nblock = n_per_row/QK_K;
9197
+ char * qrow = (char *)dst;
9198
+ for (int row = 0; row < nrow; ++row) {
9199
+ quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
9200
+ src += n_per_row;
9201
+ qrow += nblock*sizeof(block_iq2_xs);
9202
+ }
9203
+ return nrow * nblock * sizeof(block_iq2_xs);
9204
+ }
9205
+