whisper.rn 0.4.0-rc.7 → 0.4.0-rc.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/coreml/whisper-encoder.mm +1 -1
- package/cpp/ggml-alloc.c +41 -11
- package/cpp/ggml-alloc.h +3 -1
- package/cpp/ggml-backend-impl.h +38 -34
- package/cpp/ggml-backend.c +630 -269
- package/cpp/ggml-backend.h +58 -30
- package/cpp/ggml-impl.h +3 -0
- package/cpp/ggml-metal-whisper.metal +1253 -341
- package/cpp/ggml-metal.h +6 -54
- package/cpp/ggml-metal.m +2004 -1987
- package/cpp/ggml-quants.c +2230 -421
- package/cpp/ggml-quants.h +39 -1
- package/cpp/ggml.c +735 -265
- package/cpp/ggml.h +94 -43
- package/cpp/whisper.cpp +118 -86
- package/ios/RNWhisperContext.mm +2 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +1 -1
- package/src/version.json +1 -1
package/cpp/ggml-quants.c
CHANGED
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
#include <string.h>
|
|
6
6
|
#include <assert.h>
|
|
7
7
|
#include <float.h>
|
|
8
|
+
#include <stdlib.h> // for qsort
|
|
9
|
+
#include <stdio.h> // for WSP_GGML_ASSERT
|
|
8
10
|
|
|
9
11
|
#ifdef __ARM_NEON
|
|
10
12
|
|
|
@@ -272,10 +274,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
|
|
|
272
274
|
|
|
273
275
|
// vaddvq_s16
|
|
274
276
|
// vpaddq_s16
|
|
277
|
+
// vpaddq_s32
|
|
275
278
|
// vaddvq_s32
|
|
276
279
|
// vaddvq_f32
|
|
277
280
|
// vmaxvq_f32
|
|
278
281
|
// vcvtnq_s32_f32
|
|
282
|
+
// vzip1_u8
|
|
283
|
+
// vzip2_u8
|
|
279
284
|
|
|
280
285
|
inline static int32_t vaddvq_s16(int16x8_t v) {
|
|
281
286
|
return
|
|
@@ -291,6 +296,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
|
|
291
296
|
return vcombine_s16(a0, b0);
|
|
292
297
|
}
|
|
293
298
|
|
|
299
|
+
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
|
|
300
|
+
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
|
|
301
|
+
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
|
|
302
|
+
return vcombine_s32(a0, b0);
|
|
303
|
+
}
|
|
304
|
+
|
|
294
305
|
inline static int32_t vaddvq_s32(int32x4_t v) {
|
|
295
306
|
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
|
296
307
|
}
|
|
@@ -316,6 +327,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
|
|
316
327
|
return res;
|
|
317
328
|
}
|
|
318
329
|
|
|
330
|
+
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
|
331
|
+
uint8x8_t res;
|
|
332
|
+
|
|
333
|
+
res[0] = a[0]; res[1] = b[0];
|
|
334
|
+
res[2] = a[1]; res[3] = b[1];
|
|
335
|
+
res[4] = a[2]; res[5] = b[2];
|
|
336
|
+
res[6] = a[3]; res[7] = b[3];
|
|
337
|
+
|
|
338
|
+
return res;
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
|
342
|
+
uint8x8_t res;
|
|
343
|
+
|
|
344
|
+
res[0] = a[4]; res[1] = b[4];
|
|
345
|
+
res[2] = a[5]; res[3] = b[5];
|
|
346
|
+
res[4] = a[6]; res[5] = b[6];
|
|
347
|
+
res[6] = a[7]; res[7] = b[7];
|
|
348
|
+
|
|
349
|
+
return res;
|
|
350
|
+
}
|
|
351
|
+
|
|
319
352
|
// vld1q_s16_x2
|
|
320
353
|
// vld1q_u8_x2
|
|
321
354
|
// vld1q_u8_x4
|
|
@@ -407,6 +440,22 @@ inline static wsp_ggml_int8x16x4_t wsp_ggml_vld1q_s8_x4(const int8_t * ptr) {
|
|
|
407
440
|
#define wsp_ggml_vld1q_s8_x4 vld1q_s8_x4
|
|
408
441
|
|
|
409
442
|
#endif
|
|
443
|
+
|
|
444
|
+
#if !defined(__ARM_FEATURE_DOTPROD)
|
|
445
|
+
|
|
446
|
+
inline static int32x4_t wsp_ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
|
447
|
+
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
|
448
|
+
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
|
449
|
+
|
|
450
|
+
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
#else
|
|
454
|
+
|
|
455
|
+
#define wsp_ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
|
456
|
+
|
|
457
|
+
#endif
|
|
458
|
+
|
|
410
459
|
#endif
|
|
411
460
|
|
|
412
461
|
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
|
@@ -466,6 +515,7 @@ void wsp_quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
|
|
|
466
515
|
wsp_quantize_row_q4_0_reference(x, y, k);
|
|
467
516
|
}
|
|
468
517
|
|
|
518
|
+
|
|
469
519
|
void wsp_quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
|
|
470
520
|
const int qk = QK4_1;
|
|
471
521
|
|
|
@@ -1195,7 +1245,8 @@ static inline int nearest_int(float fval) {
|
|
|
1195
1245
|
return (i & 0x007fffff) - 0x00400000;
|
|
1196
1246
|
}
|
|
1197
1247
|
|
|
1198
|
-
static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type
|
|
1248
|
+
static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
|
|
1249
|
+
const float * restrict qw) {
|
|
1199
1250
|
float max = 0;
|
|
1200
1251
|
float amax = 0;
|
|
1201
1252
|
for (int i = 0; i < n; ++i) {
|
|
@@ -1221,14 +1272,18 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
|
|
|
1221
1272
|
rmse_type = -rmse_type;
|
|
1222
1273
|
return_early = true;
|
|
1223
1274
|
}
|
|
1224
|
-
int weight_type = rmse_type%2;
|
|
1225
1275
|
float sumlx = 0;
|
|
1226
1276
|
float suml2 = 0;
|
|
1277
|
+
#ifdef HAVE_BUGGY_APPLE_LINKER
|
|
1278
|
+
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
|
|
1279
|
+
for (volatile int i = 0; i < n; ++i) {
|
|
1280
|
+
#else
|
|
1227
1281
|
for (int i = 0; i < n; ++i) {
|
|
1282
|
+
#endif
|
|
1228
1283
|
int l = nearest_int(iscale * x[i]);
|
|
1229
1284
|
l = MAX(-nmax, MIN(nmax-1, l));
|
|
1230
1285
|
L[i] = l + nmax;
|
|
1231
|
-
float w =
|
|
1286
|
+
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
|
|
1232
1287
|
sumlx += w*x[i]*l;
|
|
1233
1288
|
suml2 += w*l*l;
|
|
1234
1289
|
}
|
|
@@ -1244,7 +1299,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
|
|
|
1244
1299
|
for (int i = 0; i < n; ++i) {
|
|
1245
1300
|
int l = nearest_int(iscale * x[i]);
|
|
1246
1301
|
l = MAX(-nmax, MIN(nmax-1, l));
|
|
1247
|
-
float w =
|
|
1302
|
+
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
|
|
1248
1303
|
sumlx += w*x[i]*l;
|
|
1249
1304
|
suml2 += w*l*l;
|
|
1250
1305
|
}
|
|
@@ -1592,6 +1647,246 @@ size_t wsp_ggml_wsp_quantize_q2_K(const float * restrict src, void * restrict ds
|
|
|
1592
1647
|
return (n/QK_K*sizeof(block_q2_K));
|
|
1593
1648
|
}
|
|
1594
1649
|
|
|
1650
|
+
static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
|
|
1651
|
+
uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
|
|
1652
|
+
float rmin, float rdelta, int nstep, bool use_mad) {
|
|
1653
|
+
float min = x[0];
|
|
1654
|
+
float max = x[0];
|
|
1655
|
+
float sum_w = weights ? weights[0] : x[0]*x[0];
|
|
1656
|
+
float sum_x = sum_w * x[0];
|
|
1657
|
+
#ifdef HAVE_BUGGY_APPLE_LINKER
|
|
1658
|
+
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
|
|
1659
|
+
for (volatile int i = 1; i < n; ++i) {
|
|
1660
|
+
#else
|
|
1661
|
+
for (int i = 1; i < n; ++i) {
|
|
1662
|
+
#endif
|
|
1663
|
+
if (x[i] < min) min = x[i];
|
|
1664
|
+
if (x[i] > max) max = x[i];
|
|
1665
|
+
float w = weights ? weights[i] : x[i]*x[i];
|
|
1666
|
+
sum_w += w;
|
|
1667
|
+
sum_x += w * x[i];
|
|
1668
|
+
}
|
|
1669
|
+
if (min > 0) {
|
|
1670
|
+
min = 0;
|
|
1671
|
+
}
|
|
1672
|
+
if (max <= min) {
|
|
1673
|
+
memset(L, 0, n);
|
|
1674
|
+
*the_min = -min;
|
|
1675
|
+
return 0.f;
|
|
1676
|
+
}
|
|
1677
|
+
float iscale = nmax/(max - min);
|
|
1678
|
+
float scale = 1/iscale;
|
|
1679
|
+
float best_mad = 0;
|
|
1680
|
+
for (int i = 0; i < n; ++i) {
|
|
1681
|
+
int l = nearest_int(iscale*(x[i] - min));
|
|
1682
|
+
L[i] = MAX(0, MIN(nmax, l));
|
|
1683
|
+
float diff = scale * L[i] + min - x[i];
|
|
1684
|
+
diff = use_mad ? fabsf(diff) : diff*diff;
|
|
1685
|
+
float w = weights ? weights[i] : x[i]*x[i];
|
|
1686
|
+
best_mad += w * diff;
|
|
1687
|
+
}
|
|
1688
|
+
if (nstep < 1) {
|
|
1689
|
+
*the_min = -min;
|
|
1690
|
+
return scale;
|
|
1691
|
+
}
|
|
1692
|
+
for (int is = 0; is <= nstep; ++is) {
|
|
1693
|
+
iscale = (rmin + rdelta*is + nmax)/(max - min);
|
|
1694
|
+
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
|
|
1695
|
+
for (int i = 0; i < n; ++i) {
|
|
1696
|
+
int l = nearest_int(iscale*(x[i] - min));
|
|
1697
|
+
l = MAX(0, MIN(nmax, l));
|
|
1698
|
+
Laux[i] = l;
|
|
1699
|
+
float w = weights ? weights[i] : x[i]*x[i];
|
|
1700
|
+
sum_l += w*l;
|
|
1701
|
+
sum_l2 += w*l*l;
|
|
1702
|
+
sum_xl += w*l*x[i];
|
|
1703
|
+
}
|
|
1704
|
+
float D = sum_w * sum_l2 - sum_l * sum_l;
|
|
1705
|
+
if (D > 0) {
|
|
1706
|
+
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
|
|
1707
|
+
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
|
|
1708
|
+
if (this_min > 0) {
|
|
1709
|
+
this_min = 0;
|
|
1710
|
+
this_scale = sum_xl / sum_l2;
|
|
1711
|
+
}
|
|
1712
|
+
float mad = 0;
|
|
1713
|
+
for (int i = 0; i < n; ++i) {
|
|
1714
|
+
float diff = this_scale * Laux[i] + this_min - x[i];
|
|
1715
|
+
diff = use_mad ? fabsf(diff) : diff*diff;
|
|
1716
|
+
float w = weights ? weights[i] : x[i]*x[i];
|
|
1717
|
+
mad += w * diff;
|
|
1718
|
+
}
|
|
1719
|
+
if (mad < best_mad) {
|
|
1720
|
+
for (int i = 0; i < n; ++i) {
|
|
1721
|
+
L[i] = Laux[i];
|
|
1722
|
+
}
|
|
1723
|
+
best_mad = mad;
|
|
1724
|
+
scale = this_scale;
|
|
1725
|
+
min = this_min;
|
|
1726
|
+
}
|
|
1727
|
+
}
|
|
1728
|
+
}
|
|
1729
|
+
*the_min = -min;
|
|
1730
|
+
return scale;
|
|
1731
|
+
}
|
|
1732
|
+
|
|
1733
|
+
static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
|
|
1734
|
+
float max = 0;
|
|
1735
|
+
for (int i = 0; i < n; ++i) {
|
|
1736
|
+
max = MAX(max, x[i]);
|
|
1737
|
+
}
|
|
1738
|
+
if (!max) { // all zero
|
|
1739
|
+
for (int i = 0; i < n; ++i) { L[i] = 0; }
|
|
1740
|
+
return 0.f;
|
|
1741
|
+
}
|
|
1742
|
+
float iscale = nmax / max;
|
|
1743
|
+
for (int i = 0; i < n; ++i) {
|
|
1744
|
+
L[i] = nearest_int(iscale * x[i]);
|
|
1745
|
+
}
|
|
1746
|
+
float scale = 1/iscale;
|
|
1747
|
+
float best_mse = 0;
|
|
1748
|
+
for (int i = 0; i < n; ++i) {
|
|
1749
|
+
float diff = x[i] - scale*L[i];
|
|
1750
|
+
float w = quant_weights[i];
|
|
1751
|
+
best_mse += w*diff*diff;
|
|
1752
|
+
}
|
|
1753
|
+
for (int is = -4; is <= 4; ++is) {
|
|
1754
|
+
if (is == 0) continue;
|
|
1755
|
+
float iscale_is = (0.1f*is + nmax)/max;
|
|
1756
|
+
float scale_is = 1/iscale_is;
|
|
1757
|
+
float mse = 0;
|
|
1758
|
+
for (int i = 0; i < n; ++i) {
|
|
1759
|
+
int l = nearest_int(iscale_is*x[i]);
|
|
1760
|
+
l = MIN(nmax, l);
|
|
1761
|
+
float diff = x[i] - scale_is*l;
|
|
1762
|
+
float w = quant_weights[i];
|
|
1763
|
+
mse += w*diff*diff;
|
|
1764
|
+
}
|
|
1765
|
+
if (mse < best_mse) {
|
|
1766
|
+
best_mse = mse;
|
|
1767
|
+
iscale = iscale_is;
|
|
1768
|
+
}
|
|
1769
|
+
}
|
|
1770
|
+
float sumlx = 0;
|
|
1771
|
+
float suml2 = 0;
|
|
1772
|
+
for (int i = 0; i < n; ++i) {
|
|
1773
|
+
int l = nearest_int(iscale * x[i]);
|
|
1774
|
+
l = MIN(nmax, l);
|
|
1775
|
+
L[i] = l;
|
|
1776
|
+
float w = quant_weights[i];
|
|
1777
|
+
sumlx += w*x[i]*l;
|
|
1778
|
+
suml2 += w*l*l;
|
|
1779
|
+
}
|
|
1780
|
+
for (int itry = 0; itry < 5; ++itry) {
|
|
1781
|
+
int n_changed = 0;
|
|
1782
|
+
for (int i = 0; i < n; ++i) {
|
|
1783
|
+
float w = quant_weights[i];
|
|
1784
|
+
float slx = sumlx - w*x[i]*L[i];
|
|
1785
|
+
float sl2 = suml2 - w*L[i]*L[i];
|
|
1786
|
+
if (slx > 0 && sl2 > 0) {
|
|
1787
|
+
int new_l = nearest_int(x[i] * sl2 / slx);
|
|
1788
|
+
new_l = MIN(nmax, new_l);
|
|
1789
|
+
if (new_l != L[i]) {
|
|
1790
|
+
slx += w*x[i]*new_l;
|
|
1791
|
+
sl2 += w*new_l*new_l;
|
|
1792
|
+
if (slx*slx*suml2 > sumlx*sumlx*sl2) {
|
|
1793
|
+
L[i] = new_l; sumlx = slx; suml2 = sl2;
|
|
1794
|
+
++n_changed;
|
|
1795
|
+
}
|
|
1796
|
+
}
|
|
1797
|
+
}
|
|
1798
|
+
}
|
|
1799
|
+
if (!n_changed) {
|
|
1800
|
+
break;
|
|
1801
|
+
}
|
|
1802
|
+
}
|
|
1803
|
+
return sumlx / suml2;
|
|
1804
|
+
}
|
|
1805
|
+
|
|
1806
|
+
static void wsp_quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
|
|
1807
|
+
WSP_GGML_ASSERT(quant_weights);
|
|
1808
|
+
assert(k % QK_K == 0);
|
|
1809
|
+
const int nb = k / QK_K;
|
|
1810
|
+
const bool requantize = true;
|
|
1811
|
+
|
|
1812
|
+
uint8_t L[QK_K];
|
|
1813
|
+
uint8_t Laux[16];
|
|
1814
|
+
float mins[QK_K/16];
|
|
1815
|
+
float scales[QK_K/16];
|
|
1816
|
+
float sw[QK_K/16];
|
|
1817
|
+
float weight[QK_K/16];
|
|
1818
|
+
uint8_t Ls[QK_K/16], Lm[QK_K/16];
|
|
1819
|
+
|
|
1820
|
+
for (int i = 0; i < nb; i++) {
|
|
1821
|
+
memset(sw, 0, QK_K/16*sizeof(float));
|
|
1822
|
+
float sumx2 = 0;
|
|
1823
|
+
for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
|
|
1824
|
+
float sigma2 = sumx2/QK_K;
|
|
1825
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
1826
|
+
const float * restrict qw = quant_weights + QK_K * i + 16*j;
|
|
1827
|
+
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
|
|
1828
|
+
for (int l = 0; l < 16; ++l) sw[j] += weight[l];
|
|
1829
|
+
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
|
1830
|
+
}
|
|
1831
|
+
|
|
1832
|
+
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
|
|
1833
|
+
float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
|
|
1834
|
+
y[i].d = WSP_GGML_FP32_TO_FP16(dm);
|
|
1835
|
+
y[i].dmin = WSP_GGML_FP32_TO_FP16(mm);
|
|
1836
|
+
dm = WSP_GGML_FP16_TO_FP32(y[i].d);
|
|
1837
|
+
mm = WSP_GGML_FP16_TO_FP32(y[i].dmin);
|
|
1838
|
+
|
|
1839
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
1840
|
+
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
|
|
1841
|
+
}
|
|
1842
|
+
|
|
1843
|
+
if (requantize) {
|
|
1844
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
1845
|
+
const float d = dm * (y[i].scales[j] & 0xF);
|
|
1846
|
+
if (!d) continue;
|
|
1847
|
+
const float m = mm * (y[i].scales[j] >> 4);
|
|
1848
|
+
for (int ii = 0; ii < 16; ++ii) {
|
|
1849
|
+
int l = nearest_int((x[16*j + ii] + m)/d);
|
|
1850
|
+
l = MAX(0, MIN(3, l));
|
|
1851
|
+
L[16*j + ii] = l;
|
|
1852
|
+
}
|
|
1853
|
+
}
|
|
1854
|
+
}
|
|
1855
|
+
|
|
1856
|
+
#if QK_K == 256
|
|
1857
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
1858
|
+
for (int l = 0; l < 32; ++l) {
|
|
1859
|
+
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
|
|
1860
|
+
}
|
|
1861
|
+
}
|
|
1862
|
+
#else
|
|
1863
|
+
for (int l = 0; l < 16; ++l) {
|
|
1864
|
+
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
|
|
1865
|
+
}
|
|
1866
|
+
#endif
|
|
1867
|
+
|
|
1868
|
+
x += QK_K;
|
|
1869
|
+
|
|
1870
|
+
}
|
|
1871
|
+
}
|
|
1872
|
+
|
|
1873
|
+
size_t wsp_quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
1874
|
+
(void)hist;
|
|
1875
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q2_K, n_per_row);
|
|
1876
|
+
if (!quant_weights) {
|
|
1877
|
+
wsp_quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
|
|
1878
|
+
}
|
|
1879
|
+
else {
|
|
1880
|
+
char * qrow = (char *)dst;
|
|
1881
|
+
for (int row = 0; row < nrow; ++row) {
|
|
1882
|
+
wsp_quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
|
|
1883
|
+
src += n_per_row;
|
|
1884
|
+
qrow += row_size;
|
|
1885
|
+
}
|
|
1886
|
+
}
|
|
1887
|
+
return nrow * row_size;
|
|
1888
|
+
}
|
|
1889
|
+
|
|
1595
1890
|
//========================= 3-bit (de)-quantization
|
|
1596
1891
|
|
|
1597
1892
|
void wsp_quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
|
|
@@ -1805,6 +2100,112 @@ size_t wsp_ggml_wsp_quantize_q3_K(const float * restrict src, void * restrict ds
|
|
|
1805
2100
|
return (n/QK_K*sizeof(block_q3_K));
|
|
1806
2101
|
}
|
|
1807
2102
|
|
|
2103
|
+
static void wsp_quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) {
|
|
2104
|
+
#if QK_K != 256
|
|
2105
|
+
(void)quant_weights;
|
|
2106
|
+
wsp_quantize_row_q3_K_reference(x, y, n_per_row);
|
|
2107
|
+
#else
|
|
2108
|
+
assert(n_per_row % QK_K == 0);
|
|
2109
|
+
const int nb = n_per_row / QK_K;
|
|
2110
|
+
|
|
2111
|
+
int8_t L[QK_K];
|
|
2112
|
+
float scales[QK_K / 16];
|
|
2113
|
+
float weight[16];
|
|
2114
|
+
float sw[QK_K / 16];
|
|
2115
|
+
int8_t Ls[QK_K / 16];
|
|
2116
|
+
|
|
2117
|
+
for (int i = 0; i < nb; i++) {
|
|
2118
|
+
|
|
2119
|
+
float sumx2 = 0;
|
|
2120
|
+
for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
|
|
2121
|
+
float sigma2 = 2*sumx2/QK_K;
|
|
2122
|
+
|
|
2123
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
2124
|
+
if (quant_weights) {
|
|
2125
|
+
const float * qw = quant_weights ? quant_weights + QK_K * i + 16*j : NULL;
|
|
2126
|
+
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]);
|
|
2127
|
+
} else {
|
|
2128
|
+
for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l];
|
|
2129
|
+
}
|
|
2130
|
+
float sumw = 0;
|
|
2131
|
+
for (int l = 0; l < 16; ++l) sumw += weight[l];
|
|
2132
|
+
sw[j] = sumw;
|
|
2133
|
+
|
|
2134
|
+
scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
|
|
2135
|
+
|
|
2136
|
+
}
|
|
2137
|
+
|
|
2138
|
+
memset(y[i].scales, 0, 12);
|
|
2139
|
+
|
|
2140
|
+
float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
|
|
2141
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
2142
|
+
int l = Ls[j];
|
|
2143
|
+
if (j < 8) {
|
|
2144
|
+
y[i].scales[j] = l & 0xF;
|
|
2145
|
+
} else {
|
|
2146
|
+
y[i].scales[j-8] |= ((l & 0xF) << 4);
|
|
2147
|
+
}
|
|
2148
|
+
l >>= 4;
|
|
2149
|
+
y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
|
|
2150
|
+
}
|
|
2151
|
+
y[i].d = WSP_GGML_FP32_TO_FP16(d_block);
|
|
2152
|
+
|
|
2153
|
+
int8_t sc;
|
|
2154
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
2155
|
+
sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
|
|
2156
|
+
sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
|
|
2157
|
+
float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc;
|
|
2158
|
+
if (!d) {
|
|
2159
|
+
continue;
|
|
2160
|
+
}
|
|
2161
|
+
for (int ii = 0; ii < 16; ++ii) {
|
|
2162
|
+
int l = nearest_int(x[16*j + ii]/d);
|
|
2163
|
+
l = MAX(-4, MIN(3, l));
|
|
2164
|
+
L[16*j + ii] = l + 4;
|
|
2165
|
+
}
|
|
2166
|
+
}
|
|
2167
|
+
|
|
2168
|
+
memset(y[i].hmask, 0, QK_K/8);
|
|
2169
|
+
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
|
|
2170
|
+
int m = 0;
|
|
2171
|
+
uint8_t hm = 1;
|
|
2172
|
+
for (int j = 0; j < QK_K; ++j) {
|
|
2173
|
+
if (L[j] > 3) {
|
|
2174
|
+
y[i].hmask[m] |= hm;
|
|
2175
|
+
L[j] -= 4;
|
|
2176
|
+
}
|
|
2177
|
+
if (++m == QK_K/8) {
|
|
2178
|
+
m = 0; hm <<= 1;
|
|
2179
|
+
}
|
|
2180
|
+
}
|
|
2181
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
2182
|
+
for (int l = 0; l < 32; ++l) {
|
|
2183
|
+
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
|
|
2184
|
+
}
|
|
2185
|
+
}
|
|
2186
|
+
|
|
2187
|
+
x += QK_K;
|
|
2188
|
+
}
|
|
2189
|
+
#endif
|
|
2190
|
+
}
|
|
2191
|
+
|
|
2192
|
+
size_t wsp_quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
2193
|
+
(void)hist;
|
|
2194
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q3_K, n_per_row);
|
|
2195
|
+
if (!quant_weights) {
|
|
2196
|
+
wsp_quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
|
|
2197
|
+
}
|
|
2198
|
+
else {
|
|
2199
|
+
char * qrow = (char *)dst;
|
|
2200
|
+
for (int row = 0; row < nrow; ++row) {
|
|
2201
|
+
wsp_quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
|
|
2202
|
+
src += n_per_row;
|
|
2203
|
+
qrow += row_size;
|
|
2204
|
+
}
|
|
2205
|
+
}
|
|
2206
|
+
return nrow * row_size;
|
|
2207
|
+
}
|
|
2208
|
+
|
|
1808
2209
|
// ====================== 4-bit (de)-quantization
|
|
1809
2210
|
|
|
1810
2211
|
void wsp_quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
|
|
@@ -1970,6 +2371,108 @@ size_t wsp_ggml_wsp_quantize_q4_K(const float * restrict src, void * restrict ds
|
|
|
1970
2371
|
return (n/QK_K*sizeof(block_q4_K));
|
|
1971
2372
|
}
|
|
1972
2373
|
|
|
2374
|
+
static void wsp_quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) {
|
|
2375
|
+
#if QK_K != 256
|
|
2376
|
+
(void)quant_weights;
|
|
2377
|
+
wsp_quantize_row_q4_K_reference(x, y, n_per_row);
|
|
2378
|
+
#else
|
|
2379
|
+
assert(n_per_row % QK_K == 0);
|
|
2380
|
+
const int nb = n_per_row / QK_K;
|
|
2381
|
+
|
|
2382
|
+
uint8_t L[QK_K];
|
|
2383
|
+
uint8_t Laux[32];
|
|
2384
|
+
float weights[32];
|
|
2385
|
+
float mins[QK_K/32];
|
|
2386
|
+
float scales[QK_K/32];
|
|
2387
|
+
|
|
2388
|
+
for (int i = 0; i < nb; i++) {
|
|
2389
|
+
|
|
2390
|
+
float sum_x2 = 0;
|
|
2391
|
+
for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
|
|
2392
|
+
float sigma2 = sum_x2/QK_K;
|
|
2393
|
+
float av_x = sqrtf(sigma2);
|
|
2394
|
+
|
|
2395
|
+
float max_scale = 0; // as we are deducting the min, scales are always positive
|
|
2396
|
+
float max_min = 0;
|
|
2397
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
|
2398
|
+
if (quant_weights) {
|
|
2399
|
+
const float * qw = quant_weights + QK_K*i + 32*j;
|
|
2400
|
+
for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
|
|
2401
|
+
} else {
|
|
2402
|
+
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
|
|
2403
|
+
}
|
|
2404
|
+
scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
|
2405
|
+
//scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
|
|
2406
|
+
float scale = scales[j];
|
|
2407
|
+
if (scale > max_scale) {
|
|
2408
|
+
max_scale = scale;
|
|
2409
|
+
}
|
|
2410
|
+
float min = mins[j];
|
|
2411
|
+
if (min > max_min) {
|
|
2412
|
+
max_min = min;
|
|
2413
|
+
}
|
|
2414
|
+
}
|
|
2415
|
+
|
|
2416
|
+
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
|
|
2417
|
+
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
|
|
2418
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
|
2419
|
+
uint8_t ls = nearest_int(inv_scale*scales[j]);
|
|
2420
|
+
uint8_t lm = nearest_int(inv_min*mins[j]);
|
|
2421
|
+
ls = MIN(63, ls);
|
|
2422
|
+
lm = MIN(63, lm);
|
|
2423
|
+
if (j < 4) {
|
|
2424
|
+
y[i].scales[j] = ls;
|
|
2425
|
+
y[i].scales[j+4] = lm;
|
|
2426
|
+
} else {
|
|
2427
|
+
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
|
|
2428
|
+
y[i].scales[j-4] |= ((ls >> 4) << 6);
|
|
2429
|
+
y[i].scales[j-0] |= ((lm >> 4) << 6);
|
|
2430
|
+
}
|
|
2431
|
+
}
|
|
2432
|
+
y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/63.f);
|
|
2433
|
+
y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/63.f);
|
|
2434
|
+
|
|
2435
|
+
uint8_t sc, m;
|
|
2436
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
|
2437
|
+
get_scale_min_k4(j, y[i].scales, &sc, &m);
|
|
2438
|
+
const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc;
|
|
2439
|
+
if (!d) continue;
|
|
2440
|
+
const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * m;
|
|
2441
|
+
for (int ii = 0; ii < 32; ++ii) {
|
|
2442
|
+
int l = nearest_int((x[32*j + ii] + dm)/d);
|
|
2443
|
+
l = MAX(0, MIN(15, l));
|
|
2444
|
+
L[32*j + ii] = l;
|
|
2445
|
+
}
|
|
2446
|
+
}
|
|
2447
|
+
uint8_t * q = y[i].qs;
|
|
2448
|
+
for (int j = 0; j < QK_K; j += 64) {
|
|
2449
|
+
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
|
|
2450
|
+
q += 32;
|
|
2451
|
+
}
|
|
2452
|
+
|
|
2453
|
+
x += QK_K;
|
|
2454
|
+
|
|
2455
|
+
}
|
|
2456
|
+
#endif
|
|
2457
|
+
}
|
|
2458
|
+
|
|
2459
|
+
size_t wsp_quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
2460
|
+
(void)hist;
|
|
2461
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_K, n_per_row);
|
|
2462
|
+
if (!quant_weights) {
|
|
2463
|
+
wsp_quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
|
|
2464
|
+
}
|
|
2465
|
+
else {
|
|
2466
|
+
char * qrow = (char *)dst;
|
|
2467
|
+
for (int row = 0; row < nrow; ++row) {
|
|
2468
|
+
wsp_quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
|
|
2469
|
+
src += n_per_row;
|
|
2470
|
+
qrow += row_size;
|
|
2471
|
+
}
|
|
2472
|
+
}
|
|
2473
|
+
return nrow * row_size;
|
|
2474
|
+
}
|
|
2475
|
+
|
|
1973
2476
|
// ====================== 5-bit (de)-quantization
|
|
1974
2477
|
|
|
1975
2478
|
void wsp_quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
|
|
@@ -2065,7 +2568,7 @@ void wsp_quantize_row_q5_K_reference(const float * restrict x, block_q5_K * rest
|
|
|
2065
2568
|
#else
|
|
2066
2569
|
float max_scale = 0, amax = 0;
|
|
2067
2570
|
for (int j = 0; j < QK_K/16; ++j) {
|
|
2068
|
-
scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
|
|
2571
|
+
scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL);
|
|
2069
2572
|
float abs_scale = fabsf(scales[j]);
|
|
2070
2573
|
if (abs_scale > amax) {
|
|
2071
2574
|
amax = abs_scale;
|
|
@@ -2159,21 +2662,138 @@ void wsp_dewsp_quantize_row_q5_K(const block_q5_K * restrict x, float * restrict
|
|
|
2159
2662
|
}
|
|
2160
2663
|
}
|
|
2161
2664
|
|
|
2162
|
-
void wsp_quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
|
|
2163
|
-
assert(k % QK_K == 0);
|
|
2164
|
-
block_q5_K * restrict y = vy;
|
|
2165
|
-
wsp_quantize_row_q5_K_reference(x, y, k);
|
|
2166
|
-
}
|
|
2665
|
+
void wsp_quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
|
|
2666
|
+
assert(k % QK_K == 0);
|
|
2667
|
+
block_q5_K * restrict y = vy;
|
|
2668
|
+
wsp_quantize_row_q5_K_reference(x, y, k);
|
|
2669
|
+
}
|
|
2670
|
+
|
|
2671
|
+
size_t wsp_ggml_wsp_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
|
|
2672
|
+
assert(k % QK_K == 0);
|
|
2673
|
+
(void)hist; // TODO: collect histograms
|
|
2674
|
+
|
|
2675
|
+
for (int j = 0; j < n; j += k) {
|
|
2676
|
+
block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
|
|
2677
|
+
wsp_quantize_row_q5_K_reference(src + j, y, k);
|
|
2678
|
+
}
|
|
2679
|
+
return (n/QK_K*sizeof(block_q5_K));
|
|
2680
|
+
}
|
|
2681
|
+
|
|
2682
|
+
static void wsp_quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) {
|
|
2683
|
+
#if QK_K != 256
|
|
2684
|
+
(void)quant_weights;
|
|
2685
|
+
wsp_quantize_row_q5_K_reference(x, y, n_per_row);
|
|
2686
|
+
#else
|
|
2687
|
+
assert(n_per_row % QK_K == 0);
|
|
2688
|
+
const int nb = n_per_row / QK_K;
|
|
2689
|
+
|
|
2690
|
+
uint8_t L[QK_K];
|
|
2691
|
+
float mins[QK_K/32];
|
|
2692
|
+
float scales[QK_K/32];
|
|
2693
|
+
float weights[32];
|
|
2694
|
+
uint8_t Laux[32];
|
|
2695
|
+
|
|
2696
|
+
for (int i = 0; i < nb; i++) {
|
|
2697
|
+
|
|
2698
|
+
float sum_x2 = 0;
|
|
2699
|
+
for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
|
|
2700
|
+
float sigma2 = sum_x2/QK_K;
|
|
2701
|
+
float av_x = sqrtf(sigma2);
|
|
2702
|
+
|
|
2703
|
+
float max_scale = 0; // as we are deducting the min, scales are always positive
|
|
2704
|
+
float max_min = 0;
|
|
2705
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
|
2706
|
+
if (quant_weights) {
|
|
2707
|
+
const float * qw = quant_weights + QK_K*i + 32*j;
|
|
2708
|
+
for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
|
|
2709
|
+
} else {
|
|
2710
|
+
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
|
|
2711
|
+
}
|
|
2712
|
+
scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
|
2713
|
+
float scale = scales[j];
|
|
2714
|
+
if (scale > max_scale) {
|
|
2715
|
+
max_scale = scale;
|
|
2716
|
+
}
|
|
2717
|
+
float min = mins[j];
|
|
2718
|
+
if (min > max_min) {
|
|
2719
|
+
max_min = min;
|
|
2720
|
+
}
|
|
2721
|
+
}
|
|
2722
|
+
|
|
2723
|
+
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
|
|
2724
|
+
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
|
|
2725
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
|
2726
|
+
uint8_t ls = nearest_int(inv_scale*scales[j]);
|
|
2727
|
+
uint8_t lm = nearest_int(inv_min*mins[j]);
|
|
2728
|
+
ls = MIN(63, ls);
|
|
2729
|
+
lm = MIN(63, lm);
|
|
2730
|
+
if (j < 4) {
|
|
2731
|
+
y[i].scales[j] = ls;
|
|
2732
|
+
y[i].scales[j+4] = lm;
|
|
2733
|
+
} else {
|
|
2734
|
+
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
|
|
2735
|
+
y[i].scales[j-4] |= ((ls >> 4) << 6);
|
|
2736
|
+
y[i].scales[j-0] |= ((lm >> 4) << 6);
|
|
2737
|
+
}
|
|
2738
|
+
}
|
|
2739
|
+
y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/63.f);
|
|
2740
|
+
y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/63.f);
|
|
2741
|
+
|
|
2742
|
+
uint8_t sc, m;
|
|
2743
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
|
2744
|
+
get_scale_min_k4(j, y[i].scales, &sc, &m);
|
|
2745
|
+
const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc;
|
|
2746
|
+
if (!d) continue;
|
|
2747
|
+
const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * m;
|
|
2748
|
+
for (int ii = 0; ii < 32; ++ii) {
|
|
2749
|
+
int l = nearest_int((x[32*j + ii] + dm)/d);
|
|
2750
|
+
l = MAX(0, MIN(31, l));
|
|
2751
|
+
L[32*j + ii] = l;
|
|
2752
|
+
}
|
|
2753
|
+
}
|
|
2754
|
+
|
|
2755
|
+
uint8_t * restrict qh = y[i].qh;
|
|
2756
|
+
uint8_t * restrict ql = y[i].qs;
|
|
2757
|
+
memset(qh, 0, QK_K/8);
|
|
2758
|
+
|
|
2759
|
+
uint8_t m1 = 1, m2 = 2;
|
|
2760
|
+
for (int n = 0; n < QK_K; n += 64) {
|
|
2761
|
+
for (int j = 0; j < 32; ++j) {
|
|
2762
|
+
int l1 = L[n + j];
|
|
2763
|
+
if (l1 > 15) {
|
|
2764
|
+
l1 -= 16; qh[j] |= m1;
|
|
2765
|
+
}
|
|
2766
|
+
int l2 = L[n + j + 32];
|
|
2767
|
+
if (l2 > 15) {
|
|
2768
|
+
l2 -= 16; qh[j] |= m2;
|
|
2769
|
+
}
|
|
2770
|
+
ql[j] = l1 | (l2 << 4);
|
|
2771
|
+
}
|
|
2772
|
+
m1 <<= 2; m2 <<= 2;
|
|
2773
|
+
ql += 32;
|
|
2774
|
+
}
|
|
2775
|
+
|
|
2776
|
+
x += QK_K;
|
|
2167
2777
|
|
|
2168
|
-
|
|
2169
|
-
|
|
2170
|
-
|
|
2778
|
+
}
|
|
2779
|
+
#endif
|
|
2780
|
+
}
|
|
2171
2781
|
|
|
2172
|
-
|
|
2173
|
-
|
|
2174
|
-
|
|
2782
|
+
size_t wsp_quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
2783
|
+
(void)hist;
|
|
2784
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_K, n_per_row);
|
|
2785
|
+
if (!quant_weights) {
|
|
2786
|
+
wsp_quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
|
|
2175
2787
|
}
|
|
2176
|
-
|
|
2788
|
+
else {
|
|
2789
|
+
char * qrow = (char *)dst;
|
|
2790
|
+
for (int row = 0; row < nrow; ++row) {
|
|
2791
|
+
wsp_quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
|
|
2792
|
+
src += n_per_row;
|
|
2793
|
+
qrow += row_size;
|
|
2794
|
+
}
|
|
2795
|
+
}
|
|
2796
|
+
return nrow * row_size;
|
|
2177
2797
|
}
|
|
2178
2798
|
|
|
2179
2799
|
// ====================== 6-bit (de)-quantization
|
|
@@ -2192,7 +2812,7 @@ void wsp_quantize_row_q6_K_reference(const float * restrict x, block_q6_K * rest
|
|
|
2192
2812
|
|
|
2193
2813
|
for (int ib = 0; ib < QK_K/16; ++ib) {
|
|
2194
2814
|
|
|
2195
|
-
const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
|
|
2815
|
+
const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
|
|
2196
2816
|
scales[ib] = scale;
|
|
2197
2817
|
|
|
2198
2818
|
const float abs_scale = fabsf(scale);
|
|
@@ -2324,6 +2944,569 @@ size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, i
|
|
|
2324
2944
|
return (n/QK_K*sizeof(block_q6_K));
|
|
2325
2945
|
}
|
|
2326
2946
|
|
|
2947
|
+
static void wsp_quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) {
|
|
2948
|
+
#if QK_K != 256
|
|
2949
|
+
(void)quant_weights;
|
|
2950
|
+
wsp_quantize_row_q6_K_reference(x, y, n_per_row);
|
|
2951
|
+
#else
|
|
2952
|
+
assert(n_per_row % QK_K == 0);
|
|
2953
|
+
const int nb = n_per_row / QK_K;
|
|
2954
|
+
|
|
2955
|
+
int8_t L[QK_K];
|
|
2956
|
+
float scales[QK_K/16];
|
|
2957
|
+
//float weights[16];
|
|
2958
|
+
|
|
2959
|
+
for (int i = 0; i < nb; i++) {
|
|
2960
|
+
|
|
2961
|
+
//float sum_x2 = 0;
|
|
2962
|
+
//for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j];
|
|
2963
|
+
//float sigma2 = sum_x2/QK_K;
|
|
2964
|
+
|
|
2965
|
+
float max_scale = 0;
|
|
2966
|
+
float max_abs_scale = 0;
|
|
2967
|
+
|
|
2968
|
+
for (int ib = 0; ib < QK_K/16; ++ib) {
|
|
2969
|
+
|
|
2970
|
+
float scale;
|
|
2971
|
+
if (quant_weights) {
|
|
2972
|
+
const float * qw = quant_weights + QK_K*i + 16*ib;
|
|
2973
|
+
//for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]);
|
|
2974
|
+
//scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights);
|
|
2975
|
+
scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw);
|
|
2976
|
+
} else {
|
|
2977
|
+
scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
|
|
2978
|
+
}
|
|
2979
|
+
scales[ib] = scale;
|
|
2980
|
+
|
|
2981
|
+
const float abs_scale = fabsf(scale);
|
|
2982
|
+
if (abs_scale > max_abs_scale) {
|
|
2983
|
+
max_abs_scale = abs_scale;
|
|
2984
|
+
max_scale = scale;
|
|
2985
|
+
}
|
|
2986
|
+
|
|
2987
|
+
}
|
|
2988
|
+
|
|
2989
|
+
if (!max_abs_scale) {
|
|
2990
|
+
memset(&y[i], 0, sizeof(block_q6_K));
|
|
2991
|
+
y[i].d = WSP_GGML_FP32_TO_FP16(0.f);
|
|
2992
|
+
x += QK_K;
|
|
2993
|
+
continue;
|
|
2994
|
+
}
|
|
2995
|
+
|
|
2996
|
+
float iscale = -128.f/max_scale;
|
|
2997
|
+
y[i].d = WSP_GGML_FP32_TO_FP16(1/iscale);
|
|
2998
|
+
for (int ib = 0; ib < QK_K/16; ++ib) {
|
|
2999
|
+
y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
|
|
3000
|
+
}
|
|
3001
|
+
|
|
3002
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
3003
|
+
float d = WSP_GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
|
|
3004
|
+
if (!d) {
|
|
3005
|
+
continue;
|
|
3006
|
+
}
|
|
3007
|
+
for (int ii = 0; ii < 16; ++ii) {
|
|
3008
|
+
int l = nearest_int(x[16*j + ii]/d);
|
|
3009
|
+
l = MAX(-32, MIN(31, l));
|
|
3010
|
+
L[16*j + ii] = l + 32;
|
|
3011
|
+
}
|
|
3012
|
+
}
|
|
3013
|
+
|
|
3014
|
+
uint8_t * restrict ql = y[i].ql;
|
|
3015
|
+
uint8_t * restrict qh = y[i].qh;
|
|
3016
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
3017
|
+
for (int l = 0; l < 32; ++l) {
|
|
3018
|
+
const uint8_t q1 = L[j + l + 0] & 0xF;
|
|
3019
|
+
const uint8_t q2 = L[j + l + 32] & 0xF;
|
|
3020
|
+
const uint8_t q3 = L[j + l + 64] & 0xF;
|
|
3021
|
+
const uint8_t q4 = L[j + l + 96] & 0xF;
|
|
3022
|
+
ql[l+ 0] = q1 | (q3 << 4);
|
|
3023
|
+
ql[l+32] = q2 | (q4 << 4);
|
|
3024
|
+
qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
|
|
3025
|
+
}
|
|
3026
|
+
ql += 64;
|
|
3027
|
+
qh += 32;
|
|
3028
|
+
}
|
|
3029
|
+
|
|
3030
|
+
x += QK_K;
|
|
3031
|
+
|
|
3032
|
+
}
|
|
3033
|
+
#endif
|
|
3034
|
+
}
|
|
3035
|
+
|
|
3036
|
+
size_t wsp_quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
3037
|
+
(void)hist;
|
|
3038
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q6_K, n_per_row);
|
|
3039
|
+
if (!quant_weights) {
|
|
3040
|
+
wsp_quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
|
|
3041
|
+
}
|
|
3042
|
+
else {
|
|
3043
|
+
char * qrow = (char *)dst;
|
|
3044
|
+
for (int row = 0; row < nrow; ++row) {
|
|
3045
|
+
wsp_quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
|
|
3046
|
+
src += n_per_row;
|
|
3047
|
+
qrow += row_size;
|
|
3048
|
+
}
|
|
3049
|
+
}
|
|
3050
|
+
return nrow * row_size;
|
|
3051
|
+
}
|
|
3052
|
+
|
|
3053
|
+
static void wsp_quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int n_per_row, const float * quant_weights) {
|
|
3054
|
+
static_assert(QK4_0 == 32, "QK4_0 must be 32");
|
|
3055
|
+
|
|
3056
|
+
if (!quant_weights) {
|
|
3057
|
+
wsp_quantize_row_q4_0_reference(x, y, n_per_row);
|
|
3058
|
+
return;
|
|
3059
|
+
}
|
|
3060
|
+
|
|
3061
|
+
float weight[QK4_0];
|
|
3062
|
+
int8_t L[QK4_0];
|
|
3063
|
+
|
|
3064
|
+
float sum_x2 = 0;
|
|
3065
|
+
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
|
|
3066
|
+
float sigma2 = sum_x2/n_per_row;
|
|
3067
|
+
|
|
3068
|
+
const int nb = n_per_row/QK4_0;
|
|
3069
|
+
for (int ib = 0; ib < nb; ++ib) {
|
|
3070
|
+
const float * xb = x + QK4_0 * ib;
|
|
3071
|
+
const float * qw = quant_weights + QK4_0 * ib;
|
|
3072
|
+
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
|
3073
|
+
float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
|
|
3074
|
+
y[ib].d = WSP_GGML_FP32_TO_FP16(d);
|
|
3075
|
+
for (int j = 0; j < 16; ++j) {
|
|
3076
|
+
y[ib].qs[j] = L[j] | (L[j+16] << 4);
|
|
3077
|
+
}
|
|
3078
|
+
}
|
|
3079
|
+
}
|
|
3080
|
+
|
|
3081
|
+
size_t wsp_quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
3082
|
+
if (!quant_weights) {
|
|
3083
|
+
return wsp_ggml_wsp_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist);
|
|
3084
|
+
}
|
|
3085
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_0, n_per_row);
|
|
3086
|
+
char * qrow = (char *)dst;
|
|
3087
|
+
for (int row = 0; row < nrow; ++row) {
|
|
3088
|
+
wsp_quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
|
|
3089
|
+
src += n_per_row;
|
|
3090
|
+
qrow += row_size;
|
|
3091
|
+
}
|
|
3092
|
+
return nrow * row_size;
|
|
3093
|
+
}
|
|
3094
|
+
|
|
3095
|
+
static void wsp_quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int n_per_row, const float * quant_weights) {
|
|
3096
|
+
static_assert(QK4_1 == 32, "QK4_1 must be 32");
|
|
3097
|
+
|
|
3098
|
+
if (!quant_weights) {
|
|
3099
|
+
wsp_quantize_row_q4_1_reference(x, y, n_per_row);
|
|
3100
|
+
return;
|
|
3101
|
+
}
|
|
3102
|
+
|
|
3103
|
+
float weight[QK4_1];
|
|
3104
|
+
uint8_t L[QK4_1], Laux[QK4_1];
|
|
3105
|
+
|
|
3106
|
+
float sum_x2 = 0;
|
|
3107
|
+
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
|
|
3108
|
+
float sigma2 = sum_x2/n_per_row;
|
|
3109
|
+
|
|
3110
|
+
const int nb = n_per_row/QK4_1;
|
|
3111
|
+
for (int ib = 0; ib < nb; ++ib) {
|
|
3112
|
+
const float * xb = x + QK4_1 * ib;
|
|
3113
|
+
const float * qw = quant_weights + QK4_1 * ib;
|
|
3114
|
+
for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
|
3115
|
+
float min;
|
|
3116
|
+
float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
|
|
3117
|
+
y[ib].d = WSP_GGML_FP32_TO_FP16(d);
|
|
3118
|
+
y[ib].m = WSP_GGML_FP32_TO_FP16(-min);
|
|
3119
|
+
for (int j = 0; j < 16; ++j) {
|
|
3120
|
+
y[ib].qs[j] = L[j] | (L[j+16] << 4);
|
|
3121
|
+
}
|
|
3122
|
+
}
|
|
3123
|
+
}
|
|
3124
|
+
|
|
3125
|
+
size_t wsp_quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
3126
|
+
if (!quant_weights) {
|
|
3127
|
+
return wsp_ggml_wsp_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist);
|
|
3128
|
+
}
|
|
3129
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q4_1, n_per_row);
|
|
3130
|
+
char * qrow = (char *)dst;
|
|
3131
|
+
for (int row = 0; row < nrow; ++row) {
|
|
3132
|
+
wsp_quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
|
|
3133
|
+
src += n_per_row;
|
|
3134
|
+
qrow += row_size;
|
|
3135
|
+
}
|
|
3136
|
+
return nrow * row_size;
|
|
3137
|
+
}
|
|
3138
|
+
|
|
3139
|
+
static void wsp_quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int n_per_row, const float * quant_weights) {
|
|
3140
|
+
static_assert(QK5_0 == 32, "QK5_0 must be 32");
|
|
3141
|
+
|
|
3142
|
+
if (!quant_weights) {
|
|
3143
|
+
wsp_quantize_row_q5_0_reference(x, y, n_per_row);
|
|
3144
|
+
return;
|
|
3145
|
+
}
|
|
3146
|
+
|
|
3147
|
+
float weight[QK5_0];
|
|
3148
|
+
int8_t L[QK5_0];
|
|
3149
|
+
|
|
3150
|
+
float sum_x2 = 0;
|
|
3151
|
+
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
|
|
3152
|
+
float sigma2 = sum_x2/n_per_row;
|
|
3153
|
+
|
|
3154
|
+
const int nb = n_per_row/QK5_0;
|
|
3155
|
+
for (int ib = 0; ib < nb; ++ib) {
|
|
3156
|
+
const float * xb = x + QK5_0 * ib;
|
|
3157
|
+
const float * qw = quant_weights + QK5_0 * ib;
|
|
3158
|
+
for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
|
3159
|
+
float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight);
|
|
3160
|
+
y[ib].d = WSP_GGML_FP32_TO_FP16(d);
|
|
3161
|
+
|
|
3162
|
+
uint32_t qh = 0;
|
|
3163
|
+
|
|
3164
|
+
for (int j = 0; j < 16; ++j) {
|
|
3165
|
+
const uint8_t xi0 = L[j];
|
|
3166
|
+
const uint8_t xi1 = L[j+16];
|
|
3167
|
+
y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
|
|
3168
|
+
|
|
3169
|
+
// get the 5-th bit and store it in qh at the right position
|
|
3170
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
3171
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
3172
|
+
}
|
|
3173
|
+
|
|
3174
|
+
memcpy(&y[ib].qh, &qh, sizeof(qh));
|
|
3175
|
+
}
|
|
3176
|
+
}
|
|
3177
|
+
|
|
3178
|
+
size_t wsp_quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
3179
|
+
if (!quant_weights) {
|
|
3180
|
+
return wsp_ggml_wsp_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist);
|
|
3181
|
+
}
|
|
3182
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_0, n_per_row);
|
|
3183
|
+
char * qrow = (char *)dst;
|
|
3184
|
+
for (int row = 0; row < nrow; ++row) {
|
|
3185
|
+
wsp_quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
|
|
3186
|
+
src += n_per_row;
|
|
3187
|
+
qrow += row_size;
|
|
3188
|
+
}
|
|
3189
|
+
return nrow * row_size;
|
|
3190
|
+
}
|
|
3191
|
+
|
|
3192
|
+
static void wsp_quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int n_per_row, const float * quant_weights) {
|
|
3193
|
+
static_assert(QK5_1 == 32, "QK5_1 must be 32");
|
|
3194
|
+
|
|
3195
|
+
if (!quant_weights) {
|
|
3196
|
+
wsp_quantize_row_q5_1_reference(x, y, n_per_row);
|
|
3197
|
+
return;
|
|
3198
|
+
}
|
|
3199
|
+
|
|
3200
|
+
float weight[QK5_1];
|
|
3201
|
+
uint8_t L[QK5_1], Laux[QK5_1];
|
|
3202
|
+
|
|
3203
|
+
float sum_x2 = 0;
|
|
3204
|
+
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
|
|
3205
|
+
float sigma2 = sum_x2/n_per_row;
|
|
3206
|
+
|
|
3207
|
+
const int nb = n_per_row/QK5_1;
|
|
3208
|
+
for (int ib = 0; ib < nb; ++ib) {
|
|
3209
|
+
const float * xb = x + QK5_1 * ib;
|
|
3210
|
+
const float * qw = quant_weights + QK5_1 * ib;
|
|
3211
|
+
for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
|
3212
|
+
float min;
|
|
3213
|
+
float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
|
|
3214
|
+
y[ib].d = WSP_GGML_FP32_TO_FP16(d);
|
|
3215
|
+
y[ib].m = WSP_GGML_FP32_TO_FP16(-min);
|
|
3216
|
+
|
|
3217
|
+
uint32_t qh = 0;
|
|
3218
|
+
for (int j = 0; j < 16; ++j) {
|
|
3219
|
+
const uint8_t xi0 = L[j];
|
|
3220
|
+
const uint8_t xi1 = L[j+16];
|
|
3221
|
+
y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
|
|
3222
|
+
// get the 5-th bit and store it in qh at the right position
|
|
3223
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
3224
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
3225
|
+
}
|
|
3226
|
+
memcpy(&y[ib].qh, &qh, sizeof(qh));
|
|
3227
|
+
}
|
|
3228
|
+
}
|
|
3229
|
+
|
|
3230
|
+
size_t wsp_quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
3231
|
+
if (!quant_weights) {
|
|
3232
|
+
return wsp_ggml_wsp_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist);
|
|
3233
|
+
}
|
|
3234
|
+
size_t row_size = wsp_ggml_row_size(WSP_GGML_TYPE_Q5_1, n_per_row);
|
|
3235
|
+
char * qrow = (char *)dst;
|
|
3236
|
+
for (int row = 0; row < nrow; ++row) {
|
|
3237
|
+
wsp_quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
|
|
3238
|
+
src += n_per_row;
|
|
3239
|
+
qrow += row_size;
|
|
3240
|
+
}
|
|
3241
|
+
return nrow * row_size;
|
|
3242
|
+
}
|
|
3243
|
+
|
|
3244
|
+
// ====================== "True" 2-bit (de)-quantization
|
|
3245
|
+
|
|
3246
|
+
static const uint64_t iq2xxs_grid[256] = {
|
|
3247
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3248
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
|
3249
|
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
3250
|
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
|
3251
|
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
|
3252
|
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
|
3253
|
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
|
3254
|
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
|
3255
|
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
|
3256
|
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
|
3257
|
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
|
3258
|
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
|
3259
|
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
|
3260
|
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
|
3261
|
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
|
3262
|
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
|
3263
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
|
3264
|
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
|
3265
|
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
|
3266
|
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
|
3267
|
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
|
3268
|
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
|
3269
|
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
|
3270
|
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
|
3271
|
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
|
3272
|
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
|
3273
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
|
3274
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
|
3275
|
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
|
3276
|
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
|
3277
|
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
|
3278
|
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
|
3279
|
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
|
3280
|
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
|
3281
|
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
|
3282
|
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
|
3283
|
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
|
3284
|
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
|
3285
|
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
|
3286
|
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
|
3287
|
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
|
3288
|
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
|
3289
|
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
|
3290
|
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
|
3291
|
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
|
3292
|
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
|
3293
|
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
|
3294
|
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
|
3295
|
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
|
3296
|
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
|
3297
|
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
|
3298
|
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
|
3299
|
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
|
3300
|
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
|
3301
|
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
|
3302
|
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
|
3303
|
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
|
3304
|
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
|
3305
|
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
|
3306
|
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
|
3307
|
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
|
3308
|
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
|
3309
|
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
|
3310
|
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
|
3311
|
+
};
|
|
3312
|
+
|
|
3313
|
+
static const uint64_t iq2xs_grid[512] = {
|
|
3314
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3315
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
|
3316
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
|
3317
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
|
3318
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
|
3319
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
|
3320
|
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
|
3321
|
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
|
3322
|
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
|
3323
|
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
|
3324
|
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
|
3325
|
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
|
3326
|
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
|
3327
|
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
|
3328
|
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
|
3329
|
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
|
3330
|
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
|
3331
|
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
|
3332
|
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
|
3333
|
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
|
3334
|
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
|
3335
|
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
|
3336
|
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
|
3337
|
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
|
3338
|
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
|
3339
|
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
|
3340
|
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
|
3341
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
|
3342
|
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
|
3343
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
|
3344
|
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
|
3345
|
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
|
3346
|
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
|
3347
|
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
|
3348
|
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
|
3349
|
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
|
3350
|
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
|
3351
|
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
|
3352
|
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
|
3353
|
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
|
3354
|
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
|
3355
|
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
|
3356
|
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
|
3357
|
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
|
3358
|
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
|
3359
|
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
|
3360
|
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
|
3361
|
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
|
3362
|
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
|
3363
|
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
|
3364
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
|
3365
|
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
|
3366
|
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
|
3367
|
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
|
3368
|
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
|
3369
|
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
|
3370
|
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
|
3371
|
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
|
3372
|
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
|
3373
|
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
|
3374
|
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
|
3375
|
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
|
3376
|
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
|
3377
|
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
|
3378
|
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
|
3379
|
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
|
3380
|
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
|
3381
|
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
|
3382
|
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
|
3383
|
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
|
3384
|
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
|
3385
|
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
|
3386
|
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
|
3387
|
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
|
3388
|
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
|
3389
|
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
|
3390
|
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
|
3391
|
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
|
3392
|
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
|
3393
|
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
|
3394
|
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
|
3395
|
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
|
3396
|
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
|
3397
|
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
|
3398
|
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
|
3399
|
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
|
3400
|
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
|
3401
|
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
|
3402
|
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
|
3403
|
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
|
3404
|
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
|
3405
|
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
|
3406
|
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
|
3407
|
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
|
3408
|
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
|
3409
|
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
|
3410
|
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
|
3411
|
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
|
3412
|
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
|
3413
|
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
|
3414
|
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
|
3415
|
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
|
3416
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
|
3417
|
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
|
3418
|
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
|
3419
|
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
|
3420
|
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
|
3421
|
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
|
3422
|
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
|
3423
|
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
|
3424
|
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
|
3425
|
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
|
3426
|
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
|
3427
|
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
|
3428
|
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
|
3429
|
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
|
3430
|
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
|
3431
|
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
|
3432
|
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
|
3433
|
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
|
3434
|
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
|
3435
|
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
|
3436
|
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
|
3437
|
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
|
3438
|
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
|
3439
|
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
|
3440
|
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
|
3441
|
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
|
3442
|
+
};
|
|
3443
|
+
|
|
3444
|
+
static const uint8_t ksigns_iq2xs[128] = {
|
|
3445
|
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
3446
|
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
3447
|
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
|
3448
|
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
|
3449
|
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
|
3450
|
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
|
3451
|
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
|
3452
|
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
|
3453
|
+
};
|
|
3454
|
+
|
|
3455
|
+
static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
|
3456
|
+
|
|
3457
|
+
void wsp_dewsp_quantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
|
|
3458
|
+
assert(k % QK_K == 0);
|
|
3459
|
+
const int nb = k / QK_K;
|
|
3460
|
+
|
|
3461
|
+
uint32_t aux32[2];
|
|
3462
|
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
|
3463
|
+
|
|
3464
|
+
for (int i = 0; i < nb; i++) {
|
|
3465
|
+
|
|
3466
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
3467
|
+
|
|
3468
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
|
3469
|
+
memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
|
|
3470
|
+
const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
|
|
3471
|
+
for (int l = 0; l < 4; ++l) {
|
|
3472
|
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
|
3473
|
+
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
|
|
3474
|
+
for (int j = 0; j < 8; ++j) {
|
|
3475
|
+
y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3476
|
+
}
|
|
3477
|
+
y += 8;
|
|
3478
|
+
}
|
|
3479
|
+
}
|
|
3480
|
+
}
|
|
3481
|
+
}
|
|
3482
|
+
|
|
3483
|
+
// ====================== 2.3125 bpw (de)-quantization
|
|
3484
|
+
|
|
3485
|
+
void wsp_dewsp_quantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
|
|
3486
|
+
assert(k % QK_K == 0);
|
|
3487
|
+
const int nb = k / QK_K;
|
|
3488
|
+
|
|
3489
|
+
float db[2];
|
|
3490
|
+
|
|
3491
|
+
for (int i = 0; i < nb; i++) {
|
|
3492
|
+
|
|
3493
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
3494
|
+
|
|
3495
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
|
3496
|
+
db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
|
|
3497
|
+
db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
|
|
3498
|
+
for (int l = 0; l < 4; ++l) {
|
|
3499
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
|
|
3500
|
+
const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
|
|
3501
|
+
for (int j = 0; j < 8; ++j) {
|
|
3502
|
+
y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3503
|
+
}
|
|
3504
|
+
y += 8;
|
|
3505
|
+
}
|
|
3506
|
+
}
|
|
3507
|
+
}
|
|
3508
|
+
}
|
|
3509
|
+
|
|
2327
3510
|
//===================================== Q8_K ==============================================
|
|
2328
3511
|
|
|
2329
3512
|
void wsp_quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
|
@@ -2346,7 +3529,9 @@ void wsp_quantize_row_q8_K_reference(const float * restrict x, block_q8_K * rest
|
|
|
2346
3529
|
x += QK_K;
|
|
2347
3530
|
continue;
|
|
2348
3531
|
}
|
|
2349
|
-
const float iscale = -128.f/max;
|
|
3532
|
+
//const float iscale = -128.f/max;
|
|
3533
|
+
// We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
|
|
3534
|
+
const float iscale = -127.f/max;
|
|
2350
3535
|
for (int j = 0; j < QK_K; ++j) {
|
|
2351
3536
|
int v = nearest_int(iscale*x[j]);
|
|
2352
3537
|
y[i].qs[j] = MIN(127, v);
|
|
@@ -2468,32 +3653,12 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict
|
|
|
2468
3653
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
|
2469
3654
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
|
2470
3655
|
|
|
2471
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
2472
3656
|
// dot product into int32x4_t
|
|
2473
|
-
const int32x4_t p_0 =
|
|
2474
|
-
const int32x4_t p_1 =
|
|
3657
|
+
const int32x4_t p_0 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
|
3658
|
+
const int32x4_t p_1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
|
2475
3659
|
|
|
2476
3660
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
|
|
2477
3661
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
|
|
2478
|
-
#else
|
|
2479
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
|
|
2480
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
|
2481
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
|
|
2482
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
|
2483
|
-
|
|
2484
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
|
|
2485
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
|
2486
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
|
|
2487
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
|
2488
|
-
|
|
2489
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
|
2490
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
|
2491
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
|
2492
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
|
2493
|
-
|
|
2494
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
|
|
2495
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
|
|
2496
|
-
#endif
|
|
2497
3662
|
}
|
|
2498
3663
|
|
|
2499
3664
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
@@ -2776,32 +3941,12 @@ void wsp_ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * re
|
|
|
2776
3941
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
|
2777
3942
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
|
2778
3943
|
|
|
2779
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
2780
3944
|
// dot product into int32x4_t
|
|
2781
|
-
const int32x4_t p_0 =
|
|
2782
|
-
const int32x4_t p_1 =
|
|
3945
|
+
const int32x4_t p_0 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
|
|
3946
|
+
const int32x4_t p_1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
|
|
2783
3947
|
|
|
2784
3948
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
|
|
2785
3949
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
|
|
2786
|
-
#else
|
|
2787
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
|
|
2788
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
|
|
2789
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
|
|
2790
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
|
|
2791
|
-
|
|
2792
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
|
|
2793
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
|
|
2794
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
|
|
2795
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
|
|
2796
|
-
|
|
2797
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
|
2798
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
|
2799
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
|
2800
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
|
2801
|
-
|
|
2802
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
|
|
2803
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
|
|
2804
|
-
#endif
|
|
2805
3950
|
}
|
|
2806
3951
|
|
|
2807
3952
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
|
@@ -2963,32 +4108,12 @@ void wsp_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * re
|
|
|
2963
4108
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
|
2964
4109
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
|
2965
4110
|
|
|
2966
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
2967
4111
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
|
2968
|
-
|
|
2969
|
-
|
|
4112
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
|
4113
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
|
|
2970
4114
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
|
2971
|
-
|
|
2972
|
-
|
|
2973
|
-
#else
|
|
2974
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
|
2975
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
|
2976
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
|
|
2977
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
|
|
2978
|
-
|
|
2979
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
|
|
2980
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
|
|
2981
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
|
|
2982
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
|
|
2983
|
-
|
|
2984
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
|
2985
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
|
2986
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
|
2987
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
|
2988
|
-
|
|
2989
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
|
|
2990
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
|
|
2991
|
-
#endif
|
|
4115
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
|
4116
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
|
|
2992
4117
|
}
|
|
2993
4118
|
|
|
2994
4119
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
@@ -3275,32 +4400,12 @@ void wsp_ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * re
|
|
|
3275
4400
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
|
3276
4401
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
|
3277
4402
|
|
|
3278
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
3279
4403
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
|
3280
|
-
|
|
3281
|
-
|
|
4404
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
|
4405
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
|
|
3282
4406
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
|
3283
|
-
|
|
3284
|
-
|
|
3285
|
-
#else
|
|
3286
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
|
3287
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
|
3288
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
|
|
3289
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
|
|
3290
|
-
|
|
3291
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
|
|
3292
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
|
|
3293
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
|
|
3294
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
|
|
3295
|
-
|
|
3296
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
|
3297
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
|
3298
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
|
3299
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
|
3300
|
-
|
|
3301
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d);
|
|
3302
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
|
|
3303
|
-
#endif
|
|
4407
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
|
4408
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d);
|
|
3304
4409
|
}
|
|
3305
4410
|
|
|
3306
4411
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
|
@@ -3550,34 +4655,13 @@ void wsp_ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * re
|
|
|
3550
4655
|
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
|
3551
4656
|
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
|
3552
4657
|
|
|
3553
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
3554
4658
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
|
3555
|
-
|
|
3556
|
-
|
|
4659
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
|
4660
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
|
|
3557
4661
|
|
|
3558
4662
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
|
3559
|
-
|
|
3560
|
-
|
|
3561
|
-
|
|
3562
|
-
#else
|
|
3563
|
-
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
|
3564
|
-
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
|
3565
|
-
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
|
3566
|
-
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
|
3567
|
-
|
|
3568
|
-
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
|
3569
|
-
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
|
3570
|
-
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
|
3571
|
-
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
|
3572
|
-
|
|
3573
|
-
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
|
3574
|
-
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
|
3575
|
-
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
|
3576
|
-
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
|
3577
|
-
|
|
3578
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d));
|
|
3579
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
|
|
3580
|
-
#endif
|
|
4663
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
|
4664
|
+
wsp_ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d));
|
|
3581
4665
|
}
|
|
3582
4666
|
|
|
3583
4667
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
@@ -3650,12 +4734,10 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
3650
4734
|
const int nb = n / QK_K;
|
|
3651
4735
|
|
|
3652
4736
|
#ifdef __ARM_NEON
|
|
3653
|
-
|
|
3654
4737
|
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
|
3655
4738
|
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
|
3656
|
-
|
|
3657
|
-
const int32x4_t
|
|
3658
|
-
#endif
|
|
4739
|
+
|
|
4740
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
|
3659
4741
|
|
|
3660
4742
|
wsp_ggml_int8x16x2_t q2bytes;
|
|
3661
4743
|
uint8_t aux[16];
|
|
@@ -3663,7 +4745,6 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
3663
4745
|
float sum = 0;
|
|
3664
4746
|
|
|
3665
4747
|
for (int i = 0; i < nb; ++i) {
|
|
3666
|
-
|
|
3667
4748
|
const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
3668
4749
|
const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin);
|
|
3669
4750
|
|
|
@@ -3677,7 +4758,7 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
3677
4758
|
|
|
3678
4759
|
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
|
3679
4760
|
const wsp_ggml_int16x8x2_t q8sums = wsp_ggml_vld1q_s16_x2(y[i].bsums);
|
|
3680
|
-
const wsp_ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
|
|
4761
|
+
const wsp_ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
|
|
3681
4762
|
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
|
3682
4763
|
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
|
3683
4764
|
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
|
@@ -3689,20 +4770,9 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
3689
4770
|
|
|
3690
4771
|
// We use this macro instead of a function call because for some reason
|
|
3691
4772
|
// the code runs 2-3% slower, even if the function is declared inline
|
|
3692
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
3693
|
-
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
|
3694
|
-
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
|
3695
|
-
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
|
3696
|
-
#else
|
|
3697
4773
|
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
|
3698
|
-
|
|
3699
|
-
|
|
3700
|
-
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
|
|
3701
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
|
|
3702
|
-
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
|
|
3703
|
-
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
|
|
3704
|
-
}
|
|
3705
|
-
#endif
|
|
4774
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
|
4775
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
|
3706
4776
|
|
|
3707
4777
|
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
|
3708
4778
|
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;\
|
|
@@ -3710,26 +4780,23 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
3710
4780
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
|
3711
4781
|
MULTIPLY_ACCUM_WITH_SCALE((index));
|
|
3712
4782
|
|
|
3713
|
-
|
|
3714
4783
|
for (int j = 0; j < QK_K/128; ++j) {
|
|
3715
|
-
|
|
3716
4784
|
const wsp_ggml_uint8x16x2_t q2bits = wsp_ggml_vld1q_u8_x2(q2); q2 += 32;
|
|
3717
4785
|
|
|
3718
4786
|
wsp_ggml_int8x16x2_t q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
3719
4787
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
|
3720
4788
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
|
4789
|
+
|
|
3721
4790
|
MULTIPLY_ACCUM_WITH_SCALE(0);
|
|
3722
4791
|
|
|
3723
4792
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
|
3724
|
-
|
|
3725
4793
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
|
3726
|
-
|
|
3727
4794
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
|
3728
4795
|
|
|
3729
4796
|
is += 8;
|
|
3730
4797
|
}
|
|
3731
|
-
sum += d * isum;
|
|
3732
4798
|
|
|
4799
|
+
sum += d * isum;
|
|
3733
4800
|
}
|
|
3734
4801
|
|
|
3735
4802
|
*s = sum;
|
|
@@ -4043,11 +5110,9 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
4043
5110
|
const int nb = n / QK_K;
|
|
4044
5111
|
|
|
4045
5112
|
#ifdef __ARM_NEON
|
|
4046
|
-
|
|
4047
5113
|
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
|
4048
|
-
|
|
4049
|
-
const int32x4_t
|
|
4050
|
-
#endif
|
|
5114
|
+
|
|
5115
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
|
4051
5116
|
|
|
4052
5117
|
wsp_ggml_int8x16x4_t q2bytes;
|
|
4053
5118
|
|
|
@@ -4081,28 +5146,12 @@ void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
4081
5146
|
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
|
4082
5147
|
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
|
4083
5148
|
|
|
4084
|
-
|
|
4085
|
-
|
|
4086
|
-
|
|
4087
|
-
|
|
4088
|
-
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
|
4089
|
-
#else
|
|
4090
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
4091
|
-
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
4092
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
4093
|
-
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
4094
|
-
isum1 += vaddvq_s16(p1) * scales[0];
|
|
4095
|
-
isum2 += vaddvq_s16(p2) * scales[1];
|
|
4096
|
-
|
|
4097
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
|
4098
|
-
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
|
4099
|
-
const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
|
4100
|
-
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
|
4101
|
-
isum1 += vaddvq_s16(p3) * scales[2];
|
|
4102
|
-
isum2 += vaddvq_s16(p4) * scales[3];
|
|
4103
|
-
#endif
|
|
4104
|
-
sum += d * (isum1 + isum2);
|
|
5149
|
+
isum1 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
|
5150
|
+
isum2 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
|
5151
|
+
isum1 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
|
5152
|
+
isum2 += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
|
4105
5153
|
|
|
5154
|
+
sum += d * (isum1 + isum2);
|
|
4106
5155
|
}
|
|
4107
5156
|
|
|
4108
5157
|
*s = sum;
|
|
@@ -4328,9 +5377,7 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
4328
5377
|
uint32_t utmp[4];
|
|
4329
5378
|
|
|
4330
5379
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
|
4331
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
|
4332
5380
|
const int32x4_t vzero = vdupq_n_s32(0);
|
|
4333
|
-
#endif
|
|
4334
5381
|
|
|
4335
5382
|
const uint8x16_t m0 = vdupq_n_u8(1);
|
|
4336
5383
|
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
|
|
@@ -4382,22 +5429,11 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
4382
5429
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
|
4383
5430
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
|
4384
5431
|
|
|
4385
|
-
|
|
4386
|
-
isum += vaddvq_s32(
|
|
4387
|
-
isum += vaddvq_s32(
|
|
4388
|
-
isum += vaddvq_s32(
|
|
4389
|
-
|
|
4390
|
-
#else
|
|
4391
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
|
|
4392
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
|
|
4393
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
|
|
4394
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
|
|
4395
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
|
|
4396
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
|
|
4397
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
|
|
4398
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
|
|
4399
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
|
4400
|
-
#endif
|
|
5432
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
|
5433
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
|
5434
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
|
5435
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
|
5436
|
+
|
|
4401
5437
|
scale += 4;
|
|
4402
5438
|
|
|
4403
5439
|
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
|
|
@@ -4410,22 +5446,11 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
4410
5446
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
|
4411
5447
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
|
4412
5448
|
|
|
4413
|
-
|
|
4414
|
-
isum += vaddvq_s32(
|
|
4415
|
-
isum += vaddvq_s32(
|
|
4416
|
-
isum += vaddvq_s32(
|
|
4417
|
-
|
|
4418
|
-
#else
|
|
4419
|
-
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
|
|
4420
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
|
|
4421
|
-
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
|
|
4422
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
|
|
4423
|
-
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
|
|
4424
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
|
|
4425
|
-
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
|
|
4426
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
|
|
4427
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
|
4428
|
-
#endif
|
|
5449
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
|
5450
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
|
5451
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
|
5452
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
|
5453
|
+
|
|
4429
5454
|
scale += 4;
|
|
4430
5455
|
|
|
4431
5456
|
if (j == 0) {
|
|
@@ -4864,10 +5889,7 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
4864
5889
|
const int nb = n / QK_K;
|
|
4865
5890
|
|
|
4866
5891
|
#ifdef __ARM_NEON
|
|
4867
|
-
|
|
4868
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
|
4869
|
-
const int32x4_t vzero = vdupq_n_s32(0);
|
|
4870
|
-
#endif
|
|
5892
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
|
4871
5893
|
|
|
4872
5894
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
|
4873
5895
|
const uint8x16_t mh = vdupq_n_u8(4);
|
|
@@ -4908,22 +5930,10 @@ void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
4908
5930
|
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
|
4909
5931
|
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
|
4910
5932
|
|
|
4911
|
-
|
|
4912
|
-
isum += vaddvq_s32(
|
|
4913
|
-
isum += vaddvq_s32(
|
|
4914
|
-
isum += vaddvq_s32(
|
|
4915
|
-
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
|
4916
|
-
#else
|
|
4917
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
4918
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
4919
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
4920
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
4921
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
|
4922
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
|
4923
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
|
4924
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
|
4925
|
-
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
|
|
4926
|
-
#endif
|
|
5933
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
|
5934
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
|
5935
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
|
5936
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
|
4927
5937
|
|
|
4928
5938
|
sum += d * isum;
|
|
4929
5939
|
|
|
@@ -5228,11 +6238,8 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
5228
6238
|
uint32_t utmp[4];
|
|
5229
6239
|
|
|
5230
6240
|
#ifdef __ARM_NEON
|
|
5231
|
-
|
|
5232
6241
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
5233
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
|
5234
6242
|
const int32x4_t mzero = vdupq_n_s32(0);
|
|
5235
|
-
#endif
|
|
5236
6243
|
|
|
5237
6244
|
wsp_ggml_int8x16x2_t q4bytes;
|
|
5238
6245
|
wsp_ggml_int8x16x2_t q8bytes;
|
|
@@ -5269,44 +6276,22 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
5269
6276
|
int32_t sumi2 = 0;
|
|
5270
6277
|
|
|
5271
6278
|
for (int j = 0; j < QK_K/64; ++j) {
|
|
5272
|
-
|
|
5273
6279
|
const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); q4 += 32;
|
|
5274
6280
|
|
|
5275
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
|
5276
6281
|
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
5277
6282
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
|
5278
6283
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
|
5279
6284
|
|
|
5280
|
-
const int32x4_t p1 =
|
|
5281
|
-
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
|
5282
|
-
|
|
5283
|
-
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
5284
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
|
5285
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
|
5286
|
-
|
|
5287
|
-
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
|
5288
|
-
|
|
5289
|
-
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
|
5290
|
-
#else
|
|
5291
|
-
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
5292
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
|
5293
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
|
5294
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
5295
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
5296
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
5297
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
5298
|
-
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
|
|
6285
|
+
const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
|
6286
|
+
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
|
5299
6287
|
|
|
5300
6288
|
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
5301
6289
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
|
5302
6290
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
|
5303
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
5304
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
5305
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
5306
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
5307
|
-
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
|
|
5308
6291
|
|
|
5309
|
-
|
|
6292
|
+
const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
|
6293
|
+
|
|
6294
|
+
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
|
5310
6295
|
}
|
|
5311
6296
|
|
|
5312
6297
|
sumf += d * (sumi1 + sumi2);
|
|
@@ -5603,12 +6588,9 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
5603
6588
|
const int nb = n / QK_K;
|
|
5604
6589
|
|
|
5605
6590
|
#ifdef __ARM_NEON
|
|
5606
|
-
|
|
5607
6591
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
5608
6592
|
|
|
5609
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
|
5610
6593
|
const int32x4_t mzero = vdupq_n_s32(0);
|
|
5611
|
-
#endif
|
|
5612
6594
|
|
|
5613
6595
|
float sumf = 0;
|
|
5614
6596
|
|
|
@@ -5636,41 +6618,20 @@ void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
5636
6618
|
|
|
5637
6619
|
const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4);
|
|
5638
6620
|
|
|
5639
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
|
5640
6621
|
q8bytes = wsp_ggml_vld1q_s8_x4(q8);
|
|
5641
6622
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
|
5642
6623
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
|
5643
6624
|
|
|
5644
|
-
const int32x4_t p1 =
|
|
6625
|
+
const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
|
5645
6626
|
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
|
5646
6627
|
|
|
5647
6628
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
|
5648
6629
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
|
5649
6630
|
|
|
5650
|
-
const int32x4_t p2 =
|
|
6631
|
+
const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
|
5651
6632
|
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
|
5652
6633
|
|
|
5653
|
-
#else
|
|
5654
|
-
q8bytes = wsp_ggml_vld1q_s8_x4(q8);
|
|
5655
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
|
5656
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
|
5657
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
5658
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
5659
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
5660
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
5661
|
-
int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
|
|
5662
|
-
|
|
5663
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
|
5664
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
|
5665
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
|
|
5666
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
|
|
5667
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
|
|
5668
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
|
|
5669
|
-
int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
|
|
5670
|
-
|
|
5671
|
-
#endif
|
|
5672
6634
|
sumf += d * (sumi1 + sumi2);
|
|
5673
|
-
|
|
5674
6635
|
}
|
|
5675
6636
|
|
|
5676
6637
|
*s = sumf - sum_mins;
|
|
@@ -5875,15 +6836,11 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
5875
6836
|
|
|
5876
6837
|
uint32_t utmp[4];
|
|
5877
6838
|
|
|
5878
|
-
|
|
5879
6839
|
#ifdef __ARM_NEON
|
|
5880
|
-
|
|
5881
6840
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
5882
6841
|
const uint8x16_t mone = vdupq_n_u8(1);
|
|
5883
6842
|
const uint8x16_t mtwo = vdupq_n_u8(2);
|
|
5884
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
5885
6843
|
const int32x4_t mzero = vdupq_n_s32(0);
|
|
5886
|
-
#endif
|
|
5887
6844
|
|
|
5888
6845
|
wsp_ggml_int8x16x4_t q5bytes;
|
|
5889
6846
|
|
|
@@ -5938,28 +6895,11 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
5938
6895
|
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
|
5939
6896
|
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
|
5940
6897
|
|
|
5941
|
-
|
|
5942
|
-
|
|
5943
|
-
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
|
5944
|
-
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
|
5945
|
-
#else
|
|
5946
|
-
|
|
5947
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
5948
|
-
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
5949
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
5950
|
-
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
5951
|
-
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
|
|
5952
|
-
|
|
5953
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
|
5954
|
-
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
|
5955
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
|
5956
|
-
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
|
5957
|
-
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
|
|
5958
|
-
#endif
|
|
6898
|
+
sumi += vaddvq_s32(wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
|
6899
|
+
sumi += vaddvq_s32(wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
|
5959
6900
|
}
|
|
5960
6901
|
|
|
5961
6902
|
sumf += d * sumi - dmin * sumi_mins;
|
|
5962
|
-
|
|
5963
6903
|
}
|
|
5964
6904
|
|
|
5965
6905
|
*s = sumf;
|
|
@@ -6311,12 +7251,9 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
6311
7251
|
const int nb = n / QK_K;
|
|
6312
7252
|
|
|
6313
7253
|
#ifdef __ARM_NEON
|
|
6314
|
-
|
|
6315
7254
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
6316
7255
|
const uint8x16_t mh = vdupq_n_u8(16);
|
|
6317
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
6318
7256
|
const int32x4_t mzero = vdupq_n_s32(0);
|
|
6319
|
-
#endif
|
|
6320
7257
|
|
|
6321
7258
|
wsp_ggml_int8x16x4_t q5bytes;
|
|
6322
7259
|
wsp_ggml_uint8x16x4_t q5h;
|
|
@@ -6348,32 +7285,12 @@ void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
6348
7285
|
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
|
6349
7286
|
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
|
6350
7287
|
|
|
6351
|
-
|
|
6352
|
-
|
|
6353
|
-
int32_t
|
|
6354
|
-
int32_t
|
|
6355
|
-
int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
|
6356
|
-
int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
|
7288
|
+
int32_t sumi1 = sc[0] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
|
7289
|
+
int32_t sumi2 = sc[1] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
|
7290
|
+
int32_t sumi3 = sc[2] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
|
7291
|
+
int32_t sumi4 = sc[3] * vaddvq_s32(wsp_ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
|
6357
7292
|
|
|
6358
7293
|
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
|
6359
|
-
|
|
6360
|
-
#else
|
|
6361
|
-
|
|
6362
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
6363
|
-
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
6364
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
6365
|
-
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
6366
|
-
int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
|
|
6367
|
-
|
|
6368
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
|
6369
|
-
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
|
6370
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
|
6371
|
-
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
|
6372
|
-
sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
|
|
6373
|
-
|
|
6374
|
-
sumf += d*sumi;
|
|
6375
|
-
#endif
|
|
6376
|
-
|
|
6377
7294
|
}
|
|
6378
7295
|
|
|
6379
7296
|
*s = sumf;
|
|
@@ -6600,13 +7517,10 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
6600
7517
|
const int nb = n / QK_K;
|
|
6601
7518
|
|
|
6602
7519
|
#ifdef __ARM_NEON
|
|
6603
|
-
|
|
6604
7520
|
float sum = 0;
|
|
6605
7521
|
|
|
6606
7522
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
|
6607
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
6608
7523
|
const int32x4_t vzero = vdupq_n_s32(0);
|
|
6609
|
-
#endif
|
|
6610
7524
|
//const int8x16_t m32s = vdupq_n_s8(32);
|
|
6611
7525
|
|
|
6612
7526
|
const uint8x16_t mone = vdupq_n_u8(3);
|
|
@@ -6626,7 +7540,7 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
6626
7540
|
|
|
6627
7541
|
const wsp_ggml_int16x8x2_t q8sums = wsp_ggml_vld1q_s16_x2(y[i].bsums);
|
|
6628
7542
|
const int8x16_t scales = vld1q_s8(scale);
|
|
6629
|
-
const wsp_ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
|
|
7543
|
+
const wsp_ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
|
|
6630
7544
|
|
|
6631
7545
|
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
|
6632
7546
|
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
|
@@ -6658,31 +7572,13 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
6658
7572
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
|
6659
7573
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
|
6660
7574
|
|
|
6661
|
-
|
|
7575
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
|
7576
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
|
7577
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
|
7578
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
|
6662
7579
|
|
|
6663
|
-
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
|
6664
|
-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
|
6665
|
-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
|
6666
|
-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
|
6667
7580
|
scale += 4;
|
|
6668
7581
|
|
|
6669
|
-
#else
|
|
6670
|
-
|
|
6671
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
6672
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
6673
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
6674
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
6675
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
|
6676
|
-
scale += 2;
|
|
6677
|
-
|
|
6678
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
|
6679
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
|
6680
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
|
6681
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
|
6682
|
-
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
|
6683
|
-
scale += 2;
|
|
6684
|
-
#endif
|
|
6685
|
-
|
|
6686
7582
|
q8bytes = wsp_ggml_vld1q_s8_x4(q8); q8 += 64;
|
|
6687
7583
|
|
|
6688
7584
|
shifted = vshrq_n_u8(qhbits.val[0], 4);
|
|
@@ -6703,34 +7599,11 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
6703
7599
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
|
6704
7600
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
|
6705
7601
|
|
|
6706
|
-
|
|
6707
|
-
|
|
6708
|
-
|
|
6709
|
-
vaddvq_s32(
|
|
6710
|
-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
|
6711
|
-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
|
7602
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
|
7603
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
|
7604
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
|
7605
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
|
6712
7606
|
scale += 4;
|
|
6713
|
-
|
|
6714
|
-
//for (int l = 0; l < 4; ++l) {
|
|
6715
|
-
// const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
|
|
6716
|
-
// isum += vaddvq_s32(p) * *scale++;
|
|
6717
|
-
//}
|
|
6718
|
-
#else
|
|
6719
|
-
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
6720
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
6721
|
-
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
6722
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
6723
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
|
6724
|
-
scale += 2;
|
|
6725
|
-
|
|
6726
|
-
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
|
6727
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
|
6728
|
-
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
|
6729
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
|
6730
|
-
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
|
6731
|
-
scale += 2;
|
|
6732
|
-
#endif
|
|
6733
|
-
|
|
6734
7607
|
}
|
|
6735
7608
|
//sum += isum * d_all * y[i].d;
|
|
6736
7609
|
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
|
@@ -7076,14 +7949,11 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
7076
7949
|
const int nb = n / QK_K;
|
|
7077
7950
|
|
|
7078
7951
|
#ifdef __ARM_NEON
|
|
7079
|
-
|
|
7080
7952
|
float sum = 0;
|
|
7081
7953
|
|
|
7082
7954
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
|
7083
7955
|
const int8x16_t m32s = vdupq_n_s8(32);
|
|
7084
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
|
7085
7956
|
const int32x4_t vzero = vdupq_n_s32(0);
|
|
7086
|
-
#endif
|
|
7087
7957
|
|
|
7088
7958
|
const uint8x16_t mone = vdupq_n_u8(3);
|
|
7089
7959
|
|
|
@@ -7119,26 +7989,10 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
7119
7989
|
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
|
7120
7990
|
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
|
7121
7991
|
|
|
7122
|
-
|
|
7123
|
-
|
|
7124
|
-
|
|
7125
|
-
vaddvq_s32(
|
|
7126
|
-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
|
7127
|
-
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
|
7128
|
-
#else
|
|
7129
|
-
|
|
7130
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
|
7131
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
|
7132
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
|
7133
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
|
7134
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
|
7135
|
-
|
|
7136
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
|
7137
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
|
7138
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
|
7139
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
|
7140
|
-
isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
|
7141
|
-
#endif
|
|
7992
|
+
isum += vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
|
7993
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
|
7994
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
|
7995
|
+
vaddvq_s32(wsp_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
|
7142
7996
|
|
|
7143
7997
|
sum += isum * d_all * y[i].d;
|
|
7144
7998
|
|
|
@@ -7380,3 +8234,958 @@ void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * re
|
|
|
7380
8234
|
}
|
|
7381
8235
|
|
|
7382
8236
|
#endif
|
|
8237
|
+
|
|
8238
|
+
static const int8_t keven_signs_q2xs[1024] = {
|
|
8239
|
+
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
|
8240
|
+
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
|
8241
|
+
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
|
|
8242
|
+
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
|
|
8243
|
+
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
|
|
8244
|
+
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
|
|
8245
|
+
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
|
|
8246
|
+
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
|
|
8247
|
+
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
|
|
8248
|
+
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
|
|
8249
|
+
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
|
|
8250
|
+
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
|
|
8251
|
+
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
|
|
8252
|
+
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
|
|
8253
|
+
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
|
|
8254
|
+
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
|
|
8255
|
+
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
|
|
8256
|
+
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
|
|
8257
|
+
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
|
|
8258
|
+
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
|
|
8259
|
+
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
|
|
8260
|
+
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
|
|
8261
|
+
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
|
|
8262
|
+
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
|
|
8263
|
+
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
|
|
8264
|
+
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
|
|
8265
|
+
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
|
|
8266
|
+
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
|
|
8267
|
+
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
|
|
8268
|
+
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
|
|
8269
|
+
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
|
8270
|
+
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
|
8271
|
+
};
|
|
8272
|
+
|
|
8273
|
+
void wsp_ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
8274
|
+
assert(n % QK_K == 0);
|
|
8275
|
+
|
|
8276
|
+
const block_iq2_xxs * restrict x = vx;
|
|
8277
|
+
const block_q8_K * restrict y = vy;
|
|
8278
|
+
|
|
8279
|
+
const int nb = n / QK_K;
|
|
8280
|
+
|
|
8281
|
+
#if defined(__ARM_NEON)
|
|
8282
|
+
|
|
8283
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
8284
|
+
|
|
8285
|
+
uint32_t aux32[4];
|
|
8286
|
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
|
8287
|
+
|
|
8288
|
+
wsp_ggml_int8x16x4_t q2u;
|
|
8289
|
+
wsp_ggml_int8x16x4_t q2s;
|
|
8290
|
+
wsp_ggml_int8x16x4_t q8b;
|
|
8291
|
+
|
|
8292
|
+
float sumf = 0;
|
|
8293
|
+
for (int i = 0; i < nb; ++i) {
|
|
8294
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
8295
|
+
const uint16_t * restrict q2 = x[i].qs;
|
|
8296
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
8297
|
+
float sumf1 = 0, sumf2 = 0;
|
|
8298
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
|
8299
|
+
q8b = wsp_ggml_vld1q_s8_x4(q8); q8 += 64;
|
|
8300
|
+
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
|
8301
|
+
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
|
|
8302
|
+
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
|
|
8303
|
+
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
|
|
8304
|
+
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
|
|
8305
|
+
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
|
8306
|
+
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
|
8307
|
+
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
|
|
8308
|
+
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
|
|
8309
|
+
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
|
|
8310
|
+
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
|
|
8311
|
+
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
|
|
8312
|
+
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
|
|
8313
|
+
const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
|
|
8314
|
+
const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
|
|
8315
|
+
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
|
|
8316
|
+
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
|
|
8317
|
+
}
|
|
8318
|
+
sumf += d*(sumf1 + sumf2);
|
|
8319
|
+
}
|
|
8320
|
+
*s = 0.25f * sumf;
|
|
8321
|
+
|
|
8322
|
+
#elif defined(__AVX2__)
|
|
8323
|
+
|
|
8324
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
8325
|
+
|
|
8326
|
+
uint32_t aux32[4];
|
|
8327
|
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
|
8328
|
+
|
|
8329
|
+
__m256 accumf = _mm256_setzero_ps();
|
|
8330
|
+
for (int i = 0; i < nb; ++i) {
|
|
8331
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
8332
|
+
const uint16_t * restrict q2 = x[i].qs;
|
|
8333
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
8334
|
+
__m256i sumi1 = _mm256_setzero_si256();
|
|
8335
|
+
__m256i sumi2 = _mm256_setzero_si256();
|
|
8336
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
|
8337
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
|
8338
|
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
|
8339
|
+
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
|
8340
|
+
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
|
|
8341
|
+
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
|
|
8342
|
+
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
|
|
8343
|
+
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
|
8344
|
+
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
|
|
8345
|
+
signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
|
|
8346
|
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
|
8347
|
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
|
8348
|
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
|
8349
|
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
|
8350
|
+
const uint16_t ls1 = aux32[1] >> 28;
|
|
8351
|
+
const uint16_t ls2 = aux32[3] >> 28;
|
|
8352
|
+
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
|
|
8353
|
+
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
|
|
8354
|
+
sumi1 = _mm256_add_epi32(sumi1, p1);
|
|
8355
|
+
sumi2 = _mm256_add_epi32(sumi2, p2);
|
|
8356
|
+
}
|
|
8357
|
+
|
|
8358
|
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
|
8359
|
+
|
|
8360
|
+
}
|
|
8361
|
+
|
|
8362
|
+
*s = 0.125f * hsum_float_8(accumf);
|
|
8363
|
+
|
|
8364
|
+
#else
|
|
8365
|
+
|
|
8366
|
+
uint32_t aux32[2];
|
|
8367
|
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
|
8368
|
+
|
|
8369
|
+
float sumf = 0.f;
|
|
8370
|
+
for (int i = 0; i < nb; ++i) {
|
|
8371
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
8372
|
+
const uint16_t * restrict q2 = x[i].qs;
|
|
8373
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
8374
|
+
int32_t bsum = 0;
|
|
8375
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
|
8376
|
+
memcpy(aux32, q2, 2*sizeof(uint32_t));
|
|
8377
|
+
q2 += 4;
|
|
8378
|
+
const uint32_t ls = 2*(aux32[1] >> 28) + 1;
|
|
8379
|
+
int32_t sumi = 0;
|
|
8380
|
+
for (int l = 0; l < 4; ++l) {
|
|
8381
|
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
|
8382
|
+
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
|
|
8383
|
+
for (int j = 0; j < 8; ++j) {
|
|
8384
|
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
|
8385
|
+
}
|
|
8386
|
+
q8 += 8;
|
|
8387
|
+
}
|
|
8388
|
+
bsum += sumi * ls;
|
|
8389
|
+
}
|
|
8390
|
+
sumf += d * bsum;
|
|
8391
|
+
}
|
|
8392
|
+
*s = 0.125f * sumf;
|
|
8393
|
+
#endif
|
|
8394
|
+
}
|
|
8395
|
+
|
|
8396
|
+
void wsp_ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
8397
|
+
assert(n % QK_K == 0);
|
|
8398
|
+
|
|
8399
|
+
const block_iq2_xs * restrict x = vx;
|
|
8400
|
+
const block_q8_K * restrict y = vy;
|
|
8401
|
+
|
|
8402
|
+
const int nb = n / QK_K;
|
|
8403
|
+
|
|
8404
|
+
#if defined(__ARM_NEON)
|
|
8405
|
+
|
|
8406
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
8407
|
+
|
|
8408
|
+
wsp_ggml_int8x16x4_t q2u;
|
|
8409
|
+
wsp_ggml_int8x16x4_t q2s;
|
|
8410
|
+
wsp_ggml_int8x16x4_t q8b;
|
|
8411
|
+
|
|
8412
|
+
int32x4x4_t scales32;
|
|
8413
|
+
|
|
8414
|
+
float sumf = 0;
|
|
8415
|
+
for (int i = 0; i < nb; ++i) {
|
|
8416
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
8417
|
+
const uint16_t * restrict q2 = x[i].qs;
|
|
8418
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
8419
|
+
const uint8x8_t scales8 = vld1_u8(x[i].scales);
|
|
8420
|
+
const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
|
|
8421
|
+
const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
|
|
8422
|
+
uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
|
|
8423
|
+
scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
|
|
8424
|
+
const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
|
|
8425
|
+
const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
|
|
8426
|
+
scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
|
|
8427
|
+
scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
|
|
8428
|
+
scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
|
|
8429
|
+
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
|
|
8430
|
+
int32x4_t sumi = vdupq_n_s32(0);
|
|
8431
|
+
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
|
|
8432
|
+
q8b = wsp_ggml_vld1q_s8_x4(q8); q8 += 64;
|
|
8433
|
+
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
|
8434
|
+
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
|
8435
|
+
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
|
|
8436
|
+
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
|
|
8437
|
+
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
|
|
8438
|
+
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
|
|
8439
|
+
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
|
|
8440
|
+
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
|
|
8441
|
+
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
|
|
8442
|
+
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
|
|
8443
|
+
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
|
|
8444
|
+
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
|
|
8445
|
+
const int32x4_t p1 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
|
|
8446
|
+
const int32x4_t p2 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
|
|
8447
|
+
const int32x4_t p3 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
|
|
8448
|
+
const int32x4_t p4 = wsp_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
|
|
8449
|
+
const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
|
|
8450
|
+
sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
|
|
8451
|
+
q2 += 8;
|
|
8452
|
+
}
|
|
8453
|
+
sumf += d*vaddvq_s32(sumi);
|
|
8454
|
+
}
|
|
8455
|
+
*s = 0.125f * sumf;
|
|
8456
|
+
|
|
8457
|
+
#elif defined(__AVX2__)
|
|
8458
|
+
|
|
8459
|
+
const __m128i m4 = _mm_set1_epi8(0xf);
|
|
8460
|
+
const __m128i m1 = _mm_set1_epi8(1);
|
|
8461
|
+
const __m128i m511 = _mm_set1_epi16(511);
|
|
8462
|
+
const __m128i m127 = _mm_set1_epi16(127);
|
|
8463
|
+
|
|
8464
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
8465
|
+
|
|
8466
|
+
uint64_t aux64;
|
|
8467
|
+
|
|
8468
|
+
// somewhat hacky, but gives a significant boost in performance
|
|
8469
|
+
__m128i aux_gindex, aux_sindex;
|
|
8470
|
+
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
|
|
8471
|
+
const uint16_t * sindex = (const uint16_t *)&aux_sindex;
|
|
8472
|
+
|
|
8473
|
+
__m256 accumf = _mm256_setzero_ps();
|
|
8474
|
+
for (int i = 0; i < nb; ++i) {
|
|
8475
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
8476
|
+
const uint16_t * restrict q2 = x[i].qs;
|
|
8477
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
8478
|
+
|
|
8479
|
+
memcpy(&aux64, x[i].scales, 8);
|
|
8480
|
+
__m128i stmp = _mm_set1_epi64x(aux64);
|
|
8481
|
+
stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
|
|
8482
|
+
const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
|
|
8483
|
+
|
|
8484
|
+
__m256i sumi1 = _mm256_setzero_si256();
|
|
8485
|
+
__m256i sumi2 = _mm256_setzero_si256();
|
|
8486
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
|
8487
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
|
8488
|
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
|
8489
|
+
const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
|
|
8490
|
+
aux_gindex = _mm_and_si128(q2_data, m511);
|
|
8491
|
+
aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
|
|
8492
|
+
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
|
|
8493
|
+
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
|
|
8494
|
+
const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
|
|
8495
|
+
const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
|
|
8496
|
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
|
8497
|
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
|
8498
|
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
|
8499
|
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
|
8500
|
+
|
|
8501
|
+
const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
|
|
8502
|
+
const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
|
|
8503
|
+
|
|
8504
|
+
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
|
|
8505
|
+
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
|
|
8506
|
+
}
|
|
8507
|
+
|
|
8508
|
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
|
8509
|
+
|
|
8510
|
+
}
|
|
8511
|
+
|
|
8512
|
+
*s = 0.125f * hsum_float_8(accumf);
|
|
8513
|
+
|
|
8514
|
+
#else
|
|
8515
|
+
|
|
8516
|
+
float sumf = 0.f;
|
|
8517
|
+
for (int i = 0; i < nb; ++i) {
|
|
8518
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
8519
|
+
const uint16_t * restrict q2 = x[i].qs;
|
|
8520
|
+
const uint8_t * restrict sc = x[i].scales;
|
|
8521
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
8522
|
+
int32_t bsum = 0;
|
|
8523
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
|
8524
|
+
const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
|
|
8525
|
+
const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
|
|
8526
|
+
int32_t sumi = 0;
|
|
8527
|
+
for (int l = 0; l < 2; ++l) {
|
|
8528
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
|
8529
|
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
|
8530
|
+
for (int j = 0; j < 8; ++j) {
|
|
8531
|
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
|
8532
|
+
}
|
|
8533
|
+
q8 += 8;
|
|
8534
|
+
}
|
|
8535
|
+
bsum += sumi * ls1;
|
|
8536
|
+
sumi = 0;
|
|
8537
|
+
for (int l = 2; l < 4; ++l) {
|
|
8538
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
|
8539
|
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
|
8540
|
+
for (int j = 0; j < 8; ++j) {
|
|
8541
|
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
|
8542
|
+
}
|
|
8543
|
+
q8 += 8;
|
|
8544
|
+
}
|
|
8545
|
+
bsum += sumi * ls2;
|
|
8546
|
+
q2 += 4;
|
|
8547
|
+
}
|
|
8548
|
+
sumf += d * bsum;
|
|
8549
|
+
}
|
|
8550
|
+
*s = 0.125f * sumf;
|
|
8551
|
+
#endif
|
|
8552
|
+
}
|
|
8553
|
+
|
|
8554
|
+
// ================================ IQ2 quantization =============================================
|
|
8555
|
+
|
|
8556
|
+
typedef struct {
|
|
8557
|
+
uint64_t * grid;
|
|
8558
|
+
int * map;
|
|
8559
|
+
uint16_t * neighbours;
|
|
8560
|
+
} iq2_entry_t;
|
|
8561
|
+
|
|
8562
|
+
static iq2_entry_t iq2_data[2] = {
|
|
8563
|
+
{NULL, NULL, NULL},
|
|
8564
|
+
{NULL, NULL, NULL},
|
|
8565
|
+
};
|
|
8566
|
+
|
|
8567
|
+
static inline int iq2_data_index(int grid_size) {
|
|
8568
|
+
WSP_GGML_ASSERT(grid_size == 256 || grid_size == 512);
|
|
8569
|
+
return grid_size == 256 ? 0 : 1;
|
|
8570
|
+
}
|
|
8571
|
+
|
|
8572
|
+
static int iq2_compare_func(const void * left, const void * right) {
|
|
8573
|
+
const int * l = (const int *)left;
|
|
8574
|
+
const int * r = (const int *)right;
|
|
8575
|
+
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
|
|
8576
|
+
}
|
|
8577
|
+
|
|
8578
|
+
void iq2xs_init_impl(int grid_size) {
|
|
8579
|
+
const int gindex = iq2_data_index(grid_size);
|
|
8580
|
+
if (iq2_data[gindex].grid) {
|
|
8581
|
+
return;
|
|
8582
|
+
}
|
|
8583
|
+
static const uint16_t kgrid_256[256] = {
|
|
8584
|
+
0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
|
|
8585
|
+
100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
|
|
8586
|
+
1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
|
|
8587
|
+
1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113,
|
|
8588
|
+
2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240,
|
|
8589
|
+
4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400,
|
|
8590
|
+
5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260,
|
|
8591
|
+
8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872,
|
|
8592
|
+
10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
|
|
8593
|
+
16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
|
|
8594
|
+
17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
|
|
8595
|
+
20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
|
|
8596
|
+
22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
|
|
8597
|
+
25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
|
|
8598
|
+
33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
|
|
8599
|
+
37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
|
|
8600
|
+
};
|
|
8601
|
+
static const uint16_t kgrid_512[512] = {
|
|
8602
|
+
0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
|
|
8603
|
+
73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
|
|
8604
|
+
260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
|
|
8605
|
+
352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597,
|
|
8606
|
+
640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096,
|
|
8607
|
+
1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348,
|
|
8608
|
+
1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065,
|
|
8609
|
+
2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441,
|
|
8610
|
+
2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160,
|
|
8611
|
+
4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372,
|
|
8612
|
+
4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125,
|
|
8613
|
+
5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652,
|
|
8614
|
+
5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197,
|
|
8615
|
+
8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549,
|
|
8616
|
+
8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894,
|
|
8617
|
+
10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
|
|
8618
|
+
16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
|
|
8619
|
+
16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
|
|
8620
|
+
16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
|
|
8621
|
+
17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
|
|
8622
|
+
18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
|
|
8623
|
+
20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
|
|
8624
|
+
21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
|
|
8625
|
+
22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
|
|
8626
|
+
24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
|
|
8627
|
+
32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
|
|
8628
|
+
33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
|
|
8629
|
+
33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
|
|
8630
|
+
35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
|
|
8631
|
+
37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
|
|
8632
|
+
40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
|
|
8633
|
+
42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
|
|
8634
|
+
};
|
|
8635
|
+
const int kmap_size = 43692;
|
|
8636
|
+
const int nwant = 2;
|
|
8637
|
+
const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
|
|
8638
|
+
uint64_t * kgrid_q2xs;
|
|
8639
|
+
int * kmap_q2xs;
|
|
8640
|
+
uint16_t * kneighbors_q2xs;
|
|
8641
|
+
|
|
8642
|
+
printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
|
|
8643
|
+
uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
|
|
8644
|
+
for (int k = 0; k < grid_size; ++k) {
|
|
8645
|
+
int8_t * pos = (int8_t *)(the_grid + k);
|
|
8646
|
+
for (int i = 0; i < 8; ++i) {
|
|
8647
|
+
int l = (kgrid[k] >> 2*i) & 0x3;
|
|
8648
|
+
pos[i] = 2*l + 1;
|
|
8649
|
+
}
|
|
8650
|
+
}
|
|
8651
|
+
kgrid_q2xs = the_grid;
|
|
8652
|
+
iq2_data[gindex].grid = the_grid;
|
|
8653
|
+
kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
|
|
8654
|
+
iq2_data[gindex].map = kmap_q2xs;
|
|
8655
|
+
for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
|
|
8656
|
+
uint64_t aux64;
|
|
8657
|
+
uint8_t * aux8 = (uint8_t *)&aux64;
|
|
8658
|
+
for (int i = 0; i < grid_size; ++i) {
|
|
8659
|
+
aux64 = kgrid_q2xs[i];
|
|
8660
|
+
uint16_t index = 0;
|
|
8661
|
+
for (int k=0; k<8; ++k) {
|
|
8662
|
+
uint16_t q = (aux8[k] - 1)/2;
|
|
8663
|
+
index |= (q << 2*k);
|
|
8664
|
+
}
|
|
8665
|
+
kmap_q2xs[index] = i;
|
|
8666
|
+
}
|
|
8667
|
+
int8_t pos[8];
|
|
8668
|
+
int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
|
|
8669
|
+
int num_neighbors = 0, num_not_in_map = 0;
|
|
8670
|
+
for (int i = 0; i < kmap_size; ++i) {
|
|
8671
|
+
if (kmap_q2xs[i] >= 0) continue;
|
|
8672
|
+
++num_not_in_map;
|
|
8673
|
+
for (int k = 0; k < 8; ++k) {
|
|
8674
|
+
int l = (i >> 2*k) & 0x3;
|
|
8675
|
+
pos[k] = 2*l + 1;
|
|
8676
|
+
}
|
|
8677
|
+
for (int j = 0; j < grid_size; ++j) {
|
|
8678
|
+
const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
|
|
8679
|
+
int d2 = 0;
|
|
8680
|
+
for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
|
|
8681
|
+
dist2[2*j+0] = d2;
|
|
8682
|
+
dist2[2*j+1] = j;
|
|
8683
|
+
}
|
|
8684
|
+
qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
|
|
8685
|
+
int n = 0; int d2 = dist2[0];
|
|
8686
|
+
int nhave = 1;
|
|
8687
|
+
for (int j = 0; j < grid_size; ++j) {
|
|
8688
|
+
if (dist2[2*j] > d2) {
|
|
8689
|
+
if (nhave == nwant) break;
|
|
8690
|
+
d2 = dist2[2*j];
|
|
8691
|
+
++nhave;
|
|
8692
|
+
}
|
|
8693
|
+
++n;
|
|
8694
|
+
}
|
|
8695
|
+
num_neighbors += n;
|
|
8696
|
+
}
|
|
8697
|
+
printf("%s: %d neighbours in total\n", __func__, num_neighbors);
|
|
8698
|
+
kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
|
|
8699
|
+
iq2_data[gindex].neighbours = kneighbors_q2xs;
|
|
8700
|
+
int counter = 0;
|
|
8701
|
+
for (int i = 0; i < kmap_size; ++i) {
|
|
8702
|
+
if (kmap_q2xs[i] >= 0) continue;
|
|
8703
|
+
for (int k = 0; k < 8; ++k) {
|
|
8704
|
+
int l = (i >> 2*k) & 0x3;
|
|
8705
|
+
pos[k] = 2*l + 1;
|
|
8706
|
+
}
|
|
8707
|
+
for (int j = 0; j < grid_size; ++j) {
|
|
8708
|
+
const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
|
|
8709
|
+
int d2 = 0;
|
|
8710
|
+
for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
|
|
8711
|
+
dist2[2*j+0] = d2;
|
|
8712
|
+
dist2[2*j+1] = j;
|
|
8713
|
+
}
|
|
8714
|
+
qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
|
|
8715
|
+
kmap_q2xs[i] = -(counter + 1);
|
|
8716
|
+
int d2 = dist2[0];
|
|
8717
|
+
uint16_t * start = &kneighbors_q2xs[counter++];
|
|
8718
|
+
int n = 0, nhave = 1;
|
|
8719
|
+
for (int j = 0; j < grid_size; ++j) {
|
|
8720
|
+
if (dist2[2*j] > d2) {
|
|
8721
|
+
if (nhave == nwant) break;
|
|
8722
|
+
d2 = dist2[2*j];
|
|
8723
|
+
++nhave;
|
|
8724
|
+
}
|
|
8725
|
+
kneighbors_q2xs[counter++] = dist2[2*j+1];
|
|
8726
|
+
++n;
|
|
8727
|
+
}
|
|
8728
|
+
*start = n;
|
|
8729
|
+
}
|
|
8730
|
+
free(dist2);
|
|
8731
|
+
}
|
|
8732
|
+
|
|
8733
|
+
void iq2xs_free_impl(int grid_size) {
|
|
8734
|
+
WSP_GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
|
|
8735
|
+
const int gindex = iq2_data_index(grid_size);
|
|
8736
|
+
if (iq2_data[gindex].grid) {
|
|
8737
|
+
free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
|
|
8738
|
+
free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
|
|
8739
|
+
free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
|
|
8740
|
+
}
|
|
8741
|
+
}
|
|
8742
|
+
|
|
8743
|
+
static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
|
8744
|
+
const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
|
|
8745
|
+
int num_neighbors = neighbours[0];
|
|
8746
|
+
WSP_GGML_ASSERT(num_neighbors > 0);
|
|
8747
|
+
float best_d2 = FLT_MAX;
|
|
8748
|
+
int grid_index = -1;
|
|
8749
|
+
for (int j = 1; j <= num_neighbors; ++j) {
|
|
8750
|
+
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
|
8751
|
+
float d2 = 0;
|
|
8752
|
+
for (int i = 0; i < 8; ++i) {
|
|
8753
|
+
float q = pg[i];
|
|
8754
|
+
float diff = scale*q - xval[i];
|
|
8755
|
+
d2 += weight[i]*diff*diff;
|
|
8756
|
+
}
|
|
8757
|
+
if (d2 < best_d2) {
|
|
8758
|
+
best_d2 = d2; grid_index = neighbours[j];
|
|
8759
|
+
}
|
|
8760
|
+
}
|
|
8761
|
+
WSP_GGML_ASSERT(grid_index >= 0);
|
|
8762
|
+
const int8_t * pg = (const int8_t *)(grid + grid_index);
|
|
8763
|
+
for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
|
|
8764
|
+
return grid_index;
|
|
8765
|
+
}
|
|
8766
|
+
|
|
8767
|
+
static void wsp_quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
|
8768
|
+
|
|
8769
|
+
const int gindex = iq2_data_index(256);
|
|
8770
|
+
|
|
8771
|
+
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
|
8772
|
+
const int * kmap_q2xs = iq2_data[gindex].map;
|
|
8773
|
+
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
|
|
8774
|
+
|
|
8775
|
+
WSP_GGML_ASSERT(quant_weights && "missing quantization weights");
|
|
8776
|
+
WSP_GGML_ASSERT(kgrid_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
|
|
8777
|
+
WSP_GGML_ASSERT(kmap_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
|
|
8778
|
+
WSP_GGML_ASSERT(kneighbors_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
|
|
8779
|
+
WSP_GGML_ASSERT(n%QK_K == 0);
|
|
8780
|
+
|
|
8781
|
+
const int kMaxQ = 3;
|
|
8782
|
+
|
|
8783
|
+
const int nbl = n/256;
|
|
8784
|
+
|
|
8785
|
+
block_iq2_xxs * y = vy;
|
|
8786
|
+
|
|
8787
|
+
float scales[QK_K/32];
|
|
8788
|
+
float weight[32];
|
|
8789
|
+
float xval[32];
|
|
8790
|
+
int8_t L[32];
|
|
8791
|
+
int8_t Laux[32];
|
|
8792
|
+
float waux[32];
|
|
8793
|
+
bool is_on_grid[4];
|
|
8794
|
+
bool is_on_grid_aux[4];
|
|
8795
|
+
uint8_t block_signs[4];
|
|
8796
|
+
uint32_t q2[2*(QK_K/32)];
|
|
8797
|
+
|
|
8798
|
+
for (int ibl = 0; ibl < nbl; ++ibl) {
|
|
8799
|
+
|
|
8800
|
+
y[ibl].d = WSP_GGML_FP32_TO_FP16(0.f);
|
|
8801
|
+
memset(q2, 0, QK_K/4);
|
|
8802
|
+
|
|
8803
|
+
float max_scale = 0;
|
|
8804
|
+
|
|
8805
|
+
const float * xbl = x + QK_K*ibl;
|
|
8806
|
+
float sumx2 = 0;
|
|
8807
|
+
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
|
8808
|
+
float sigma2 = sumx2/QK_K;
|
|
8809
|
+
|
|
8810
|
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
|
8811
|
+
const float * xb = xbl + 32*ib;
|
|
8812
|
+
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
|
8813
|
+
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
|
8814
|
+
for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
|
|
8815
|
+
for (int k = 0; k < 4; ++k) {
|
|
8816
|
+
int nflip = 0;
|
|
8817
|
+
uint8_t s = 0;
|
|
8818
|
+
for (int i = 0; i < 8; ++i) {
|
|
8819
|
+
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
|
|
8820
|
+
else {
|
|
8821
|
+
xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
|
|
8822
|
+
}
|
|
8823
|
+
}
|
|
8824
|
+
if (nflip%2) {
|
|
8825
|
+
int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
|
|
8826
|
+
for (int i = 1; i < 8; ++i) {
|
|
8827
|
+
float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
|
|
8828
|
+
if (ax < min) {
|
|
8829
|
+
min = ax; imin = i;
|
|
8830
|
+
}
|
|
8831
|
+
}
|
|
8832
|
+
xval[8*k+imin] = -xval[8*k+imin];
|
|
8833
|
+
s ^= (1 << imin);
|
|
8834
|
+
}
|
|
8835
|
+
block_signs[k] = s & 127;
|
|
8836
|
+
}
|
|
8837
|
+
float max = xval[0];
|
|
8838
|
+
for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
|
|
8839
|
+
if (!max) {
|
|
8840
|
+
scales[ib] = 0;
|
|
8841
|
+
memset(L, 0, 32);
|
|
8842
|
+
continue;
|
|
8843
|
+
}
|
|
8844
|
+
float best = 0;
|
|
8845
|
+
float scale = max/(2*kMaxQ-1);
|
|
8846
|
+
for (int is = -9; is <= 9; ++is) {
|
|
8847
|
+
float id = (2*kMaxQ-1+is*0.1f)/max;
|
|
8848
|
+
float this_scale = 1/id;
|
|
8849
|
+
for (int k = 0; k < 4; ++k) {
|
|
8850
|
+
for (int i = 0; i < 8; ++i) {
|
|
8851
|
+
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
|
|
8852
|
+
Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
|
|
8853
|
+
}
|
|
8854
|
+
uint16_t u = 0;
|
|
8855
|
+
for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
|
|
8856
|
+
int grid_index = kmap_q2xs[u];
|
|
8857
|
+
is_on_grid_aux[k] = true;
|
|
8858
|
+
if (grid_index < 0) {
|
|
8859
|
+
is_on_grid_aux[k] = false;
|
|
8860
|
+
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
|
8861
|
+
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
|
|
8862
|
+
}
|
|
8863
|
+
}
|
|
8864
|
+
float sumqx = 0, sumq2 = 0;
|
|
8865
|
+
for (int i = 0; i < 32; ++i) {
|
|
8866
|
+
float w = weight[i];
|
|
8867
|
+
float q = 2*Laux[i] + 1;
|
|
8868
|
+
sumqx += w*xval[i]*q;
|
|
8869
|
+
sumq2 += w*q*q;
|
|
8870
|
+
}
|
|
8871
|
+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
|
8872
|
+
scale = sumqx/sumq2; best = scale*sumqx;
|
|
8873
|
+
for (int i = 0; i < 32; ++i) L[i] = Laux[i];
|
|
8874
|
+
for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
|
|
8875
|
+
}
|
|
8876
|
+
}
|
|
8877
|
+
int n_not_ongrid = 0;
|
|
8878
|
+
for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
|
|
8879
|
+
if (n_not_ongrid > 0 && scale > 0) {
|
|
8880
|
+
float id = 1/scale;
|
|
8881
|
+
for (int k = 0; k < 4; ++k) {
|
|
8882
|
+
if (is_on_grid[k]) continue;
|
|
8883
|
+
uint16_t u = 0;
|
|
8884
|
+
for (int i = 0; i < 8; ++i) {
|
|
8885
|
+
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
|
|
8886
|
+
l = MAX(0, MIN(kMaxQ-1, l));
|
|
8887
|
+
u |= (l << 2*i);
|
|
8888
|
+
}
|
|
8889
|
+
int grid_index = kmap_q2xs[u];
|
|
8890
|
+
if (grid_index < 0) {
|
|
8891
|
+
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
|
8892
|
+
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
|
|
8893
|
+
}
|
|
8894
|
+
const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
|
|
8895
|
+
for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
|
|
8896
|
+
}
|
|
8897
|
+
float sumqx = 0, sumq2 = 0;
|
|
8898
|
+
for (int i = 0; i < 32; ++i) {
|
|
8899
|
+
float w = weight[i];
|
|
8900
|
+
float q = 2*L[i] + 1;
|
|
8901
|
+
sumqx += w*xval[i]*q;
|
|
8902
|
+
sumq2 += w*q*q;
|
|
8903
|
+
}
|
|
8904
|
+
if (sumq2 > 0) scale = sumqx/sumq2;
|
|
8905
|
+
}
|
|
8906
|
+
if (scale < 0) {
|
|
8907
|
+
// This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
|
|
8908
|
+
// and correspondingly flip quant signs.
|
|
8909
|
+
scale = -scale;
|
|
8910
|
+
for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
|
|
8911
|
+
}
|
|
8912
|
+
for (int k = 0; k < 4; ++k) {
|
|
8913
|
+
uint16_t u = 0;
|
|
8914
|
+
for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
|
|
8915
|
+
int grid_index = kmap_q2xs[u];
|
|
8916
|
+
if (grid_index < 0) {
|
|
8917
|
+
printf("Oops: found point %u not on grid:", u);
|
|
8918
|
+
for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
|
|
8919
|
+
printf("\n");
|
|
8920
|
+
WSP_GGML_ASSERT(false);
|
|
8921
|
+
}
|
|
8922
|
+
q2[2*ib+0] |= (grid_index << 8*k);
|
|
8923
|
+
q2[2*ib+1] |= (block_signs[k] << 7*k);
|
|
8924
|
+
}
|
|
8925
|
+
WSP_GGML_ASSERT(scale >= 0);
|
|
8926
|
+
scales[ib] = scale;
|
|
8927
|
+
max_scale = MAX(max_scale, scale);
|
|
8928
|
+
}
|
|
8929
|
+
|
|
8930
|
+
if (!max_scale) {
|
|
8931
|
+
memset(y[ibl].qs, 0, QK_K/4);
|
|
8932
|
+
continue;
|
|
8933
|
+
}
|
|
8934
|
+
|
|
8935
|
+
float d = max_scale/31;
|
|
8936
|
+
y[ibl].d = WSP_GGML_FP32_TO_FP16(d);
|
|
8937
|
+
float id = 1/d;
|
|
8938
|
+
float sumqx = 0, sumq2 = 0;
|
|
8939
|
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
|
8940
|
+
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
|
8941
|
+
l = MAX(0, MIN(15, l));
|
|
8942
|
+
q2[2*ib+1] |= ((uint32_t)l << 28);
|
|
8943
|
+
const float * xb = xbl + 32*ib;
|
|
8944
|
+
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
|
8945
|
+
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
|
8946
|
+
const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
|
|
8947
|
+
const float db = d * (1 + 2*l);
|
|
8948
|
+
uint32_t u = 0;
|
|
8949
|
+
for (int k = 0; k < 4; ++k) {
|
|
8950
|
+
const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
|
|
8951
|
+
const float * xk = xb + 8*k;
|
|
8952
|
+
const float * wk = weight + 8*k;
|
|
8953
|
+
const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
|
|
8954
|
+
float best_mse = 0; int best_index = aux8[k];
|
|
8955
|
+
for (int j = 0; j < 8; ++j) {
|
|
8956
|
+
float diff = db * grid[j] * signs[j] - xk[j];
|
|
8957
|
+
best_mse += wk[j] * diff * diff;
|
|
8958
|
+
}
|
|
8959
|
+
for (int idx = 0; idx < 256; ++idx) {
|
|
8960
|
+
grid = (const uint8_t *)(kgrid_q2xs + idx);
|
|
8961
|
+
float mse = 0;
|
|
8962
|
+
for (int j = 0; j < 8; ++j) {
|
|
8963
|
+
float diff = db * grid[j] * signs[j] - xk[j];
|
|
8964
|
+
mse += wk[j] * diff * diff;
|
|
8965
|
+
}
|
|
8966
|
+
if (mse < best_mse) {
|
|
8967
|
+
best_mse = mse; best_index = idx;
|
|
8968
|
+
}
|
|
8969
|
+
}
|
|
8970
|
+
u |= (best_index << 8*k);
|
|
8971
|
+
grid = (const uint8_t *)(kgrid_q2xs + best_index);
|
|
8972
|
+
//grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
|
|
8973
|
+
for (int j = 0; j < 8; ++j) {
|
|
8974
|
+
float q = db * grid[j] * signs[j];
|
|
8975
|
+
sumqx += wk[j] * q * xk[j];
|
|
8976
|
+
sumq2 += wk[j] * q * q;
|
|
8977
|
+
}
|
|
8978
|
+
}
|
|
8979
|
+
q2[2*ib] = u;
|
|
8980
|
+
if (sumq2 > 0) y[ibl].d = WSP_GGML_FP32_TO_FP16(d*sumqx/sumq2);
|
|
8981
|
+
}
|
|
8982
|
+
memcpy(y[ibl].qs, q2, QK_K/4);
|
|
8983
|
+
}
|
|
8984
|
+
}
|
|
8985
|
+
|
|
8986
|
+
static void wsp_quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
|
8987
|
+
|
|
8988
|
+
const int gindex = iq2_data_index(512);
|
|
8989
|
+
|
|
8990
|
+
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
|
8991
|
+
const int * kmap_q2xs = iq2_data[gindex].map;
|
|
8992
|
+
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
|
|
8993
|
+
|
|
8994
|
+
WSP_GGML_ASSERT(quant_weights && "missing quantization weights");
|
|
8995
|
+
WSP_GGML_ASSERT(kmap_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
|
|
8996
|
+
WSP_GGML_ASSERT(kgrid_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
|
|
8997
|
+
WSP_GGML_ASSERT(kneighbors_q2xs && "forgot to call wsp_ggml_wsp_quantize_init()?");
|
|
8998
|
+
WSP_GGML_ASSERT(n%QK_K == 0);
|
|
8999
|
+
|
|
9000
|
+
const int kMaxQ = 3;
|
|
9001
|
+
|
|
9002
|
+
const int nbl = n/256;
|
|
9003
|
+
|
|
9004
|
+
block_iq2_xs * y = vy;
|
|
9005
|
+
|
|
9006
|
+
float scales[QK_K/16];
|
|
9007
|
+
float weight[16];
|
|
9008
|
+
float xval[16];
|
|
9009
|
+
int8_t L[16];
|
|
9010
|
+
int8_t Laux[16];
|
|
9011
|
+
float waux[16];
|
|
9012
|
+
bool is_on_grid[2];
|
|
9013
|
+
bool is_on_grid_aux[2];
|
|
9014
|
+
uint8_t block_signs[2];
|
|
9015
|
+
uint16_t q2[2*(QK_K/16)];
|
|
9016
|
+
|
|
9017
|
+
for (int ibl = 0; ibl < nbl; ++ibl) {
|
|
9018
|
+
|
|
9019
|
+
y[ibl].d = WSP_GGML_FP32_TO_FP16(0.f);
|
|
9020
|
+
memset(q2, 0, QK_K/4);
|
|
9021
|
+
memset(y[ibl].scales, 0, QK_K/32);
|
|
9022
|
+
|
|
9023
|
+
float max_scale = 0;
|
|
9024
|
+
|
|
9025
|
+
const float * xbl = x + QK_K*ibl;
|
|
9026
|
+
float sumx2 = 0;
|
|
9027
|
+
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
|
9028
|
+
float sigma2 = sumx2/QK_K;
|
|
9029
|
+
|
|
9030
|
+
for (int ib = 0; ib < QK_K/16; ++ib) {
|
|
9031
|
+
const float * xb = xbl + 16*ib;
|
|
9032
|
+
const float * qw = quant_weights + QK_K*ibl + 16*ib;
|
|
9033
|
+
for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
|
9034
|
+
for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
|
|
9035
|
+
for (int k = 0; k < 2; ++k) {
|
|
9036
|
+
int nflip = 0;
|
|
9037
|
+
uint8_t s = 0;
|
|
9038
|
+
for (int i = 0; i < 8; ++i) {
|
|
9039
|
+
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
|
|
9040
|
+
else {
|
|
9041
|
+
xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
|
|
9042
|
+
}
|
|
9043
|
+
}
|
|
9044
|
+
if (nflip%2) {
|
|
9045
|
+
int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
|
|
9046
|
+
for (int i = 1; i < 8; ++i) {
|
|
9047
|
+
float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
|
|
9048
|
+
if (ax < min) {
|
|
9049
|
+
min = ax; imin = i;
|
|
9050
|
+
}
|
|
9051
|
+
}
|
|
9052
|
+
xval[8*k+imin] = -xval[8*k+imin];
|
|
9053
|
+
s ^= (1 << imin);
|
|
9054
|
+
}
|
|
9055
|
+
block_signs[k] = s & 127;
|
|
9056
|
+
}
|
|
9057
|
+
float max = xval[0];
|
|
9058
|
+
for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
|
|
9059
|
+
if (!max) {
|
|
9060
|
+
scales[ib] = 0;
|
|
9061
|
+
memset(L, 0, 16);
|
|
9062
|
+
continue;
|
|
9063
|
+
}
|
|
9064
|
+
float best = 0;
|
|
9065
|
+
float scale = max/(2*kMaxQ-1);
|
|
9066
|
+
is_on_grid[0] = is_on_grid[1] = true;
|
|
9067
|
+
for (int is = -9; is <= 9; ++is) {
|
|
9068
|
+
float id = (2*kMaxQ-1+is*0.1f)/max;
|
|
9069
|
+
float this_scale = 1/id;
|
|
9070
|
+
for (int k = 0; k < 2; ++k) {
|
|
9071
|
+
for (int i = 0; i < 8; ++i) {
|
|
9072
|
+
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
|
|
9073
|
+
Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
|
|
9074
|
+
}
|
|
9075
|
+
uint16_t u = 0;
|
|
9076
|
+
for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
|
|
9077
|
+
int grid_index = kmap_q2xs[u];
|
|
9078
|
+
is_on_grid_aux[k] = true;
|
|
9079
|
+
if (grid_index < 0) {
|
|
9080
|
+
is_on_grid_aux[k] = false;
|
|
9081
|
+
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
|
9082
|
+
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
|
|
9083
|
+
}
|
|
9084
|
+
}
|
|
9085
|
+
float sumqx = 0, sumq2 = 0;
|
|
9086
|
+
for (int i = 0; i < 16; ++i) {
|
|
9087
|
+
float w = weight[i];
|
|
9088
|
+
float q = 2*Laux[i] + 1;
|
|
9089
|
+
sumqx += w*xval[i]*q;
|
|
9090
|
+
sumq2 += w*q*q;
|
|
9091
|
+
}
|
|
9092
|
+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
|
9093
|
+
scale = sumqx/sumq2; best = scale*sumqx;
|
|
9094
|
+
for (int i = 0; i < 16; ++i) L[i] = Laux[i];
|
|
9095
|
+
for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k];
|
|
9096
|
+
}
|
|
9097
|
+
}
|
|
9098
|
+
int n_not_ongrid = 0;
|
|
9099
|
+
for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
|
|
9100
|
+
if (n_not_ongrid > 0 && scale > 0) {
|
|
9101
|
+
float id = 1/scale;
|
|
9102
|
+
for (int k = 0; k < 2; ++k) {
|
|
9103
|
+
if (is_on_grid[k]) continue;
|
|
9104
|
+
uint16_t u = 0;
|
|
9105
|
+
for (int i = 0; i < 8; ++i) {
|
|
9106
|
+
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
|
|
9107
|
+
l = MAX(0, MIN(kMaxQ-1, l));
|
|
9108
|
+
u |= (l << 2*i);
|
|
9109
|
+
L[8*k + i] = l;
|
|
9110
|
+
}
|
|
9111
|
+
int grid_index = kmap_q2xs[u];
|
|
9112
|
+
if (grid_index < 0) {
|
|
9113
|
+
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
|
9114
|
+
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
|
|
9115
|
+
}
|
|
9116
|
+
}
|
|
9117
|
+
float sumqx = 0, sumq2 = 0;
|
|
9118
|
+
for (int i = 0; i < 16; ++i) {
|
|
9119
|
+
float w = weight[i];
|
|
9120
|
+
float q = 2*L[i] + 1;
|
|
9121
|
+
sumqx += w*xval[i]*q;
|
|
9122
|
+
sumq2 += w*q*q;
|
|
9123
|
+
}
|
|
9124
|
+
if (sumq2 > 0) scale = sumqx/sumq2;
|
|
9125
|
+
}
|
|
9126
|
+
if (scale < 0) {
|
|
9127
|
+
scale = -scale;
|
|
9128
|
+
for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
|
|
9129
|
+
}
|
|
9130
|
+
for (int k = 0; k < 2; ++k) {
|
|
9131
|
+
uint16_t u = 0;
|
|
9132
|
+
for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
|
|
9133
|
+
int grid_index = kmap_q2xs[u];
|
|
9134
|
+
if (grid_index < 0) {
|
|
9135
|
+
printf("Oops: found point %u not on grid:", u);
|
|
9136
|
+
for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
|
|
9137
|
+
printf("\n");
|
|
9138
|
+
WSP_GGML_ASSERT(false);
|
|
9139
|
+
}
|
|
9140
|
+
q2[2*ib+k] = grid_index | (block_signs[k] << 9);
|
|
9141
|
+
}
|
|
9142
|
+
WSP_GGML_ASSERT(scale >= 0);
|
|
9143
|
+
scales[ib] = scale;
|
|
9144
|
+
max_scale = MAX(max_scale, scale);
|
|
9145
|
+
}
|
|
9146
|
+
|
|
9147
|
+
if (!max_scale) {
|
|
9148
|
+
memset(y[ibl].qs, 0, QK_K/4);
|
|
9149
|
+
continue;
|
|
9150
|
+
}
|
|
9151
|
+
|
|
9152
|
+
float d = max_scale/31;
|
|
9153
|
+
y[ibl].d = WSP_GGML_FP32_TO_FP16(d);
|
|
9154
|
+
float id = 1/d;
|
|
9155
|
+
for (int ib = 0; ib < QK_K/16; ++ib) {
|
|
9156
|
+
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
|
9157
|
+
l = MAX(0, MIN(15, l));
|
|
9158
|
+
if (ib%2 == 0) y[ibl].scales[ib/2] = l;
|
|
9159
|
+
else y[ibl].scales[ib/2] |= (l << 4);
|
|
9160
|
+
}
|
|
9161
|
+
memcpy(y[ibl].qs, q2, QK_K/4);
|
|
9162
|
+
|
|
9163
|
+
}
|
|
9164
|
+
}
|
|
9165
|
+
|
|
9166
|
+
size_t wsp_quantize_iq2_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
9167
|
+
(void)hist;
|
|
9168
|
+
WSP_GGML_ASSERT(n_per_row%QK_K == 0);
|
|
9169
|
+
int nblock = n_per_row/QK_K;
|
|
9170
|
+
char * qrow = (char *)dst;
|
|
9171
|
+
for (int row = 0; row < nrow; ++row) {
|
|
9172
|
+
wsp_quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
|
|
9173
|
+
src += n_per_row;
|
|
9174
|
+
qrow += nblock*sizeof(block_iq2_xxs);
|
|
9175
|
+
}
|
|
9176
|
+
return nrow * nblock * sizeof(block_iq2_xxs);
|
|
9177
|
+
}
|
|
9178
|
+
|
|
9179
|
+
size_t wsp_quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
|
9180
|
+
(void)hist;
|
|
9181
|
+
WSP_GGML_ASSERT(n_per_row%QK_K == 0);
|
|
9182
|
+
int nblock = n_per_row/QK_K;
|
|
9183
|
+
char * qrow = (char *)dst;
|
|
9184
|
+
for (int row = 0; row < nrow; ++row) {
|
|
9185
|
+
wsp_quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
|
|
9186
|
+
src += n_per_row;
|
|
9187
|
+
qrow += nblock*sizeof(block_iq2_xs);
|
|
9188
|
+
}
|
|
9189
|
+
return nrow * nblock * sizeof(block_iq2_xs);
|
|
9190
|
+
}
|
|
9191
|
+
|