whisper.rn 0.4.0-rc.6 → 0.4.0-rc.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml-quants.c CHANGED
@@ -5,6 +5,8 @@
5
5
  #include <string.h>
6
6
  #include <assert.h>
7
7
  #include <float.h>
8
+ #include <stdlib.h> // for qsort
9
+ #include <stdio.h> // for WSP_GGML_ASSERT
8
10
 
9
11
  #ifdef __ARM_NEON
10
12
 
@@ -272,10 +274,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
272
274
 
273
275
  // vaddvq_s16
274
276
  // vpaddq_s16
277
+ // vpaddq_s32
275
278
  // vaddvq_s32
276
279
  // vaddvq_f32
277
280
  // vmaxvq_f32
278
281
  // vcvtnq_s32_f32
282
+ // vzip1_u8
283
+ // vzip2_u8
279
284
 
280
285
  inline static int32_t vaddvq_s16(int16x8_t v) {
281
286
  return
@@ -291,6 +296,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
291
296
  return vcombine_s16(a0, b0);
292
297
  }
293
298
 
299
+ inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
300
+ int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
301
+ int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
302
+ return vcombine_s32(a0, b0);
303
+ }
304
+
294
305
  inline static int32_t vaddvq_s32(int32x4_t v) {
295
306
  return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
296
307
  }
@@ -316,6 +327,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
316
327
  return res;
317
328
  }
318
329
 
330
+ inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
331
+ uint8x8_t res;
332
+
333
+ res[0] = a[0]; res[1] = b[0];
334
+ res[2] = a[1]; res[3] = b[1];
335
+ res[4] = a[2]; res[5] = b[2];
336
+ res[6] = a[3]; res[7] = b[3];
337
+
338
+ return res;
339
+ }
340
+
341
+ inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
342
+ uint8x8_t res;
343
+
344
+ res[0] = a[4]; res[1] = b[4];
345
+ res[2] = a[5]; res[3] = b[5];
346
+ res[4] = a[6]; res[5] = b[6];
347
+ res[6] = a[7]; res[7] = b[7];
348
+
349
+ return res;
350
+ }
351
+
319
352
  // vld1q_s16_x2
320
353
  // vld1q_u8_x2
321
354
  // vld1q_u8_x4
@@ -407,6 +440,22 @@ inline static wsp_ggml_int8x16x4_t wsp_ggml_vld1q_s8_x4(const int8_t * ptr) {
407
440
  #define wsp_ggml_vld1q_s8_x4 vld1q_s8_x4
408
441
 
409
442
  #endif
443
+
444
+ #if !defined(__ARM_FEATURE_DOTPROD)
445
+
446
+ inline static int32x4_t wsp_ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
447
+ const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
448
+ const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
449
+
450
+ return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
451
+ }
452
+
453
+ #else
454
+
455
+ #define wsp_ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
456
+
457
+ #endif
458
+
410
459
  #endif
411
460
 
412
461
  #if defined(__ARM_NEON) || defined(__wasm_simd128__)
@@ -466,6 +515,7 @@ void wsp_quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
466
515
  wsp_quantize_row_q4_0_reference(x, y, k);
467
516
  }
468
517
 
518
+
469
519
  void wsp_quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
470
520
  const int qk = QK4_1;
471
521
 
@@ -1195,7 +1245,8 @@ static inline int nearest_int(float fval) {
1195
1245
  return (i & 0x007fffff) - 0x00400000;
1196
1246
  }
1197
1247
 
1198
- static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
1248
+ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
1249
+ const float * restrict qw) {
1199
1250
  float max = 0;
1200
1251
  float amax = 0;
1201
1252
  for (int i = 0; i < n; ++i) {
@@ -1221,14 +1272,18 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
1221
1272
  rmse_type = -rmse_type;
1222
1273
  return_early = true;
1223
1274
  }
1224
- int weight_type = rmse_type%2;
1225
1275
  float sumlx = 0;
1226
1276
  float suml2 = 0;
1277
+ #ifdef HAVE_BUGGY_APPLE_LINKER
1278
+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
1279
+ for (volatile int i = 0; i < n; ++i) {
1280
+ #else
1227
1281
  for (int i = 0; i < n; ++i) {
1282
+ #endif
1228
1283
  int l = nearest_int(iscale * x[i]);
1229
1284
  l = MAX(-nmax, MIN(nmax-1, l));
1230
1285
  L[i] = l + nmax;
1231
- float w = weight_type == 1 ? x[i] * x[i] : 1;
1286
+ float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
1232
1287
  sumlx += w*x[i]*l;
1233
1288
  suml2 += w*l*l;
1234
1289
  }
@@ -1244,7 +1299,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
1244
1299
  for (int i = 0; i < n; ++i) {
1245
1300
  int l = nearest_int(iscale * x[i]);
1246
1301
  l = MAX(-nmax, MIN(nmax-1, l));
1247
- float w = weight_type == 1 ? x[i] * x[i] : 1;
1302
+ float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
1248
1303
  sumlx += w*x[i]*l;
1249
1304
  suml2 += w*l*l;
1250
1305
  }
@@ -1592,6 +1647,246 @@ size_t wsp_ggml_wsp_quantize_q2_K(const float * restrict src, void * restrict ds
1592
1647
  return (n/QK_K*sizeof(block_q2_K));
1593
1648
  }
1594
1649
 
1650
+ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
1651
+ uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
1652
+ float rmin, float rdelta, int nstep, bool use_mad) {
1653
+ float min = x[0];
1654
+ float max = x[0];
1655
+ float sum_w = weights ? weights[0] : x[0]*x[0];
1656
+ float sum_x = sum_w * x[0];
1657
+ #ifdef HAVE_BUGGY_APPLE_LINKER
1658
+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
1659
+ for (volatile int i = 1; i < n; ++i) {
1660
+ #else
1661
+ for (int i = 1; i < n; ++i) {
1662
+ #endif
1663
+ if (x[i] < min) min = x[i];
1664
+ if (x[i] > max) max = x[i];
1665
+ float w = weights ? weights[i] : x[i]*x[i];
1666
+ sum_w += w;
1667
+ sum_x += w * x[i];
1668
+ }
1669
+ if (min > 0) {
1670
+ min = 0;
1671
+ }
1672
+ if (max <= min) {
1673
+ memset(L, 0, n);
1674
+ *the_min = -min;
1675
+ return 0.f;
1676
+ }
1677
+ float iscale = nmax/(max - min);
1678
+ float scale = 1/iscale;
1679
+ float best_mad = 0;
1680
+ for (int i = 0; i < n; ++i) {
1681
+ int l = nearest_int(iscale*(x[i] - min));
1682
+ L[i] = MAX(0, MIN(nmax, l));
1683
+ float diff = scale * L[i] + min - x[i];
1684
+ diff = use_mad ? fabsf(diff) : diff*diff;
1685
+ float w = weights ? weights[i] : x[i]*x[i];
1686
+ best_mad += w * diff;
1687
+ }
1688
+ if (nstep < 1) {
1689
+ *the_min = -min;
1690
+ return scale;
1691
+ }
1692
+ for (int is = 0; is <= nstep; ++is) {
1693
+ iscale = (rmin + rdelta*is + nmax)/(max - min);
1694
+ float sum_l = 0, sum_l2 = 0, sum_xl = 0;
1695
+ for (int i = 0; i < n; ++i) {
1696
+ int l = nearest_int(iscale*(x[i] - min));
1697
+ l = MAX(0, MIN(nmax, l));
1698
+ Laux[i] = l;
1699
+ float w = weights ? weights[i] : x[i]*x[i];
1700
+ sum_l += w*l;
1701
+ sum_l2 += w*l*l;
1702
+ sum_xl += w*l*x[i];
1703
+ }
1704
+ float D = sum_w * sum_l2 - sum_l * sum_l;
1705
+ if (D > 0) {
1706
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
1707
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
1708
+ if (this_min > 0) {
1709
+ this_min = 0;
1710
+ this_scale = sum_xl / sum_l2;
1711
+ }
1712
+ float mad = 0;
1713
+ for (int i = 0; i < n; ++i) {
1714
+ float diff = this_scale * Laux[i] + this_min - x[i];
1715
+ diff = use_mad ? fabsf(diff) : diff*diff;
1716
+ float w = weights ? weights[i] : x[i]*x[i];
1717
+ mad += w * diff;
1718
+ }
1719
+ if (mad < best_mad) {
1720
+ for (int i = 0; i < n; ++i) {
1721
+ L[i] = Laux[i];
1722
+ }
1723
+ best_mad = mad;
1724
+ scale = this_scale;
1725
+ min = this_min;
1726
+ }
1727
+ }
1728
+ }
1729
+ *the_min = -min;
1730
+ return scale;
1731
+ }
1732
+
1733
+ static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
1734
+ float max = 0;
1735
+ for (int i = 0; i < n; ++i) {
1736
+ max = MAX(max, x[i]);
1737
+ }
1738
+ if (!max) { // all zero
1739
+ for (int i = 0; i < n; ++i) { L[i] = 0; }
1740
+ return 0.f;
1741
+ }
1742
+ float iscale = nmax / max;
1743
+ for (int i = 0; i < n; ++i) {
1744
+ L[i] = nearest_int(iscale * x[i]);
1745
+ }
1746
+ float scale = 1/iscale;
1747
+ float best_mse = 0;
1748
+ for (int i = 0; i < n; ++i) {
1749
+ float diff = x[i] - scale*L[i];
1750
+ float w = quant_weights[i];
1751
+ best_mse += w*diff*diff;
1752
+ }
1753
+ for (int is = -4; is <= 4; ++is) {
1754
+ if (is == 0) continue;
1755
+ float iscale_is = (0.1f*is + nmax)/max;
1756
+ float scale_is = 1/iscale_is;
1757
+ float mse = 0;
1758
+ for (int i = 0; i < n; ++i) {
1759
+ int l = nearest_int(iscale_is*x[i]);
1760
+ l = MIN(nmax, l);
1761
+ float diff = x[i] - scale_is*l;
1762
+ float w = quant_weights[i];
1763
+ mse += w*diff*diff;
1764
+ }
1765
+ if (mse < best_mse) {
1766
+ best_mse = mse;
1767
+ iscale = iscale_is;
1768
+ }
1769
+ }
1770
+ float sumlx = 0;
1771
+ float suml2 = 0;
1772
+ for (int i = 0; i < n; ++i) {
1773
+ int l = nearest_int(iscale * x[i]);
1774
+ l = MIN(nmax, l);
1775
+ L[i] = l;
1776
+ float w = quant_weights[i];
1777
+ sumlx += w*x[i]*l;
1778
+ suml2 += w*l*l;
1779
+ }
1780
+ for (int itry = 0; itry < 5; ++itry) {
1781
+ int n_changed = 0;
1782
+ for (int i = 0; i < n; ++i) {
1783
+ float w = quant_weights[i];
1784
+ float slx = sumlx - w*x[i]*L[i];
1785
+ float sl2 = suml2 - w*L[i]*L[i];
1786
+ if (slx > 0 && sl2 > 0) {
1787
+ int new_l = nearest_int(x[i] * sl2 / slx);
1788
+ new_l = MIN(nmax, new_l);
1789
+ if (new_l != L[i]) {
1790
+ slx += w*x[i]*new_l;
1791
+ sl2 += w*new_l*new_l;
1792
+ if (slx*slx*suml2 > sumlx*sumlx*sl2) {
1793
+ L[i] = new_l; sumlx = slx; suml2 = sl2;
1794
+ ++n_changed;
1795
+ }
1796
+ }
1797
+ }
1798
+ }
1799
+ if (!n_changed) {
1800
+ break;
1801
+ }
1802
+ }
1803
+ return sumlx / suml2;
1804
+ }
1805
+
1806
+ static void wsp_quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
1807
+ WSP_GGML_ASSERT(quant_weights);
1808
+ assert(k % QK_K == 0);
1809
+ const int nb = k / QK_K;
1810
+ const bool requantize = true;
1811
+
1812
+ uint8_t L[QK_K];
1813
+ uint8_t Laux[16];
1814
+ float mins[QK_K/16];
1815
+ float scales[QK_K/16];
1816
+ float sw[QK_K/16];
1817
+ float weight[QK_K/16];
1818
+ uint8_t Ls[QK_K/16], Lm[QK_K/16];
1819
+
1820
+ for (int i = 0; i < nb; i++) {
1821
+ memset(sw, 0, QK_K/16*sizeof(float));
1822
+ float sumx2 = 0;
1823
+ for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
1824
+ float sigma2 = sumx2/QK_K;
1825
+ for (int j = 0; j < QK_K/16; ++j) {
1826
+ const float * restrict qw = quant_weights + QK_K * i + 16*j;
1827
+ for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
1828
+ for (int l = 0; l < 16; ++l) sw[j] += weight[l];
1829
+ scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1830
+ }
1831
+
1832
+ float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1833
+ float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
1834
+ y[i].d = WSP_GGML_FP32_TO_FP16(dm);
1835
+ y[i].dmin = WSP_GGML_FP32_TO_FP16(mm);
1836
+ dm = WSP_GGML_FP16_TO_FP32(y[i].d);
1837
+ mm = WSP_GGML_FP16_TO_FP32(y[i].dmin);
1838
+
1839
+ for (int j = 0; j < QK_K/16; ++j) {
1840
+ y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1841
+ }
1842
+
1843
+ if (requantize) {
1844
+ for (int j = 0; j < QK_K/16; ++j) {
1845
+ const float d = dm * (y[i].scales[j] & 0xF);
1846
+ if (!d) continue;
1847
+ const float m = mm * (y[i].scales[j] >> 4);
1848
+ for (int ii = 0; ii < 16; ++ii) {
1849
+ int l = nearest_int((x[16*j + ii] + m)/d);
1850
+ l = MAX(0, MIN(3, l));
1851
+ L[16*j + ii] = l;
1852
+ }
1853
+ }
1854
+ }
1855
+
1856
+ #if QK_K == 256
1857
+ for (int j = 0; j < QK_K; j += 128) {
1858
+ for (int l = 0; l < 32; ++l) {
1859
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1860
+ }
1861
+ }
1862
+ #else
1863
+ for (int l = 0; l < 16; ++l) {
1864
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
1865
+ }
1866
+ #endif
1867
+
1868
+ x += QK_K;
1869
+
1870
+ }
1871
+ }
1872
+
1873
+ size_t wsp_quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
1874
+ (void)hist;
1875
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q2_K, n_per_row);
1876
+ if (!quant_weights) {
1877
+ wsp_quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
1878
+ }
1879
+ else {
1880
+ char * qrow = (char *)dst;
1881
+ for (int row = 0; row < nrow; ++row) {
1882
+ wsp_quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
1883
+ src += n_per_row;
1884
+ qrow += row_size;
1885
+ }
1886
+ }
1887
+ return nrow * row_size;
1888
+ }
1889
+
1595
1890
  //========================= 3-bit (de)-quantization
1596
1891
 
1597
1892
  void wsp_quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
@@ -1805,6 +2100,112 @@ size_t wsp_ggml_wsp_quantize_q3_K(const float * restrict src, void * restrict ds
1805
2100
  return (n/QK_K*sizeof(block_q3_K));
1806
2101
  }
1807
2102
 
2103
+ static void wsp_quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) {
2104
+ #if QK_K != 256
2105
+ (void)quant_weights;
2106
+ wsp_quantize_row_q3_K_reference(x, y, n_per_row);
2107
+ #else
2108
+ assert(n_per_row % QK_K == 0);
2109
+ const int nb = n_per_row / QK_K;
2110
+
2111
+ int8_t L[QK_K];
2112
+ float scales[QK_K / 16];
2113
+ float weight[16];
2114
+ float sw[QK_K / 16];
2115
+ int8_t Ls[QK_K / 16];
2116
+
2117
+ for (int i = 0; i < nb; i++) {
2118
+
2119
+ float sumx2 = 0;
2120
+ for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
2121
+ float sigma2 = 2*sumx2/QK_K;
2122
+
2123
+ for (int j = 0; j < QK_K/16; ++j) {
2124
+ if (quant_weights) {
2125
+ const float * qw = quant_weights ? quant_weights + QK_K * i + 16*j : NULL;
2126
+ for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]);
2127
+ } else {
2128
+ for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l];
2129
+ }
2130
+ float sumw = 0;
2131
+ for (int l = 0; l < 16; ++l) sumw += weight[l];
2132
+ sw[j] = sumw;
2133
+
2134
+ scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
2135
+
2136
+ }
2137
+
2138
+ memset(y[i].scales, 0, 12);
2139
+
2140
+ float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
2141
+ for (int j = 0; j < QK_K/16; ++j) {
2142
+ int l = Ls[j];
2143
+ if (j < 8) {
2144
+ y[i].scales[j] = l & 0xF;
2145
+ } else {
2146
+ y[i].scales[j-8] |= ((l & 0xF) << 4);
2147
+ }
2148
+ l >>= 4;
2149
+ y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
2150
+ }
2151
+ y[i].d = WSP_GGML_FP32_TO_FP16(d_block);
2152
+
2153
+ int8_t sc;
2154
+ for (int j = 0; j < QK_K/16; ++j) {
2155
+ sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
2156
+ sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
2157
+ float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc;
2158
+ if (!d) {
2159
+ continue;
2160
+ }
2161
+ for (int ii = 0; ii < 16; ++ii) {
2162
+ int l = nearest_int(x[16*j + ii]/d);
2163
+ l = MAX(-4, MIN(3, l));
2164
+ L[16*j + ii] = l + 4;
2165
+ }
2166
+ }
2167
+
2168
+ memset(y[i].hmask, 0, QK_K/8);
2169
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
2170
+ int m = 0;
2171
+ uint8_t hm = 1;
2172
+ for (int j = 0; j < QK_K; ++j) {
2173
+ if (L[j] > 3) {
2174
+ y[i].hmask[m] |= hm;
2175
+ L[j] -= 4;
2176
+ }
2177
+ if (++m == QK_K/8) {
2178
+ m = 0; hm <<= 1;
2179
+ }
2180
+ }
2181
+ for (int j = 0; j < QK_K; j += 128) {
2182
+ for (int l = 0; l < 32; ++l) {
2183
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
2184
+ }
2185
+ }
2186
+
2187
+ x += QK_K;
2188
+ }
2189
+ #endif
2190
+ }
2191
+
2192
+ size_t wsp_quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2193
+ (void)hist;
2194
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q3_K, n_per_row);
2195
+ if (!quant_weights) {
2196
+ wsp_quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
2197
+ }
2198
+ else {
2199
+ char * qrow = (char *)dst;
2200
+ for (int row = 0; row < nrow; ++row) {
2201
+ wsp_quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
2202
+ src += n_per_row;
2203
+ qrow += row_size;
2204
+ }
2205
+ }
2206
+ return nrow * row_size;
2207
+ }
2208
+
1808
2209
  // ====================== 4-bit (de)-quantization
1809
2210
 
1810
2211
  void wsp_quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
@@ -1970,6 +2371,108 @@ size_t wsp_ggml_wsp_quantize_q4_K(const float * restrict src, void * restrict ds
1970
2371
  return (n/QK_K*sizeof(block_q4_K));
1971
2372
  }
1972
2373
 
2374
+ static void wsp_quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) {
2375
+ #if QK_K != 256
2376
+ (void)quant_weights;
2377
+ wsp_quantize_row_q4_K_reference(x, y, n_per_row);
2378
+ #else
2379
+ assert(n_per_row % QK_K == 0);
2380
+ const int nb = n_per_row / QK_K;
2381
+
2382
+ uint8_t L[QK_K];
2383
+ uint8_t Laux[32];
2384
+ float weights[32];
2385
+ float mins[QK_K/32];
2386
+ float scales[QK_K/32];
2387
+
2388
+ for (int i = 0; i < nb; i++) {
2389
+
2390
+ float sum_x2 = 0;
2391
+ for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
2392
+ float sigma2 = sum_x2/QK_K;
2393
+ float av_x = sqrtf(sigma2);
2394
+
2395
+ float max_scale = 0; // as we are deducting the min, scales are always positive
2396
+ float max_min = 0;
2397
+ for (int j = 0; j < QK_K/32; ++j) {
2398
+ if (quant_weights) {
2399
+ const float * qw = quant_weights + QK_K*i + 32*j;
2400
+ for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
2401
+ } else {
2402
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2403
+ }
2404
+ scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
2405
+ //scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
2406
+ float scale = scales[j];
2407
+ if (scale > max_scale) {
2408
+ max_scale = scale;
2409
+ }
2410
+ float min = mins[j];
2411
+ if (min > max_min) {
2412
+ max_min = min;
2413
+ }
2414
+ }
2415
+
2416
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2417
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2418
+ for (int j = 0; j < QK_K/32; ++j) {
2419
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
2420
+ uint8_t lm = nearest_int(inv_min*mins[j]);
2421
+ ls = MIN(63, ls);
2422
+ lm = MIN(63, lm);
2423
+ if (j < 4) {
2424
+ y[i].scales[j] = ls;
2425
+ y[i].scales[j+4] = lm;
2426
+ } else {
2427
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
2428
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
2429
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
2430
+ }
2431
+ }
2432
+ y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/63.f);
2433
+ y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/63.f);
2434
+
2435
+ uint8_t sc, m;
2436
+ for (int j = 0; j < QK_K/32; ++j) {
2437
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
2438
+ const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc;
2439
+ if (!d) continue;
2440
+ const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * m;
2441
+ for (int ii = 0; ii < 32; ++ii) {
2442
+ int l = nearest_int((x[32*j + ii] + dm)/d);
2443
+ l = MAX(0, MIN(15, l));
2444
+ L[32*j + ii] = l;
2445
+ }
2446
+ }
2447
+ uint8_t * q = y[i].qs;
2448
+ for (int j = 0; j < QK_K; j += 64) {
2449
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
2450
+ q += 32;
2451
+ }
2452
+
2453
+ x += QK_K;
2454
+
2455
+ }
2456
+ #endif
2457
+ }
2458
+
2459
+ size_t wsp_quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2460
+ (void)hist;
2461
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_K, n_per_row);
2462
+ if (!quant_weights) {
2463
+ wsp_quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
2464
+ }
2465
+ else {
2466
+ char * qrow = (char *)dst;
2467
+ for (int row = 0; row < nrow; ++row) {
2468
+ wsp_quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
2469
+ src += n_per_row;
2470
+ qrow += row_size;
2471
+ }
2472
+ }
2473
+ return nrow * row_size;
2474
+ }
2475
+
1973
2476
  // ====================== 5-bit (de)-quantization
1974
2477
 
1975
2478
  void wsp_quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
@@ -2065,7 +2568,7 @@ void wsp_quantize_row_q5_K_reference(const float * restrict x, block_q5_K * rest
2065
2568
  #else
2066
2569
  float max_scale = 0, amax = 0;
2067
2570
  for (int j = 0; j < QK_K/16; ++j) {
2068
- scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
2571
+ scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL);
2069
2572
  float abs_scale = fabsf(scales[j]);
2070
2573
  if (abs_scale > amax) {
2071
2574
  amax = abs_scale;
@@ -2159,21 +2662,138 @@ void wsp_dewsp_quantize_row_q5_K(const block_q5_K * restrict x, float * restrict
2159
2662
  }
2160
2663
  }
2161
2664
 
2162
- void wsp_quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
2163
- assert(k % QK_K == 0);
2164
- block_q5_K * restrict y = vy;
2165
- wsp_quantize_row_q5_K_reference(x, y, k);
2166
- }
2665
+ void wsp_quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
2666
+ assert(k % QK_K == 0);
2667
+ block_q5_K * restrict y = vy;
2668
+ wsp_quantize_row_q5_K_reference(x, y, k);
2669
+ }
2670
+
2671
+ size_t wsp_ggml_wsp_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
2672
+ assert(k % QK_K == 0);
2673
+ (void)hist; // TODO: collect histograms
2674
+
2675
+ for (int j = 0; j < n; j += k) {
2676
+ block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
2677
+ wsp_quantize_row_q5_K_reference(src + j, y, k);
2678
+ }
2679
+ return (n/QK_K*sizeof(block_q5_K));
2680
+ }
2681
+
2682
+ static void wsp_quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) {
2683
+ #if QK_K != 256
2684
+ (void)quant_weights;
2685
+ wsp_quantize_row_q5_K_reference(x, y, n_per_row);
2686
+ #else
2687
+ assert(n_per_row % QK_K == 0);
2688
+ const int nb = n_per_row / QK_K;
2689
+
2690
+ uint8_t L[QK_K];
2691
+ float mins[QK_K/32];
2692
+ float scales[QK_K/32];
2693
+ float weights[32];
2694
+ uint8_t Laux[32];
2695
+
2696
+ for (int i = 0; i < nb; i++) {
2697
+
2698
+ float sum_x2 = 0;
2699
+ for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
2700
+ float sigma2 = sum_x2/QK_K;
2701
+ float av_x = sqrtf(sigma2);
2702
+
2703
+ float max_scale = 0; // as we are deducting the min, scales are always positive
2704
+ float max_min = 0;
2705
+ for (int j = 0; j < QK_K/32; ++j) {
2706
+ if (quant_weights) {
2707
+ const float * qw = quant_weights + QK_K*i + 32*j;
2708
+ for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
2709
+ } else {
2710
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2711
+ }
2712
+ scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
2713
+ float scale = scales[j];
2714
+ if (scale > max_scale) {
2715
+ max_scale = scale;
2716
+ }
2717
+ float min = mins[j];
2718
+ if (min > max_min) {
2719
+ max_min = min;
2720
+ }
2721
+ }
2722
+
2723
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2724
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2725
+ for (int j = 0; j < QK_K/32; ++j) {
2726
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
2727
+ uint8_t lm = nearest_int(inv_min*mins[j]);
2728
+ ls = MIN(63, ls);
2729
+ lm = MIN(63, lm);
2730
+ if (j < 4) {
2731
+ y[i].scales[j] = ls;
2732
+ y[i].scales[j+4] = lm;
2733
+ } else {
2734
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
2735
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
2736
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
2737
+ }
2738
+ }
2739
+ y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/63.f);
2740
+ y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/63.f);
2741
+
2742
+ uint8_t sc, m;
2743
+ for (int j = 0; j < QK_K/32; ++j) {
2744
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
2745
+ const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc;
2746
+ if (!d) continue;
2747
+ const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * m;
2748
+ for (int ii = 0; ii < 32; ++ii) {
2749
+ int l = nearest_int((x[32*j + ii] + dm)/d);
2750
+ l = MAX(0, MIN(31, l));
2751
+ L[32*j + ii] = l;
2752
+ }
2753
+ }
2754
+
2755
+ uint8_t * restrict qh = y[i].qh;
2756
+ uint8_t * restrict ql = y[i].qs;
2757
+ memset(qh, 0, QK_K/8);
2758
+
2759
+ uint8_t m1 = 1, m2 = 2;
2760
+ for (int n = 0; n < QK_K; n += 64) {
2761
+ for (int j = 0; j < 32; ++j) {
2762
+ int l1 = L[n + j];
2763
+ if (l1 > 15) {
2764
+ l1 -= 16; qh[j] |= m1;
2765
+ }
2766
+ int l2 = L[n + j + 32];
2767
+ if (l2 > 15) {
2768
+ l2 -= 16; qh[j] |= m2;
2769
+ }
2770
+ ql[j] = l1 | (l2 << 4);
2771
+ }
2772
+ m1 <<= 2; m2 <<= 2;
2773
+ ql += 32;
2774
+ }
2775
+
2776
+ x += QK_K;
2167
2777
 
2168
- size_t wsp_ggml_wsp_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
2169
- assert(k % QK_K == 0);
2170
- (void)hist; // TODO: collect histograms
2778
+ }
2779
+ #endif
2780
+ }
2171
2781
 
2172
- for (int j = 0; j < n; j += k) {
2173
- block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
2174
- wsp_quantize_row_q5_K_reference(src + j, y, k);
2782
+ size_t wsp_quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2783
+ (void)hist;
2784
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_K, n_per_row);
2785
+ if (!quant_weights) {
2786
+ wsp_quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
2175
2787
  }
2176
- return (n/QK_K*sizeof(block_q5_K));
2788
+ else {
2789
+ char * qrow = (char *)dst;
2790
+ for (int row = 0; row < nrow; ++row) {
2791
+ wsp_quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
2792
+ src += n_per_row;
2793
+ qrow += row_size;
2794
+ }
2795
+ }
2796
+ return nrow * row_size;
2177
2797
  }
2178
2798
 
2179
2799
  // ====================== 6-bit (de)-quantization
@@ -2192,7 +2812,7 @@ void wsp_quantize_row_q6_K_reference(const float * restrict x, block_q6_K * rest
2192
2812
 
2193
2813
  for (int ib = 0; ib < QK_K/16; ++ib) {
2194
2814
 
2195
- const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
2815
+ const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
2196
2816
  scales[ib] = scale;
2197
2817
 
2198
2818
  const float abs_scale = fabsf(scale);
@@ -2324,6 +2944,569 @@ size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, i
2324
2944
  return (n/QK_K*sizeof(block_q6_K));
2325
2945
  }
2326
2946
 
2947
+ static void wsp_quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) {
2948
+ #if QK_K != 256
2949
+ (void)quant_weights;
2950
+ wsp_quantize_row_q6_K_reference(x, y, n_per_row);
2951
+ #else
2952
+ assert(n_per_row % QK_K == 0);
2953
+ const int nb = n_per_row / QK_K;
2954
+
2955
+ int8_t L[QK_K];
2956
+ float scales[QK_K/16];
2957
+ //float weights[16];
2958
+
2959
+ for (int i = 0; i < nb; i++) {
2960
+
2961
+ //float sum_x2 = 0;
2962
+ //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j];
2963
+ //float sigma2 = sum_x2/QK_K;
2964
+
2965
+ float max_scale = 0;
2966
+ float max_abs_scale = 0;
2967
+
2968
+ for (int ib = 0; ib < QK_K/16; ++ib) {
2969
+
2970
+ float scale;
2971
+ if (quant_weights) {
2972
+ const float * qw = quant_weights + QK_K*i + 16*ib;
2973
+ //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]);
2974
+ //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights);
2975
+ scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw);
2976
+ } else {
2977
+ scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
2978
+ }
2979
+ scales[ib] = scale;
2980
+
2981
+ const float abs_scale = fabsf(scale);
2982
+ if (abs_scale > max_abs_scale) {
2983
+ max_abs_scale = abs_scale;
2984
+ max_scale = scale;
2985
+ }
2986
+
2987
+ }
2988
+
2989
+ if (!max_abs_scale) {
2990
+ memset(&y[i], 0, sizeof(block_q6_K));
2991
+ y[i].d = WSP_GGML_FP32_TO_FP16(0.f);
2992
+ x += QK_K;
2993
+ continue;
2994
+ }
2995
+
2996
+ float iscale = -128.f/max_scale;
2997
+ y[i].d = WSP_GGML_FP32_TO_FP16(1/iscale);
2998
+ for (int ib = 0; ib < QK_K/16; ++ib) {
2999
+ y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
3000
+ }
3001
+
3002
+ for (int j = 0; j < QK_K/16; ++j) {
3003
+ float d = WSP_GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
3004
+ if (!d) {
3005
+ continue;
3006
+ }
3007
+ for (int ii = 0; ii < 16; ++ii) {
3008
+ int l = nearest_int(x[16*j + ii]/d);
3009
+ l = MAX(-32, MIN(31, l));
3010
+ L[16*j + ii] = l + 32;
3011
+ }
3012
+ }
3013
+
3014
+ uint8_t * restrict ql = y[i].ql;
3015
+ uint8_t * restrict qh = y[i].qh;
3016
+ for (int j = 0; j < QK_K; j += 128) {
3017
+ for (int l = 0; l < 32; ++l) {
3018
+ const uint8_t q1 = L[j + l + 0] & 0xF;
3019
+ const uint8_t q2 = L[j + l + 32] & 0xF;
3020
+ const uint8_t q3 = L[j + l + 64] & 0xF;
3021
+ const uint8_t q4 = L[j + l + 96] & 0xF;
3022
+ ql[l+ 0] = q1 | (q3 << 4);
3023
+ ql[l+32] = q2 | (q4 << 4);
3024
+ qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
3025
+ }
3026
+ ql += 64;
3027
+ qh += 32;
3028
+ }
3029
+
3030
+ x += QK_K;
3031
+
3032
+ }
3033
+ #endif
3034
+ }
3035
+
3036
+ size_t wsp_quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3037
+ (void)hist;
3038
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q6_K, n_per_row);
3039
+ if (!quant_weights) {
3040
+ wsp_quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
3041
+ }
3042
+ else {
3043
+ char * qrow = (char *)dst;
3044
+ for (int row = 0; row < nrow; ++row) {
3045
+ wsp_quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
3046
+ src += n_per_row;
3047
+ qrow += row_size;
3048
+ }
3049
+ }
3050
+ return nrow * row_size;
3051
+ }
3052
+
3053
+ static void wsp_quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int n_per_row, const float * quant_weights) {
3054
+ static_assert(QK4_0 == 32, "QK4_0 must be 32");
3055
+
3056
+ if (!quant_weights) {
3057
+ wsp_quantize_row_q4_0_reference(x, y, n_per_row);
3058
+ return;
3059
+ }
3060
+
3061
+ float weight[QK4_0];
3062
+ int8_t L[QK4_0];
3063
+
3064
+ float sum_x2 = 0;
3065
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
3066
+ float sigma2 = sum_x2/n_per_row;
3067
+
3068
+ const int nb = n_per_row/QK4_0;
3069
+ for (int ib = 0; ib < nb; ++ib) {
3070
+ const float * xb = x + QK4_0 * ib;
3071
+ const float * qw = quant_weights + QK4_0 * ib;
3072
+ for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3073
+ float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
3074
+ y[ib].d = WSP_GGML_FP32_TO_FP16(d);
3075
+ for (int j = 0; j < 16; ++j) {
3076
+ y[ib].qs[j] = L[j] | (L[j+16] << 4);
3077
+ }
3078
+ }
3079
+ }
3080
+
3081
+ size_t wsp_quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3082
+ if (!quant_weights) {
3083
+ return wsp_ggml_wsp_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist);
3084
+ }
3085
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_0, n_per_row);
3086
+ char * qrow = (char *)dst;
3087
+ for (int row = 0; row < nrow; ++row) {
3088
+ wsp_quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
3089
+ src += n_per_row;
3090
+ qrow += row_size;
3091
+ }
3092
+ return nrow * row_size;
3093
+ }
3094
+
3095
+ static void wsp_quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int n_per_row, const float * quant_weights) {
3096
+ static_assert(QK4_1 == 32, "QK4_1 must be 32");
3097
+
3098
+ if (!quant_weights) {
3099
+ wsp_quantize_row_q4_1_reference(x, y, n_per_row);
3100
+ return;
3101
+ }
3102
+
3103
+ float weight[QK4_1];
3104
+ uint8_t L[QK4_1], Laux[QK4_1];
3105
+
3106
+ float sum_x2 = 0;
3107
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
3108
+ float sigma2 = sum_x2/n_per_row;
3109
+
3110
+ const int nb = n_per_row/QK4_1;
3111
+ for (int ib = 0; ib < nb; ++ib) {
3112
+ const float * xb = x + QK4_1 * ib;
3113
+ const float * qw = quant_weights + QK4_1 * ib;
3114
+ for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3115
+ float min;
3116
+ float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
3117
+ y[ib].d = WSP_GGML_FP32_TO_FP16(d);
3118
+ y[ib].m = WSP_GGML_FP32_TO_FP16(-min);
3119
+ for (int j = 0; j < 16; ++j) {
3120
+ y[ib].qs[j] = L[j] | (L[j+16] << 4);
3121
+ }
3122
+ }
3123
+ }
3124
+
3125
+ size_t wsp_quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3126
+ if (!quant_weights) {
3127
+ return wsp_ggml_wsp_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist);
3128
+ }
3129
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_1, n_per_row);
3130
+ char * qrow = (char *)dst;
3131
+ for (int row = 0; row < nrow; ++row) {
3132
+ wsp_quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
3133
+ src += n_per_row;
3134
+ qrow += row_size;
3135
+ }
3136
+ return nrow * row_size;
3137
+ }
3138
+
3139
+ static void wsp_quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int n_per_row, const float * quant_weights) {
3140
+ static_assert(QK5_0 == 32, "QK5_0 must be 32");
3141
+
3142
+ if (!quant_weights) {
3143
+ wsp_quantize_row_q5_0_reference(x, y, n_per_row);
3144
+ return;
3145
+ }
3146
+
3147
+ float weight[QK5_0];
3148
+ int8_t L[QK5_0];
3149
+
3150
+ float sum_x2 = 0;
3151
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
3152
+ float sigma2 = sum_x2/n_per_row;
3153
+
3154
+ const int nb = n_per_row/QK5_0;
3155
+ for (int ib = 0; ib < nb; ++ib) {
3156
+ const float * xb = x + QK5_0 * ib;
3157
+ const float * qw = quant_weights + QK5_0 * ib;
3158
+ for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3159
+ float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight);
3160
+ y[ib].d = WSP_GGML_FP32_TO_FP16(d);
3161
+
3162
+ uint32_t qh = 0;
3163
+
3164
+ for (int j = 0; j < 16; ++j) {
3165
+ const uint8_t xi0 = L[j];
3166
+ const uint8_t xi1 = L[j+16];
3167
+ y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
3168
+
3169
+ // get the 5-th bit and store it in qh at the right position
3170
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
3171
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
3172
+ }
3173
+
3174
+ memcpy(&y[ib].qh, &qh, sizeof(qh));
3175
+ }
3176
+ }
3177
+
3178
+ size_t wsp_quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3179
+ if (!quant_weights) {
3180
+ return wsp_ggml_wsp_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist);
3181
+ }
3182
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_0, n_per_row);
3183
+ char * qrow = (char *)dst;
3184
+ for (int row = 0; row < nrow; ++row) {
3185
+ wsp_quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
3186
+ src += n_per_row;
3187
+ qrow += row_size;
3188
+ }
3189
+ return nrow * row_size;
3190
+ }
3191
+
3192
+ static void wsp_quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int n_per_row, const float * quant_weights) {
3193
+ static_assert(QK5_1 == 32, "QK5_1 must be 32");
3194
+
3195
+ if (!quant_weights) {
3196
+ wsp_quantize_row_q5_1_reference(x, y, n_per_row);
3197
+ return;
3198
+ }
3199
+
3200
+ float weight[QK5_1];
3201
+ uint8_t L[QK5_1], Laux[QK5_1];
3202
+
3203
+ float sum_x2 = 0;
3204
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
3205
+ float sigma2 = sum_x2/n_per_row;
3206
+
3207
+ const int nb = n_per_row/QK5_1;
3208
+ for (int ib = 0; ib < nb; ++ib) {
3209
+ const float * xb = x + QK5_1 * ib;
3210
+ const float * qw = quant_weights + QK5_1 * ib;
3211
+ for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3212
+ float min;
3213
+ float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
3214
+ y[ib].d = WSP_GGML_FP32_TO_FP16(d);
3215
+ y[ib].m = WSP_GGML_FP32_TO_FP16(-min);
3216
+
3217
+ uint32_t qh = 0;
3218
+ for (int j = 0; j < 16; ++j) {
3219
+ const uint8_t xi0 = L[j];
3220
+ const uint8_t xi1 = L[j+16];
3221
+ y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
3222
+ // get the 5-th bit and store it in qh at the right position
3223
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
3224
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
3225
+ }
3226
+ memcpy(&y[ib].qh, &qh, sizeof(qh));
3227
+ }
3228
+ }
3229
+
3230
+ size_t wsp_quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3231
+ if (!quant_weights) {
3232
+ return wsp_ggml_wsp_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist);
3233
+ }
3234
+ size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_1, n_per_row);
3235
+ char * qrow = (char *)dst;
3236
+ for (int row = 0; row < nrow; ++row) {
3237
+ wsp_quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
3238
+ src += n_per_row;
3239
+ qrow += row_size;
3240
+ }
3241
+ return nrow * row_size;
3242
+ }
3243
+
3244
+ // ====================== "True" 2-bit (de)-quantization
3245
+
3246
+ static const uint64_t iq2xxs_grid[256] = {
3247
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3248
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3249
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
3250
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
3251
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
3252
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
3253
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
3254
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
3255
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
3256
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
3257
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
3258
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
3259
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
3260
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
3261
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
3262
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
3263
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
3264
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
3265
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
3266
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
3267
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
3268
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
3269
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
3270
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
3271
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
3272
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
3273
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
3274
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
3275
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
3276
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
3277
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
3278
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
3279
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
3280
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
3281
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
3282
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
3283
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
3284
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
3285
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
3286
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
3287
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
3288
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
3289
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
3290
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
3291
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
3292
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
3293
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
3294
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
3295
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
3296
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
3297
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
3298
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
3299
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
3300
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
3301
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
3302
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
3303
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
3304
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
3305
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
3306
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
3307
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
3308
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
3309
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
3310
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3311
+ };
3312
+
3313
+ static const uint64_t iq2xs_grid[512] = {
3314
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3315
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3316
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3317
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3318
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3319
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
3320
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
3321
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
3322
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
3323
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
3324
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
3325
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
3326
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
3327
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
3328
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
3329
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
3330
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
3331
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
3332
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
3333
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
3334
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
3335
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
3336
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
3337
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
3338
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
3339
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
3340
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
3341
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
3342
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
3343
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
3344
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
3345
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
3346
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
3347
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
3348
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
3349
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
3350
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
3351
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
3352
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
3353
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
3354
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
3355
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
3356
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
3357
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
3358
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
3359
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
3360
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
3361
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
3362
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
3363
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
3364
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
3365
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
3366
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
3367
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
3368
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
3369
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
3370
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
3371
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
3372
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
3373
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
3374
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
3375
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
3376
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
3377
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
3378
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
3379
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
3380
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
3381
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
3382
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
3383
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
3384
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
3385
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
3386
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
3387
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
3388
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
3389
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
3390
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
3391
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
3392
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
3393
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
3394
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
3395
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
3396
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
3397
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
3398
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
3399
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
3400
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
3401
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
3402
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
3403
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
3404
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
3405
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
3406
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
3407
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
3408
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
3409
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
3410
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
3411
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
3412
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
3413
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
3414
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
3415
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
3416
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
3417
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
3418
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
3419
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
3420
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
3421
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
3422
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
3423
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
3424
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
3425
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
3426
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
3427
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
3428
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
3429
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
3430
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
3431
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
3432
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
3433
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
3434
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
3435
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
3436
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
3437
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
3438
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
3439
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
3440
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
3441
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3442
+ };
3443
+
3444
+ static const uint8_t ksigns_iq2xs[128] = {
3445
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3446
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
3447
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
3448
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
3449
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
3450
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
3451
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
3452
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
3453
+ };
3454
+
3455
+ static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
3456
+
3457
+ void wsp_dewsp_quantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
3458
+ assert(k % QK_K == 0);
3459
+ const int nb = k / QK_K;
3460
+
3461
+ uint32_t aux32[2];
3462
+ const uint8_t * aux8 = (const uint8_t *)aux32;
3463
+
3464
+ for (int i = 0; i < nb; i++) {
3465
+
3466
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d);
3467
+
3468
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3469
+ memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
3470
+ const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
3471
+ for (int l = 0; l < 4; ++l) {
3472
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
3473
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
3474
+ for (int j = 0; j < 8; ++j) {
3475
+ y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3476
+ }
3477
+ y += 8;
3478
+ }
3479
+ }
3480
+ }
3481
+ }
3482
+
3483
+ // ====================== 2.3125 bpw (de)-quantization
3484
+
3485
+ void wsp_dewsp_quantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
3486
+ assert(k % QK_K == 0);
3487
+ const int nb = k / QK_K;
3488
+
3489
+ float db[2];
3490
+
3491
+ for (int i = 0; i < nb; i++) {
3492
+
3493
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d);
3494
+
3495
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3496
+ db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
3497
+ db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
3498
+ for (int l = 0; l < 4; ++l) {
3499
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
3500
+ const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
3501
+ for (int j = 0; j < 8; ++j) {
3502
+ y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3503
+ }
3504
+ y += 8;
3505
+ }
3506
+ }
3507
+ }
3508
+ }
3509
+
2327
3510
  //===================================== Q8_K ==============================================
2328
3511
 
2329
3512
  void wsp_quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -2346,7 +3529,9 @@ void wsp_quantize_row_q8_K_reference(const float * restrict x, block_q8_K * rest
2346
3529
  x += QK_K;
2347
3530
  continue;
2348
3531
  }
2349
- const float iscale = -128.f/max;
3532
+ //const float iscale = -128.f/max;
3533
+ // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
3534
+ const float iscale = -127.f/max;
2350
3535
  for (int j = 0; j < QK_K; ++j) {
2351
3536
  int v = nearest_int(iscale*x[j]);
2352
3537
  y[i].qs[j] = MIN(127, v);
@@ -2468,32 +3653,12 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict
2468
3653
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2469
3654
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2470
3655
 
2471
- #if defined(__ARM_FEATURE_DOTPROD)
2472
3656
  // dot product into int32x4_t
2473
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2474
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
3657
+ const int32x4_t p_0 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
3658
+ const int32x4_t p_1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2475
3659
 
2476
3660
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
2477
3661
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
2478
- #else
2479
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2480
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
2481
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
2482
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
2483
-
2484
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
2485
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
2486
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
2487
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
2488
-
2489
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2490
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2491
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2492
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2493
-
2494
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
2495
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
2496
- #endif
2497
3662
  }
2498
3663
 
2499
3664
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -2776,32 +3941,12 @@ void wsp_ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * re
2776
3941
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2777
3942
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2778
3943
 
2779
- #if defined(__ARM_FEATURE_DOTPROD)
2780
3944
  // dot product into int32x4_t
2781
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2782
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
3945
+ const int32x4_t p_0 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
3946
+ const int32x4_t p_1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2783
3947
 
2784
3948
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
2785
3949
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
2786
- #else
2787
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
2788
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
2789
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
2790
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
2791
-
2792
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
2793
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
2794
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
2795
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
2796
-
2797
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2798
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2799
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2800
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2801
-
2802
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
2803
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
2804
- #endif
2805
3950
  }
2806
3951
 
2807
3952
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
@@ -2963,32 +4108,12 @@ void wsp_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * re
2963
4108
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2964
4109
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2965
4110
 
2966
- #if defined(__ARM_FEATURE_DOTPROD)
2967
4111
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2968
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2969
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
4112
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
4113
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
2970
4114
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2971
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2972
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
2973
- #else
2974
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2975
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
2976
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
2977
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
2978
-
2979
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
2980
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
2981
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
2982
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
2983
-
2984
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2985
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2986
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2987
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2988
-
2989
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
2990
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
2991
- #endif
4115
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
4116
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
2992
4117
  }
2993
4118
 
2994
4119
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3275,32 +4400,12 @@ void wsp_ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * re
3275
4400
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
3276
4401
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3277
4402
 
3278
- #if defined(__ARM_FEATURE_DOTPROD)
3279
4403
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3280
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3281
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
4404
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
4405
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
3282
4406
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3283
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3284
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
3285
- #else
3286
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
3287
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
3288
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
3289
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
3290
-
3291
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
3292
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
3293
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
3294
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
3295
-
3296
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3297
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3298
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3299
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3300
-
3301
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
3302
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
3303
- #endif
4407
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
4408
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
3304
4409
  }
3305
4410
 
3306
4411
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
@@ -3550,34 +4655,13 @@ void wsp_ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * re
3550
4655
  const int8x16_t y1_0 = vld1q_s8(y1->qs);
3551
4656
  const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
3552
4657
 
3553
- #if defined(__ARM_FEATURE_DOTPROD)
3554
4658
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3555
- vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3556
- vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
4659
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
4660
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
3557
4661
 
3558
4662
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3559
- vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3560
- vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
3561
-
3562
- #else
3563
- const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3564
- const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3565
- const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3566
- const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3567
-
3568
- const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3569
- const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3570
- const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3571
- const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3572
-
3573
- const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3574
- const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3575
- const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3576
- const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3577
-
3578
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
3579
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
3580
- #endif
4663
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
4664
+ wsp_ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
3581
4665
  }
3582
4666
 
3583
4667
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3650,12 +4734,10 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
3650
4734
  const int nb = n / QK_K;
3651
4735
 
3652
4736
  #ifdef __ARM_NEON
3653
-
3654
4737
  const uint8x16_t m3 = vdupq_n_u8(0x3);
3655
4738
  const uint8x16_t m4 = vdupq_n_u8(0xF);
3656
- #if defined(__ARM_FEATURE_DOTPROD)
3657
- const int32x4_t vzero = vdupq_n_s32(0);
3658
- #endif
4739
+
4740
+ const int32x4_t vzero = vdupq_n_s32(0);
3659
4741
 
3660
4742
  wsp_ggml_int8x16x2_t q2bytes;
3661
4743
  uint8_t aux[16];
@@ -3663,7 +4745,6 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
3663
4745
  float sum = 0;
3664
4746
 
3665
4747
  for (int i = 0; i < nb; ++i) {
3666
-
3667
4748
  const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
3668
4749
  const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin);
3669
4750
 
@@ -3677,7 +4758,7 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
3677
4758
 
3678
4759
  const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
3679
4760
  const wsp_ggml_int16x8x2_t q8sums = wsp_ggml_vld1q_s16_x2(y[i].bsums);
3680
- const wsp_ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
4761
+ const wsp_ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
3681
4762
  const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
3682
4763
  vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
3683
4764
  const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@@ -3689,20 +4770,9 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
3689
4770
 
3690
4771
  // We use this macro instead of a function call because for some reason
3691
4772
  // the code runs 2-3% slower, even if the function is declared inline
3692
- #if defined(__ARM_FEATURE_DOTPROD)
3693
- #define MULTIPLY_ACCUM_WITH_SCALE(index)\
3694
- isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3695
- isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3696
- #else
3697
4773
  #define MULTIPLY_ACCUM_WITH_SCALE(index)\
3698
- {\
3699
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
3700
- vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
3701
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
3702
- vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
3703
- isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
3704
- }
3705
- #endif
4774
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
4775
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3706
4776
 
3707
4777
  #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
3708
4778
  q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;\
@@ -3710,26 +4780,23 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
3710
4780
  q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
3711
4781
  MULTIPLY_ACCUM_WITH_SCALE((index));
3712
4782
 
3713
-
3714
4783
  for (int j = 0; j < QK_K/128; ++j) {
3715
-
3716
4784
  const wsp_ggml_uint8x16x2_t q2bits = wsp_ggml_vld1q_u8_x2(q2); q2 += 32;
3717
4785
 
3718
4786
  wsp_ggml_int8x16x2_t q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
3719
4787
  q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
3720
4788
  q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
4789
+
3721
4790
  MULTIPLY_ACCUM_WITH_SCALE(0);
3722
4791
 
3723
4792
  SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
3724
-
3725
4793
  SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
3726
-
3727
4794
  SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
3728
4795
 
3729
4796
  is += 8;
3730
4797
  }
3731
- sum += d * isum;
3732
4798
 
4799
+ sum += d * isum;
3733
4800
  }
3734
4801
 
3735
4802
  *s = sum;
@@ -4043,11 +5110,9 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
4043
5110
  const int nb = n / QK_K;
4044
5111
 
4045
5112
  #ifdef __ARM_NEON
4046
-
4047
5113
  const uint8x16_t m3 = vdupq_n_u8(0x3);
4048
- #if defined(__ARM_FEATURE_DOTPROD)
4049
- const int32x4_t vzero = vdupq_n_s32(0);
4050
- #endif
5114
+
5115
+ const int32x4_t vzero = vdupq_n_s32(0);
4051
5116
 
4052
5117
  wsp_ggml_int8x16x4_t q2bytes;
4053
5118
 
@@ -4081,28 +5146,12 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
4081
5146
  q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
4082
5147
  q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
4083
5148
 
4084
- #if defined(__ARM_FEATURE_DOTPROD)
4085
- isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
4086
- isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
4087
- isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
4088
- isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
4089
- #else
4090
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
4091
- vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
4092
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
4093
- vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
4094
- isum1 += vaddvq_s16(p1) * scales[0];
4095
- isum2 += vaddvq_s16(p2) * scales[1];
4096
-
4097
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
4098
- vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
4099
- const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
4100
- vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
4101
- isum1 += vaddvq_s16(p3) * scales[2];
4102
- isum2 += vaddvq_s16(p4) * scales[3];
4103
- #endif
4104
- sum += d * (isum1 + isum2);
5149
+ isum1 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
5150
+ isum2 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
5151
+ isum1 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
5152
+ isum2 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
4105
5153
 
5154
+ sum += d * (isum1 + isum2);
4106
5155
  }
4107
5156
 
4108
5157
  *s = sum;
@@ -4328,9 +5377,7 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
4328
5377
  uint32_t utmp[4];
4329
5378
 
4330
5379
  const uint8x16_t m3b = vdupq_n_u8(0x3);
4331
- #ifdef __ARM_FEATURE_DOTPROD
4332
5380
  const int32x4_t vzero = vdupq_n_s32(0);
4333
- #endif
4334
5381
 
4335
5382
  const uint8x16_t m0 = vdupq_n_u8(1);
4336
5383
  const uint8x16_t m1 = vshlq_n_u8(m0, 1);
@@ -4382,22 +5429,11 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
4382
5429
  q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
4383
5430
  q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
4384
5431
 
4385
- #if defined(__ARM_FEATURE_DOTPROD)
4386
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
4387
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
4388
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
4389
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
4390
- #else
4391
- int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
4392
- vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
4393
- int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
4394
- vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
4395
- int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
4396
- vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
4397
- int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
4398
- vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
4399
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
4400
- #endif
5432
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
5433
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
5434
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
5435
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
5436
+
4401
5437
  scale += 4;
4402
5438
 
4403
5439
  q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
@@ -4410,22 +5446,11 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
4410
5446
  q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
4411
5447
  q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
4412
5448
 
4413
- #if defined(__ARM_FEATURE_DOTPROD)
4414
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
4415
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
4416
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
4417
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
4418
- #else
4419
- p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
4420
- vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
4421
- p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
4422
- vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
4423
- p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
4424
- vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
4425
- p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
4426
- vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
4427
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
4428
- #endif
5449
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
5450
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
5451
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
5452
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
5453
+
4429
5454
  scale += 4;
4430
5455
 
4431
5456
  if (j == 0) {
@@ -4864,10 +5889,7 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
4864
5889
  const int nb = n / QK_K;
4865
5890
 
4866
5891
  #ifdef __ARM_NEON
4867
-
4868
- #ifdef __ARM_FEATURE_DOTPROD
4869
- const int32x4_t vzero = vdupq_n_s32(0);
4870
- #endif
5892
+ const int32x4_t vzero = vdupq_n_s32(0);
4871
5893
 
4872
5894
  const uint8x16_t m3b = vdupq_n_u8(0x3);
4873
5895
  const uint8x16_t mh = vdupq_n_u8(4);
@@ -4908,22 +5930,10 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
4908
5930
  q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
4909
5931
  q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
4910
5932
 
4911
- #if defined(__ARM_FEATURE_DOTPROD)
4912
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
4913
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
4914
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
4915
- isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
4916
- #else
4917
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
4918
- vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
4919
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
4920
- vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
4921
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
4922
- vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
4923
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
4924
- vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
4925
- isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
4926
- #endif
5933
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
5934
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
5935
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
5936
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
4927
5937
 
4928
5938
  sum += d * isum;
4929
5939
 
@@ -5228,11 +6238,8 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re
5228
6238
  uint32_t utmp[4];
5229
6239
 
5230
6240
  #ifdef __ARM_NEON
5231
-
5232
6241
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5233
- #ifdef __ARM_FEATURE_DOTPROD
5234
6242
  const int32x4_t mzero = vdupq_n_s32(0);
5235
- #endif
5236
6243
 
5237
6244
  wsp_ggml_int8x16x2_t q4bytes;
5238
6245
  wsp_ggml_int8x16x2_t q8bytes;
@@ -5269,44 +6276,22 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re
5269
6276
  int32_t sumi2 = 0;
5270
6277
 
5271
6278
  for (int j = 0; j < QK_K/64; ++j) {
5272
-
5273
6279
  const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); q4 += 32;
5274
6280
 
5275
- #ifdef __ARM_FEATURE_DOTPROD
5276
6281
  q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
5277
6282
  q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5278
6283
  q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5279
6284
 
5280
- const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5281
- sumi1 += vaddvq_s32(p1) * scales[2*j+0];
5282
-
5283
- q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
5284
- q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5285
- q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5286
-
5287
- const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5288
-
5289
- sumi2 += vaddvq_s32(p2) * scales[2*j+1];
5290
- #else
5291
- q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
5292
- q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5293
- q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5294
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
5295
- vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
5296
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
5297
- vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
5298
- sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
6285
+ const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
6286
+ sumi1 += vaddvq_s32(p1) * scales[2*j+0];
5299
6287
 
5300
6288
  q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
5301
6289
  q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5302
6290
  q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5303
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
5304
- vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
5305
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
5306
- vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
5307
- sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
5308
6291
 
5309
- #endif
6292
+ const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
6293
+
6294
+ sumi2 += vaddvq_s32(p2) * scales[2*j+1];
5310
6295
  }
5311
6296
 
5312
6297
  sumf += d * (sumi1 + sumi2);
@@ -5603,12 +6588,9 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re
5603
6588
  const int nb = n / QK_K;
5604
6589
 
5605
6590
  #ifdef __ARM_NEON
5606
-
5607
6591
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5608
6592
 
5609
- #ifdef __ARM_FEATURE_DOTPROD
5610
6593
  const int32x4_t mzero = vdupq_n_s32(0);
5611
- #endif
5612
6594
 
5613
6595
  float sumf = 0;
5614
6596
 
@@ -5636,41 +6618,20 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re
5636
6618
 
5637
6619
  const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4);
5638
6620
 
5639
- #ifdef __ARM_FEATURE_DOTPROD
5640
6621
  q8bytes = wsp_ggml_vld1q_s8_x4(q8);
5641
6622
  q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5642
6623
  q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5643
6624
 
5644
- const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
6625
+ const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5645
6626
  const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
5646
6627
 
5647
6628
  q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5648
6629
  q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5649
6630
 
5650
- const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
6631
+ const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
5651
6632
  const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
5652
6633
 
5653
- #else
5654
- q8bytes = wsp_ggml_vld1q_s8_x4(q8);
5655
- q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5656
- q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5657
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
5658
- vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
5659
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
5660
- vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
5661
- int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
5662
-
5663
- q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5664
- q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5665
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
5666
- vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
5667
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
5668
- vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
5669
- int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
5670
-
5671
- #endif
5672
6634
  sumf += d * (sumi1 + sumi2);
5673
-
5674
6635
  }
5675
6636
 
5676
6637
  *s = sumf - sum_mins;
@@ -5875,15 +6836,11 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re
5875
6836
 
5876
6837
  uint32_t utmp[4];
5877
6838
 
5878
-
5879
6839
  #ifdef __ARM_NEON
5880
-
5881
6840
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5882
6841
  const uint8x16_t mone = vdupq_n_u8(1);
5883
6842
  const uint8x16_t mtwo = vdupq_n_u8(2);
5884
- #if defined(__ARM_FEATURE_DOTPROD)
5885
6843
  const int32x4_t mzero = vdupq_n_s32(0);
5886
- #endif
5887
6844
 
5888
6845
  wsp_ggml_int8x16x4_t q5bytes;
5889
6846
 
@@ -5938,28 +6895,11 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re
5938
6895
  q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
5939
6896
  q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
5940
6897
 
5941
- #if defined(__ARM_FEATURE_DOTPROD)
5942
-
5943
- sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
5944
- sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
5945
- #else
5946
-
5947
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
5948
- vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
5949
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
5950
- vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
5951
- sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
5952
-
5953
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
5954
- vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
5955
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
5956
- vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
5957
- sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
5958
- #endif
6898
+ sumi += vaddvq_s32(wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
6899
+ sumi += vaddvq_s32(wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
5959
6900
  }
5960
6901
 
5961
6902
  sumf += d * sumi - dmin * sumi_mins;
5962
-
5963
6903
  }
5964
6904
 
5965
6905
  *s = sumf;
@@ -6311,12 +7251,9 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re
6311
7251
  const int nb = n / QK_K;
6312
7252
 
6313
7253
  #ifdef __ARM_NEON
6314
-
6315
7254
  const uint8x16_t m4b = vdupq_n_u8(0xf);
6316
7255
  const uint8x16_t mh = vdupq_n_u8(16);
6317
- #if defined(__ARM_FEATURE_DOTPROD)
6318
7256
  const int32x4_t mzero = vdupq_n_s32(0);
6319
- #endif
6320
7257
 
6321
7258
  wsp_ggml_int8x16x4_t q5bytes;
6322
7259
  wsp_ggml_uint8x16x4_t q5h;
@@ -6348,32 +7285,12 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re
6348
7285
  q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
6349
7286
  q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
6350
7287
 
6351
- #if defined(__ARM_FEATURE_DOTPROD)
6352
-
6353
- int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
6354
- int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
6355
- int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
6356
- int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
7288
+ int32_t sumi1 = sc[0] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
7289
+ int32_t sumi2 = sc[1] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
7290
+ int32_t sumi3 = sc[2] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
7291
+ int32_t sumi4 = sc[3] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
6357
7292
 
6358
7293
  sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
6359
-
6360
- #else
6361
-
6362
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
6363
- vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
6364
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
6365
- vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
6366
- int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
6367
-
6368
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
6369
- vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
6370
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
6371
- vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
6372
- sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
6373
-
6374
- sumf += d*sumi;
6375
- #endif
6376
-
6377
7294
  }
6378
7295
 
6379
7296
  *s = sumf;
@@ -6600,13 +7517,10 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
6600
7517
  const int nb = n / QK_K;
6601
7518
 
6602
7519
  #ifdef __ARM_NEON
6603
-
6604
7520
  float sum = 0;
6605
7521
 
6606
7522
  const uint8x16_t m4b = vdupq_n_u8(0xF);
6607
- #if defined(__ARM_FEATURE_DOTPROD)
6608
7523
  const int32x4_t vzero = vdupq_n_s32(0);
6609
- #endif
6610
7524
  //const int8x16_t m32s = vdupq_n_s8(32);
6611
7525
 
6612
7526
  const uint8x16_t mone = vdupq_n_u8(3);
@@ -6626,7 +7540,7 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
6626
7540
 
6627
7541
  const wsp_ggml_int16x8x2_t q8sums = wsp_ggml_vld1q_s16_x2(y[i].bsums);
6628
7542
  const int8x16_t scales = vld1q_s8(scale);
6629
- const wsp_ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
7543
+ const wsp_ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
6630
7544
 
6631
7545
  const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
6632
7546
  vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@@ -6658,31 +7572,13 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
6658
7572
  q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
6659
7573
  q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
6660
7574
 
6661
- #if defined(__ARM_FEATURE_DOTPROD)
7575
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
7576
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
7577
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
7578
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6662
7579
 
6663
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6664
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6665
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6666
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6667
7580
  scale += 4;
6668
7581
 
6669
- #else
6670
-
6671
- int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
6672
- vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
6673
- int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
6674
- vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
6675
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
6676
- scale += 2;
6677
-
6678
- int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
6679
- vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
6680
- int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
6681
- vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
6682
- isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
6683
- scale += 2;
6684
- #endif
6685
-
6686
7582
  q8bytes = wsp_ggml_vld1q_s8_x4(q8); q8 += 64;
6687
7583
 
6688
7584
  shifted = vshrq_n_u8(qhbits.val[0], 4);
@@ -6703,34 +7599,11 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
6703
7599
  q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
6704
7600
  q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
6705
7601
 
6706
- #if defined(__ARM_FEATURE_DOTPROD)
6707
-
6708
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6709
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6710
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6711
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
7602
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
7603
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
7604
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
7605
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6712
7606
  scale += 4;
6713
-
6714
- //for (int l = 0; l < 4; ++l) {
6715
- // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
6716
- // isum += vaddvq_s32(p) * *scale++;
6717
- //}
6718
- #else
6719
- p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
6720
- vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
6721
- p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
6722
- vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
6723
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
6724
- scale += 2;
6725
-
6726
- p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
6727
- vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
6728
- p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
6729
- vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
6730
- isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
6731
- scale += 2;
6732
- #endif
6733
-
6734
7607
  }
6735
7608
  //sum += isum * d_all * y[i].d;
6736
7609
  sum += d_all * y[i].d * (isum - 32 * isum_mins);
@@ -7076,14 +7949,11 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
7076
7949
  const int nb = n / QK_K;
7077
7950
 
7078
7951
  #ifdef __ARM_NEON
7079
-
7080
7952
  float sum = 0;
7081
7953
 
7082
7954
  const uint8x16_t m4b = vdupq_n_u8(0xF);
7083
7955
  const int8x16_t m32s = vdupq_n_s8(32);
7084
- #if defined(__ARM_FEATURE_DOTPROD)
7085
7956
  const int32x4_t vzero = vdupq_n_s32(0);
7086
- #endif
7087
7957
 
7088
7958
  const uint8x16_t mone = vdupq_n_u8(3);
7089
7959
 
@@ -7119,26 +7989,10 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
7119
7989
  q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
7120
7990
  q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
7121
7991
 
7122
- #if defined(__ARM_FEATURE_DOTPROD)
7123
-
7124
- isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
7125
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
7126
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
7127
- vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
7128
- #else
7129
-
7130
- int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
7131
- vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
7132
- int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
7133
- vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
7134
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
7135
-
7136
- int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
7137
- vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
7138
- int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
7139
- vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
7140
- isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
7141
- #endif
7992
+ isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
7993
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
7994
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
7995
+ vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
7142
7996
 
7143
7997
  sum += isum * d_all * y[i].d;
7144
7998
 
@@ -7380,3 +8234,958 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
7380
8234
  }
7381
8235
 
7382
8236
  #endif
8237
+
8238
+ static const int8_t keven_signs_q2xs[1024] = {
8239
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
8240
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
8241
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
8242
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
8243
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
8244
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
8245
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
8246
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
8247
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
8248
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
8249
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
8250
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
8251
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
8252
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
8253
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
8254
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
8255
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
8256
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
8257
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
8258
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
8259
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
8260
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
8261
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
8262
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
8263
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
8264
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
8265
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
8266
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
8267
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
8268
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
8269
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
8270
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
8271
+ };
8272
+
8273
+ void wsp_ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8274
+ assert(n % QK_K == 0);
8275
+
8276
+ const block_iq2_xxs * restrict x = vx;
8277
+ const block_q8_K * restrict y = vy;
8278
+
8279
+ const int nb = n / QK_K;
8280
+
8281
+ #if defined(__ARM_NEON)
8282
+
8283
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8284
+
8285
+ uint32_t aux32[4];
8286
+ const uint8_t * aux8 = (const uint8_t *)aux32;
8287
+
8288
+ wsp_ggml_int8x16x4_t q2u;
8289
+ wsp_ggml_int8x16x4_t q2s;
8290
+ wsp_ggml_int8x16x4_t q8b;
8291
+
8292
+ float sumf = 0;
8293
+ for (int i = 0; i < nb; ++i) {
8294
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8295
+ const uint16_t * restrict q2 = x[i].qs;
8296
+ const int8_t * restrict q8 = y[i].qs;
8297
+ float sumf1 = 0, sumf2 = 0;
8298
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8299
+ q8b = wsp_ggml_vld1q_s8_x4(q8); q8 += 64;
8300
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
8301
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
8302
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
8303
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
8304
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
8305
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
8306
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
8307
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
8308
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
8309
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
8310
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
8311
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
8312
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
8313
+ const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
8314
+ const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
8315
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
8316
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
8317
+ }
8318
+ sumf += d*(sumf1 + sumf2);
8319
+ }
8320
+ *s = 0.25f * sumf;
8321
+
8322
+ #elif defined(__AVX2__)
8323
+
8324
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8325
+
8326
+ uint32_t aux32[4];
8327
+ const uint8_t * aux8 = (const uint8_t *)aux32;
8328
+
8329
+ __m256 accumf = _mm256_setzero_ps();
8330
+ for (int i = 0; i < nb; ++i) {
8331
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8332
+ const uint16_t * restrict q2 = x[i].qs;
8333
+ const int8_t * restrict q8 = y[i].qs;
8334
+ __m256i sumi1 = _mm256_setzero_si256();
8335
+ __m256i sumi2 = _mm256_setzero_si256();
8336
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8337
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8338
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8339
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
8340
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
8341
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
8342
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
8343
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
8344
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
8345
+ signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
8346
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8347
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8348
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8349
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8350
+ const uint16_t ls1 = aux32[1] >> 28;
8351
+ const uint16_t ls2 = aux32[3] >> 28;
8352
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
8353
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
8354
+ sumi1 = _mm256_add_epi32(sumi1, p1);
8355
+ sumi2 = _mm256_add_epi32(sumi2, p2);
8356
+ }
8357
+
8358
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
8359
+
8360
+ }
8361
+
8362
+ *s = 0.125f * hsum_float_8(accumf);
8363
+
8364
+ #else
8365
+
8366
+ uint32_t aux32[2];
8367
+ const uint8_t * aux8 = (const uint8_t *)aux32;
8368
+
8369
+ float sumf = 0.f;
8370
+ for (int i = 0; i < nb; ++i) {
8371
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8372
+ const uint16_t * restrict q2 = x[i].qs;
8373
+ const int8_t * restrict q8 = y[i].qs;
8374
+ int32_t bsum = 0;
8375
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
8376
+ memcpy(aux32, q2, 2*sizeof(uint32_t));
8377
+ q2 += 4;
8378
+ const uint32_t ls = 2*(aux32[1] >> 28) + 1;
8379
+ int32_t sumi = 0;
8380
+ for (int l = 0; l < 4; ++l) {
8381
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
8382
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
8383
+ for (int j = 0; j < 8; ++j) {
8384
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
8385
+ }
8386
+ q8 += 8;
8387
+ }
8388
+ bsum += sumi * ls;
8389
+ }
8390
+ sumf += d * bsum;
8391
+ }
8392
+ *s = 0.125f * sumf;
8393
+ #endif
8394
+ }
8395
+
8396
+ void wsp_ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8397
+ assert(n % QK_K == 0);
8398
+
8399
+ const block_iq2_xs * restrict x = vx;
8400
+ const block_q8_K * restrict y = vy;
8401
+
8402
+ const int nb = n / QK_K;
8403
+
8404
+ #if defined(__ARM_NEON)
8405
+
8406
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8407
+
8408
+ wsp_ggml_int8x16x4_t q2u;
8409
+ wsp_ggml_int8x16x4_t q2s;
8410
+ wsp_ggml_int8x16x4_t q8b;
8411
+
8412
+ int32x4x4_t scales32;
8413
+
8414
+ float sumf = 0;
8415
+ for (int i = 0; i < nb; ++i) {
8416
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8417
+ const uint16_t * restrict q2 = x[i].qs;
8418
+ const int8_t * restrict q8 = y[i].qs;
8419
+ const uint8x8_t scales8 = vld1_u8(x[i].scales);
8420
+ const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
8421
+ const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
8422
+ uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
8423
+ scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
8424
+ const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
8425
+ const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
8426
+ scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
8427
+ scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
8428
+ scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
8429
+ scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
8430
+ int32x4_t sumi = vdupq_n_s32(0);
8431
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
8432
+ q8b = wsp_ggml_vld1q_s8_x4(q8); q8 += 64;
8433
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
8434
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
8435
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
8436
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
8437
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
8438
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
8439
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
8440
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
8441
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
8442
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
8443
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
8444
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
8445
+ const int32x4_t p1 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
8446
+ const int32x4_t p2 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
8447
+ const int32x4_t p3 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
8448
+ const int32x4_t p4 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
8449
+ const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
8450
+ sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
8451
+ q2 += 8;
8452
+ }
8453
+ sumf += d*vaddvq_s32(sumi);
8454
+ }
8455
+ *s = 0.125f * sumf;
8456
+
8457
+ #elif defined(__AVX2__)
8458
+
8459
+ const __m128i m4 = _mm_set1_epi8(0xf);
8460
+ const __m128i m1 = _mm_set1_epi8(1);
8461
+ const __m128i m511 = _mm_set1_epi16(511);
8462
+ const __m128i m127 = _mm_set1_epi16(127);
8463
+
8464
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8465
+
8466
+ uint64_t aux64;
8467
+
8468
+ // somewhat hacky, but gives a significant boost in performance
8469
+ __m128i aux_gindex, aux_sindex;
8470
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
8471
+ const uint16_t * sindex = (const uint16_t *)&aux_sindex;
8472
+
8473
+ __m256 accumf = _mm256_setzero_ps();
8474
+ for (int i = 0; i < nb; ++i) {
8475
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8476
+ const uint16_t * restrict q2 = x[i].qs;
8477
+ const int8_t * restrict q8 = y[i].qs;
8478
+
8479
+ memcpy(&aux64, x[i].scales, 8);
8480
+ __m128i stmp = _mm_set1_epi64x(aux64);
8481
+ stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
8482
+ const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
8483
+
8484
+ __m256i sumi1 = _mm256_setzero_si256();
8485
+ __m256i sumi2 = _mm256_setzero_si256();
8486
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8487
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8488
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8489
+ const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
8490
+ aux_gindex = _mm_and_si128(q2_data, m511);
8491
+ aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
8492
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
8493
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
8494
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
8495
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
8496
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8497
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8498
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8499
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8500
+
8501
+ const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
8502
+ const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
8503
+
8504
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
8505
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
8506
+ }
8507
+
8508
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
8509
+
8510
+ }
8511
+
8512
+ *s = 0.125f * hsum_float_8(accumf);
8513
+
8514
+ #else
8515
+
8516
+ float sumf = 0.f;
8517
+ for (int i = 0; i < nb; ++i) {
8518
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8519
+ const uint16_t * restrict q2 = x[i].qs;
8520
+ const uint8_t * restrict sc = x[i].scales;
8521
+ const int8_t * restrict q8 = y[i].qs;
8522
+ int32_t bsum = 0;
8523
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
8524
+ const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
8525
+ const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
8526
+ int32_t sumi = 0;
8527
+ for (int l = 0; l < 2; ++l) {
8528
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
8529
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
8530
+ for (int j = 0; j < 8; ++j) {
8531
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
8532
+ }
8533
+ q8 += 8;
8534
+ }
8535
+ bsum += sumi * ls1;
8536
+ sumi = 0;
8537
+ for (int l = 2; l < 4; ++l) {
8538
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
8539
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
8540
+ for (int j = 0; j < 8; ++j) {
8541
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
8542
+ }
8543
+ q8 += 8;
8544
+ }
8545
+ bsum += sumi * ls2;
8546
+ q2 += 4;
8547
+ }
8548
+ sumf += d * bsum;
8549
+ }
8550
+ *s = 0.125f * sumf;
8551
+ #endif
8552
+ }
8553
+
8554
+ // ================================ IQ2 quantization =============================================
8555
+
8556
+ typedef struct {
8557
+ uint64_t * grid;
8558
+ int * map;
8559
+ uint16_t * neighbours;
8560
+ } iq2_entry_t;
8561
+
8562
+ static iq2_entry_t iq2_data[2] = {
8563
+ {NULL, NULL, NULL},
8564
+ {NULL, NULL, NULL},
8565
+ };
8566
+
8567
+ static inline int iq2_data_index(int grid_size) {
8568
+ WSP_GGML_ASSERT(grid_size == 256 || grid_size == 512);
8569
+ return grid_size == 256 ? 0 : 1;
8570
+ }
8571
+
8572
+ static int iq2_compare_func(const void * left, const void * right) {
8573
+ const int * l = (const int *)left;
8574
+ const int * r = (const int *)right;
8575
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
8576
+ }
8577
+
8578
+ void iq2xs_init_impl(int grid_size) {
8579
+ const int gindex = iq2_data_index(grid_size);
8580
+ if (iq2_data[gindex].grid) {
8581
+ return;
8582
+ }
8583
+ static const uint16_t kgrid_256[256] = {
8584
+ 0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
8585
+ 100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
8586
+ 1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
8587
+ 1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113,
8588
+ 2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240,
8589
+ 4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400,
8590
+ 5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260,
8591
+ 8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872,
8592
+ 10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
8593
+ 16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
8594
+ 17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
8595
+ 20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
8596
+ 22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
8597
+ 25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
8598
+ 33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
8599
+ 37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
8600
+ };
8601
+ static const uint16_t kgrid_512[512] = {
8602
+ 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
8603
+ 73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
8604
+ 260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
8605
+ 352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597,
8606
+ 640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096,
8607
+ 1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348,
8608
+ 1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065,
8609
+ 2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441,
8610
+ 2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160,
8611
+ 4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372,
8612
+ 4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125,
8613
+ 5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652,
8614
+ 5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197,
8615
+ 8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549,
8616
+ 8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894,
8617
+ 10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
8618
+ 16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
8619
+ 16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
8620
+ 16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
8621
+ 17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
8622
+ 18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
8623
+ 20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
8624
+ 21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
8625
+ 22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
8626
+ 24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
8627
+ 32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
8628
+ 33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
8629
+ 33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
8630
+ 35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
8631
+ 37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
8632
+ 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
8633
+ 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
8634
+ };
8635
+ const int kmap_size = 43692;
8636
+ const int nwant = 2;
8637
+ const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
8638
+ uint64_t * kgrid_q2xs;
8639
+ int * kmap_q2xs;
8640
+ uint16_t * kneighbors_q2xs;
8641
+
8642
+ printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
8643
+ uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
8644
+ for (int k = 0; k < grid_size; ++k) {
8645
+ int8_t * pos = (int8_t *)(the_grid + k);
8646
+ for (int i = 0; i < 8; ++i) {
8647
+ int l = (kgrid[k] >> 2*i) & 0x3;
8648
+ pos[i] = 2*l + 1;
8649
+ }
8650
+ }
8651
+ kgrid_q2xs = the_grid;
8652
+ iq2_data[gindex].grid = the_grid;
8653
+ kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
8654
+ iq2_data[gindex].map = kmap_q2xs;
8655
+ for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
8656
+ uint64_t aux64;
8657
+ uint8_t * aux8 = (uint8_t *)&aux64;
8658
+ for (int i = 0; i < grid_size; ++i) {
8659
+ aux64 = kgrid_q2xs[i];
8660
+ uint16_t index = 0;
8661
+ for (int k=0; k<8; ++k) {
8662
+ uint16_t q = (aux8[k] - 1)/2;
8663
+ index |= (q << 2*k);
8664
+ }
8665
+ kmap_q2xs[index] = i;
8666
+ }
8667
+ int8_t pos[8];
8668
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
8669
+ int num_neighbors = 0, num_not_in_map = 0;
8670
+ for (int i = 0; i < kmap_size; ++i) {
8671
+ if (kmap_q2xs[i] >= 0) continue;
8672
+ ++num_not_in_map;
8673
+ for (int k = 0; k < 8; ++k) {
8674
+ int l = (i >> 2*k) & 0x3;
8675
+ pos[k] = 2*l + 1;
8676
+ }
8677
+ for (int j = 0; j < grid_size; ++j) {
8678
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
8679
+ int d2 = 0;
8680
+ for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
8681
+ dist2[2*j+0] = d2;
8682
+ dist2[2*j+1] = j;
8683
+ }
8684
+ qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
8685
+ int n = 0; int d2 = dist2[0];
8686
+ int nhave = 1;
8687
+ for (int j = 0; j < grid_size; ++j) {
8688
+ if (dist2[2*j] > d2) {
8689
+ if (nhave == nwant) break;
8690
+ d2 = dist2[2*j];
8691
+ ++nhave;
8692
+ }
8693
+ ++n;
8694
+ }
8695
+ num_neighbors += n;
8696
+ }
8697
+ printf("%s: %d neighbours in total\n", __func__, num_neighbors);
8698
+ kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
8699
+ iq2_data[gindex].neighbours = kneighbors_q2xs;
8700
+ int counter = 0;
8701
+ for (int i = 0; i < kmap_size; ++i) {
8702
+ if (kmap_q2xs[i] >= 0) continue;
8703
+ for (int k = 0; k < 8; ++k) {
8704
+ int l = (i >> 2*k) & 0x3;
8705
+ pos[k] = 2*l + 1;
8706
+ }
8707
+ for (int j = 0; j < grid_size; ++j) {
8708
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
8709
+ int d2 = 0;
8710
+ for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
8711
+ dist2[2*j+0] = d2;
8712
+ dist2[2*j+1] = j;
8713
+ }
8714
+ qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
8715
+ kmap_q2xs[i] = -(counter + 1);
8716
+ int d2 = dist2[0];
8717
+ uint16_t * start = &kneighbors_q2xs[counter++];
8718
+ int n = 0, nhave = 1;
8719
+ for (int j = 0; j < grid_size; ++j) {
8720
+ if (dist2[2*j] > d2) {
8721
+ if (nhave == nwant) break;
8722
+ d2 = dist2[2*j];
8723
+ ++nhave;
8724
+ }
8725
+ kneighbors_q2xs[counter++] = dist2[2*j+1];
8726
+ ++n;
8727
+ }
8728
+ *start = n;
8729
+ }
8730
+ free(dist2);
8731
+ }
8732
+
8733
+ void iq2xs_free_impl(int grid_size) {
8734
+ WSP_GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
8735
+ const int gindex = iq2_data_index(grid_size);
8736
+ if (iq2_data[gindex].grid) {
8737
+ free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
8738
+ free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
8739
+ free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
8740
+ }
8741
+ }
8742
+
8743
+ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
8744
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
8745
+ int num_neighbors = neighbours[0];
8746
+ WSP_GGML_ASSERT(num_neighbors > 0);
8747
+ float best_d2 = FLT_MAX;
8748
+ int grid_index = -1;
8749
+ for (int j = 1; j <= num_neighbors; ++j) {
8750
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
8751
+ float d2 = 0;
8752
+ for (int i = 0; i < 8; ++i) {
8753
+ float q = pg[i];
8754
+ float diff = scale*q - xval[i];
8755
+ d2 += weight[i]*diff*diff;
8756
+ }
8757
+ if (d2 < best_d2) {
8758
+ best_d2 = d2; grid_index = neighbours[j];
8759
+ }
8760
+ }
8761
+ WSP_GGML_ASSERT(grid_index >= 0);
8762
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
8763
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
8764
+ return grid_index;
8765
+ }
8766
+
8767
+ static void wsp_quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
8768
+
8769
+ const int gindex = iq2_data_index(256);
8770
+
8771
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
8772
+ const int * kmap_q2xs = iq2_data[gindex].map;
8773
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
8774
+
8775
+ WSP_GGML_ASSERT(quant_weights && "missing quantization weights");
8776
+ WSP_GGML_ASSERT(kgrid_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
8777
+ WSP_GGML_ASSERT(kmap_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
8778
+ WSP_GGML_ASSERT(kneighbors_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
8779
+ WSP_GGML_ASSERT(n%QK_K == 0);
8780
+
8781
+ const int kMaxQ = 3;
8782
+
8783
+ const int nbl = n/256;
8784
+
8785
+ block_iq2_xxs * y = vy;
8786
+
8787
+ float scales[QK_K/32];
8788
+ float weight[32];
8789
+ float xval[32];
8790
+ int8_t L[32];
8791
+ int8_t Laux[32];
8792
+ float waux[32];
8793
+ bool is_on_grid[4];
8794
+ bool is_on_grid_aux[4];
8795
+ uint8_t block_signs[4];
8796
+ uint32_t q2[2*(QK_K/32)];
8797
+
8798
+ for (int ibl = 0; ibl < nbl; ++ibl) {
8799
+
8800
+ y[ibl].d = WSP_GGML_FP32_TO_FP16(0.f);
8801
+ memset(q2, 0, QK_K/4);
8802
+
8803
+ float max_scale = 0;
8804
+
8805
+ const float * xbl = x + QK_K*ibl;
8806
+ float sumx2 = 0;
8807
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
8808
+ float sigma2 = sumx2/QK_K;
8809
+
8810
+ for (int ib = 0; ib < QK_K/32; ++ib) {
8811
+ const float * xb = xbl + 32*ib;
8812
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
8813
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
8814
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
8815
+ for (int k = 0; k < 4; ++k) {
8816
+ int nflip = 0;
8817
+ uint8_t s = 0;
8818
+ for (int i = 0; i < 8; ++i) {
8819
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
8820
+ else {
8821
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
8822
+ }
8823
+ }
8824
+ if (nflip%2) {
8825
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
8826
+ for (int i = 1; i < 8; ++i) {
8827
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
8828
+ if (ax < min) {
8829
+ min = ax; imin = i;
8830
+ }
8831
+ }
8832
+ xval[8*k+imin] = -xval[8*k+imin];
8833
+ s ^= (1 << imin);
8834
+ }
8835
+ block_signs[k] = s & 127;
8836
+ }
8837
+ float max = xval[0];
8838
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
8839
+ if (!max) {
8840
+ scales[ib] = 0;
8841
+ memset(L, 0, 32);
8842
+ continue;
8843
+ }
8844
+ float best = 0;
8845
+ float scale = max/(2*kMaxQ-1);
8846
+ for (int is = -9; is <= 9; ++is) {
8847
+ float id = (2*kMaxQ-1+is*0.1f)/max;
8848
+ float this_scale = 1/id;
8849
+ for (int k = 0; k < 4; ++k) {
8850
+ for (int i = 0; i < 8; ++i) {
8851
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
8852
+ Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
8853
+ }
8854
+ uint16_t u = 0;
8855
+ for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
8856
+ int grid_index = kmap_q2xs[u];
8857
+ is_on_grid_aux[k] = true;
8858
+ if (grid_index < 0) {
8859
+ is_on_grid_aux[k] = false;
8860
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
8861
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
8862
+ }
8863
+ }
8864
+ float sumqx = 0, sumq2 = 0;
8865
+ for (int i = 0; i < 32; ++i) {
8866
+ float w = weight[i];
8867
+ float q = 2*Laux[i] + 1;
8868
+ sumqx += w*xval[i]*q;
8869
+ sumq2 += w*q*q;
8870
+ }
8871
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
8872
+ scale = sumqx/sumq2; best = scale*sumqx;
8873
+ for (int i = 0; i < 32; ++i) L[i] = Laux[i];
8874
+ for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
8875
+ }
8876
+ }
8877
+ int n_not_ongrid = 0;
8878
+ for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
8879
+ if (n_not_ongrid > 0 && scale > 0) {
8880
+ float id = 1/scale;
8881
+ for (int k = 0; k < 4; ++k) {
8882
+ if (is_on_grid[k]) continue;
8883
+ uint16_t u = 0;
8884
+ for (int i = 0; i < 8; ++i) {
8885
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
8886
+ l = MAX(0, MIN(kMaxQ-1, l));
8887
+ u |= (l << 2*i);
8888
+ }
8889
+ int grid_index = kmap_q2xs[u];
8890
+ if (grid_index < 0) {
8891
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
8892
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
8893
+ }
8894
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
8895
+ for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
8896
+ }
8897
+ float sumqx = 0, sumq2 = 0;
8898
+ for (int i = 0; i < 32; ++i) {
8899
+ float w = weight[i];
8900
+ float q = 2*L[i] + 1;
8901
+ sumqx += w*xval[i]*q;
8902
+ sumq2 += w*q*q;
8903
+ }
8904
+ if (sumq2 > 0) scale = sumqx/sumq2;
8905
+ }
8906
+ if (scale < 0) {
8907
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
8908
+ // and correspondingly flip quant signs.
8909
+ scale = -scale;
8910
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
8911
+ }
8912
+ for (int k = 0; k < 4; ++k) {
8913
+ uint16_t u = 0;
8914
+ for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
8915
+ int grid_index = kmap_q2xs[u];
8916
+ if (grid_index < 0) {
8917
+ printf("Oops: found point %u not on grid:", u);
8918
+ for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
8919
+ printf("\n");
8920
+ WSP_GGML_ASSERT(false);
8921
+ }
8922
+ q2[2*ib+0] |= (grid_index << 8*k);
8923
+ q2[2*ib+1] |= (block_signs[k] << 7*k);
8924
+ }
8925
+ WSP_GGML_ASSERT(scale >= 0);
8926
+ scales[ib] = scale;
8927
+ max_scale = MAX(max_scale, scale);
8928
+ }
8929
+
8930
+ if (!max_scale) {
8931
+ memset(y[ibl].qs, 0, QK_K/4);
8932
+ continue;
8933
+ }
8934
+
8935
+ float d = max_scale/31;
8936
+ y[ibl].d = WSP_GGML_FP32_TO_FP16(d);
8937
+ float id = 1/d;
8938
+ float sumqx = 0, sumq2 = 0;
8939
+ for (int ib = 0; ib < QK_K/32; ++ib) {
8940
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
8941
+ l = MAX(0, MIN(15, l));
8942
+ q2[2*ib+1] |= ((uint32_t)l << 28);
8943
+ const float * xb = xbl + 32*ib;
8944
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
8945
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
8946
+ const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
8947
+ const float db = d * (1 + 2*l);
8948
+ uint32_t u = 0;
8949
+ for (int k = 0; k < 4; ++k) {
8950
+ const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
8951
+ const float * xk = xb + 8*k;
8952
+ const float * wk = weight + 8*k;
8953
+ const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
8954
+ float best_mse = 0; int best_index = aux8[k];
8955
+ for (int j = 0; j < 8; ++j) {
8956
+ float diff = db * grid[j] * signs[j] - xk[j];
8957
+ best_mse += wk[j] * diff * diff;
8958
+ }
8959
+ for (int idx = 0; idx < 256; ++idx) {
8960
+ grid = (const uint8_t *)(kgrid_q2xs + idx);
8961
+ float mse = 0;
8962
+ for (int j = 0; j < 8; ++j) {
8963
+ float diff = db * grid[j] * signs[j] - xk[j];
8964
+ mse += wk[j] * diff * diff;
8965
+ }
8966
+ if (mse < best_mse) {
8967
+ best_mse = mse; best_index = idx;
8968
+ }
8969
+ }
8970
+ u |= (best_index << 8*k);
8971
+ grid = (const uint8_t *)(kgrid_q2xs + best_index);
8972
+ //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
8973
+ for (int j = 0; j < 8; ++j) {
8974
+ float q = db * grid[j] * signs[j];
8975
+ sumqx += wk[j] * q * xk[j];
8976
+ sumq2 += wk[j] * q * q;
8977
+ }
8978
+ }
8979
+ q2[2*ib] = u;
8980
+ if (sumq2 > 0) y[ibl].d = WSP_GGML_FP32_TO_FP16(d*sumqx/sumq2);
8981
+ }
8982
+ memcpy(y[ibl].qs, q2, QK_K/4);
8983
+ }
8984
+ }
8985
+
8986
+ static void wsp_quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
8987
+
8988
+ const int gindex = iq2_data_index(512);
8989
+
8990
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
8991
+ const int * kmap_q2xs = iq2_data[gindex].map;
8992
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
8993
+
8994
+ WSP_GGML_ASSERT(quant_weights && "missing quantization weights");
8995
+ WSP_GGML_ASSERT(kmap_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
8996
+ WSP_GGML_ASSERT(kgrid_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
8997
+ WSP_GGML_ASSERT(kneighbors_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
8998
+ WSP_GGML_ASSERT(n%QK_K == 0);
8999
+
9000
+ const int kMaxQ = 3;
9001
+
9002
+ const int nbl = n/256;
9003
+
9004
+ block_iq2_xs * y = vy;
9005
+
9006
+ float scales[QK_K/16];
9007
+ float weight[16];
9008
+ float xval[16];
9009
+ int8_t L[16];
9010
+ int8_t Laux[16];
9011
+ float waux[16];
9012
+ bool is_on_grid[2];
9013
+ bool is_on_grid_aux[2];
9014
+ uint8_t block_signs[2];
9015
+ uint16_t q2[2*(QK_K/16)];
9016
+
9017
+ for (int ibl = 0; ibl < nbl; ++ibl) {
9018
+
9019
+ y[ibl].d = WSP_GGML_FP32_TO_FP16(0.f);
9020
+ memset(q2, 0, QK_K/4);
9021
+ memset(y[ibl].scales, 0, QK_K/32);
9022
+
9023
+ float max_scale = 0;
9024
+
9025
+ const float * xbl = x + QK_K*ibl;
9026
+ float sumx2 = 0;
9027
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
9028
+ float sigma2 = sumx2/QK_K;
9029
+
9030
+ for (int ib = 0; ib < QK_K/16; ++ib) {
9031
+ const float * xb = xbl + 16*ib;
9032
+ const float * qw = quant_weights + QK_K*ibl + 16*ib;
9033
+ for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9034
+ for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
9035
+ for (int k = 0; k < 2; ++k) {
9036
+ int nflip = 0;
9037
+ uint8_t s = 0;
9038
+ for (int i = 0; i < 8; ++i) {
9039
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
9040
+ else {
9041
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
9042
+ }
9043
+ }
9044
+ if (nflip%2) {
9045
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
9046
+ for (int i = 1; i < 8; ++i) {
9047
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
9048
+ if (ax < min) {
9049
+ min = ax; imin = i;
9050
+ }
9051
+ }
9052
+ xval[8*k+imin] = -xval[8*k+imin];
9053
+ s ^= (1 << imin);
9054
+ }
9055
+ block_signs[k] = s & 127;
9056
+ }
9057
+ float max = xval[0];
9058
+ for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
9059
+ if (!max) {
9060
+ scales[ib] = 0;
9061
+ memset(L, 0, 16);
9062
+ continue;
9063
+ }
9064
+ float best = 0;
9065
+ float scale = max/(2*kMaxQ-1);
9066
+ is_on_grid[0] = is_on_grid[1] = true;
9067
+ for (int is = -9; is <= 9; ++is) {
9068
+ float id = (2*kMaxQ-1+is*0.1f)/max;
9069
+ float this_scale = 1/id;
9070
+ for (int k = 0; k < 2; ++k) {
9071
+ for (int i = 0; i < 8; ++i) {
9072
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
9073
+ Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
9074
+ }
9075
+ uint16_t u = 0;
9076
+ for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
9077
+ int grid_index = kmap_q2xs[u];
9078
+ is_on_grid_aux[k] = true;
9079
+ if (grid_index < 0) {
9080
+ is_on_grid_aux[k] = false;
9081
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
9082
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
9083
+ }
9084
+ }
9085
+ float sumqx = 0, sumq2 = 0;
9086
+ for (int i = 0; i < 16; ++i) {
9087
+ float w = weight[i];
9088
+ float q = 2*Laux[i] + 1;
9089
+ sumqx += w*xval[i]*q;
9090
+ sumq2 += w*q*q;
9091
+ }
9092
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
9093
+ scale = sumqx/sumq2; best = scale*sumqx;
9094
+ for (int i = 0; i < 16; ++i) L[i] = Laux[i];
9095
+ for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k];
9096
+ }
9097
+ }
9098
+ int n_not_ongrid = 0;
9099
+ for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9100
+ if (n_not_ongrid > 0 && scale > 0) {
9101
+ float id = 1/scale;
9102
+ for (int k = 0; k < 2; ++k) {
9103
+ if (is_on_grid[k]) continue;
9104
+ uint16_t u = 0;
9105
+ for (int i = 0; i < 8; ++i) {
9106
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
9107
+ l = MAX(0, MIN(kMaxQ-1, l));
9108
+ u |= (l << 2*i);
9109
+ L[8*k + i] = l;
9110
+ }
9111
+ int grid_index = kmap_q2xs[u];
9112
+ if (grid_index < 0) {
9113
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
9114
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
9115
+ }
9116
+ }
9117
+ float sumqx = 0, sumq2 = 0;
9118
+ for (int i = 0; i < 16; ++i) {
9119
+ float w = weight[i];
9120
+ float q = 2*L[i] + 1;
9121
+ sumqx += w*xval[i]*q;
9122
+ sumq2 += w*q*q;
9123
+ }
9124
+ if (sumq2 > 0) scale = sumqx/sumq2;
9125
+ }
9126
+ if (scale < 0) {
9127
+ scale = -scale;
9128
+ for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
9129
+ }
9130
+ for (int k = 0; k < 2; ++k) {
9131
+ uint16_t u = 0;
9132
+ for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
9133
+ int grid_index = kmap_q2xs[u];
9134
+ if (grid_index < 0) {
9135
+ printf("Oops: found point %u not on grid:", u);
9136
+ for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
9137
+ printf("\n");
9138
+ WSP_GGML_ASSERT(false);
9139
+ }
9140
+ q2[2*ib+k] = grid_index | (block_signs[k] << 9);
9141
+ }
9142
+ WSP_GGML_ASSERT(scale >= 0);
9143
+ scales[ib] = scale;
9144
+ max_scale = MAX(max_scale, scale);
9145
+ }
9146
+
9147
+ if (!max_scale) {
9148
+ memset(y[ibl].qs, 0, QK_K/4);
9149
+ continue;
9150
+ }
9151
+
9152
+ float d = max_scale/31;
9153
+ y[ibl].d = WSP_GGML_FP32_TO_FP16(d);
9154
+ float id = 1/d;
9155
+ for (int ib = 0; ib < QK_K/16; ++ib) {
9156
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
9157
+ l = MAX(0, MIN(15, l));
9158
+ if (ib%2 == 0) y[ibl].scales[ib/2] = l;
9159
+ else y[ibl].scales[ib/2] |= (l << 4);
9160
+ }
9161
+ memcpy(y[ibl].qs, q2, QK_K/4);
9162
+
9163
+ }
9164
+ }
9165
+
9166
+ size_t wsp_quantize_iq2_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
9167
+ (void)hist;
9168
+ WSP_GGML_ASSERT(n_per_row%QK_K == 0);
9169
+ int nblock = n_per_row/QK_K;
9170
+ char * qrow = (char *)dst;
9171
+ for (int row = 0; row < nrow; ++row) {
9172
+ wsp_quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
9173
+ src += n_per_row;
9174
+ qrow += nblock*sizeof(block_iq2_xxs);
9175
+ }
9176
+ return nrow * nblock * sizeof(block_iq2_xxs);
9177
+ }
9178
+
9179
+ size_t wsp_quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
9180
+ (void)hist;
9181
+ WSP_GGML_ASSERT(n_per_row%QK_K == 0);
9182
+ int nblock = n_per_row/QK_K;
9183
+ char * qrow = (char *)dst;
9184
+ for (int row = 0; row < nrow; ++row) {
9185
+ wsp_quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
9186
+ src += n_per_row;
9187
+ qrow += nblock*sizeof(block_iq2_xs);
9188
+ }
9189
+ return nrow * nblock * sizeof(block_iq2_xs);
9190
+ }
9191
+