llama_cpp 0.12.2 → 0.12.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +68 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -2
- data/vendor/tmp/llama.cpp/Makefile +25 -3
- data/vendor/tmp/llama.cpp/ggml-alloc.c +87 -27
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +176 -18
- data/vendor/tmp/llama.cpp/ggml-backend.h +14 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +1990 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.h +46 -0
- data/vendor/tmp/llama.cpp/ggml-metal.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +144 -113
- data/vendor/tmp/llama.cpp/ggml-metal.metal +303 -4
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +95 -3
- data/vendor/tmp/llama.cpp/ggml-opencl.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +736 -59
- data/vendor/tmp/llama.cpp/ggml-quants.h +20 -1
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +15255 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.h +29 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +60854 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5270 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +34 -0
- data/vendor/tmp/llama.cpp/ggml.c +664 -117
- data/vendor/tmp/llama.cpp/ggml.h +46 -11
- data/vendor/tmp/llama.cpp/llama.cpp +1426 -341
- data/vendor/tmp/llama.cpp/llama.h +24 -15
- data/vendor/tmp/llama.cpp/unicode.h +2 -1
- metadata +10 -3
@@ -1274,7 +1274,12 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
|
|
1274
1274
|
}
|
1275
1275
|
float sumlx = 0;
|
1276
1276
|
float suml2 = 0;
|
1277
|
+
#ifdef HAVE_BUGGY_APPLE_LINKER
|
1278
|
+
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
|
1279
|
+
for (volatile int i = 0; i < n; ++i) {
|
1280
|
+
#else
|
1277
1281
|
for (int i = 0; i < n; ++i) {
|
1282
|
+
#endif
|
1278
1283
|
int l = nearest_int(iscale * x[i]);
|
1279
1284
|
l = MAX(-nmax, MIN(nmax-1, l));
|
1280
1285
|
L[i] = l + nmax;
|
@@ -1649,7 +1654,12 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
|
|
1649
1654
|
float max = x[0];
|
1650
1655
|
float sum_w = weights ? weights[0] : x[0]*x[0];
|
1651
1656
|
float sum_x = sum_w * x[0];
|
1657
|
+
#ifdef HAVE_BUGGY_APPLE_LINKER
|
1658
|
+
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
|
1659
|
+
for (volatile int i = 1; i < n; ++i) {
|
1660
|
+
#else
|
1652
1661
|
for (int i = 1; i < n; ++i) {
|
1662
|
+
#endif
|
1653
1663
|
if (x[i] < min) min = x[i];
|
1654
1664
|
if (x[i] > max) max = x[i];
|
1655
1665
|
float w = weights ? weights[i] : x[i]*x[i];
|
@@ -1660,7 +1670,7 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
|
|
1660
1670
|
min = 0;
|
1661
1671
|
}
|
1662
1672
|
if (max <= min) {
|
1663
|
-
|
1673
|
+
memset(L, 0, n);
|
1664
1674
|
*the_min = -min;
|
1665
1675
|
return 0.f;
|
1666
1676
|
}
|
@@ -1862,7 +1872,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
|
|
1862
1872
|
|
1863
1873
|
size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
1864
1874
|
(void)hist;
|
1865
|
-
|
1875
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
|
1866
1876
|
if (!quant_weights) {
|
1867
1877
|
quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
|
1868
1878
|
}
|
@@ -2181,7 +2191,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
|
|
2181
2191
|
|
2182
2192
|
size_t quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
2183
2193
|
(void)hist;
|
2184
|
-
|
2194
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
|
2185
2195
|
if (!quant_weights) {
|
2186
2196
|
quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
|
2187
2197
|
}
|
@@ -2448,7 +2458,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
|
|
2448
2458
|
|
2449
2459
|
size_t quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
2450
2460
|
(void)hist;
|
2451
|
-
|
2461
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
|
2452
2462
|
if (!quant_weights) {
|
2453
2463
|
quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
|
2454
2464
|
}
|
@@ -2771,7 +2781,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
|
|
2771
2781
|
|
2772
2782
|
size_t quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
2773
2783
|
(void)hist;
|
2774
|
-
|
2784
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
|
2775
2785
|
if (!quant_weights) {
|
2776
2786
|
quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
|
2777
2787
|
}
|
@@ -3025,7 +3035,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
|
|
3025
3035
|
|
3026
3036
|
size_t quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
3027
3037
|
(void)hist;
|
3028
|
-
|
3038
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
|
3029
3039
|
if (!quant_weights) {
|
3030
3040
|
quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
|
3031
3041
|
}
|
@@ -3072,7 +3082,7 @@ size_t quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int
|
|
3072
3082
|
if (!quant_weights) {
|
3073
3083
|
return ggml_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist);
|
3074
3084
|
}
|
3075
|
-
|
3085
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
3076
3086
|
char * qrow = (char *)dst;
|
3077
3087
|
for (int row = 0; row < nrow; ++row) {
|
3078
3088
|
quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
|
@@ -3116,7 +3126,7 @@ size_t quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int
|
|
3116
3126
|
if (!quant_weights) {
|
3117
3127
|
return ggml_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist);
|
3118
3128
|
}
|
3119
|
-
|
3129
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
|
3120
3130
|
char * qrow = (char *)dst;
|
3121
3131
|
for (int row = 0; row < nrow; ++row) {
|
3122
3132
|
quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
|
@@ -3169,7 +3179,7 @@ size_t quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int
|
|
3169
3179
|
if (!quant_weights) {
|
3170
3180
|
return ggml_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist);
|
3171
3181
|
}
|
3172
|
-
|
3182
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
|
3173
3183
|
char * qrow = (char *)dst;
|
3174
3184
|
for (int row = 0; row < nrow; ++row) {
|
3175
3185
|
quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
|
@@ -3221,7 +3231,7 @@ size_t quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int
|
|
3221
3231
|
if (!quant_weights) {
|
3222
3232
|
return ggml_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist);
|
3223
3233
|
}
|
3224
|
-
|
3234
|
+
size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
|
3225
3235
|
char * qrow = (char *)dst;
|
3226
3236
|
for (int row = 0; row < nrow; ++row) {
|
3227
3237
|
quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
|
@@ -3431,6 +3441,41 @@ static const uint64_t iq2xs_grid[512] = {
|
|
3431
3441
|
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
3432
3442
|
};
|
3433
3443
|
|
3444
|
+
static const uint32_t iq3xxs_grid[256] = {
|
3445
|
+
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
3446
|
+
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
3447
|
+
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
|
3448
|
+
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
|
3449
|
+
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
3450
|
+
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
|
3451
|
+
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
|
3452
|
+
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
3453
|
+
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
3454
|
+
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
3455
|
+
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
3456
|
+
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
|
3457
|
+
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
3458
|
+
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
3459
|
+
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
|
3460
|
+
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
3461
|
+
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
|
3462
|
+
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
3463
|
+
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
3464
|
+
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
|
3465
|
+
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
3466
|
+
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
|
3467
|
+
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
|
3468
|
+
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
3469
|
+
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
|
3470
|
+
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
|
3471
|
+
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
|
3472
|
+
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
3473
|
+
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
3474
|
+
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
|
3475
|
+
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
3476
|
+
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
3477
|
+
};
|
3478
|
+
|
3434
3479
|
static const uint8_t ksigns_iq2xs[128] = {
|
3435
3480
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
3436
3481
|
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
@@ -3497,6 +3542,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
|
|
3497
3542
|
}
|
3498
3543
|
}
|
3499
3544
|
|
3545
|
+
// ====================== 3.0625 bpw (de)-quantization
|
3546
|
+
|
3547
|
+
void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
|
3548
|
+
assert(k % QK_K == 0);
|
3549
|
+
const int nb = k / QK_K;
|
3550
|
+
|
3551
|
+
uint32_t aux32;
|
3552
|
+
|
3553
|
+
for (int i = 0; i < nb; i++) {
|
3554
|
+
|
3555
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3556
|
+
const uint8_t * qs = x[i].qs;
|
3557
|
+
const uint8_t * scales_and_signs = qs + QK_K/4;
|
3558
|
+
|
3559
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
3560
|
+
memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
|
3561
|
+
const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
3562
|
+
for (int l = 0; l < 4; ++l) {
|
3563
|
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
3564
|
+
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
|
3565
|
+
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
|
3566
|
+
for (int j = 0; j < 4; ++j) {
|
3567
|
+
y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
3568
|
+
y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
3569
|
+
}
|
3570
|
+
y += 8;
|
3571
|
+
}
|
3572
|
+
qs += 8;
|
3573
|
+
}
|
3574
|
+
}
|
3575
|
+
}
|
3576
|
+
|
3500
3577
|
//===================================== Q8_K ==============================================
|
3501
3578
|
|
3502
3579
|
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
@@ -8448,17 +8525,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
8448
8525
|
|
8449
8526
|
const __m128i m4 = _mm_set1_epi8(0xf);
|
8450
8527
|
const __m128i m1 = _mm_set1_epi8(1);
|
8451
|
-
const
|
8452
|
-
const
|
8528
|
+
const __m256i m511 = _mm256_set1_epi16(511);
|
8529
|
+
const __m256i mone = _mm256_set1_epi8(1);
|
8453
8530
|
|
8454
|
-
const
|
8531
|
+
static const uint8_t k_bit_helper[32] = {
|
8532
|
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
8533
|
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
8534
|
+
};
|
8535
|
+
static const char block_sign_shuffle_mask_1[32] = {
|
8536
|
+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
8537
|
+
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
8538
|
+
};
|
8539
|
+
static const char block_sign_shuffle_mask_2[32] = {
|
8540
|
+
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
|
8541
|
+
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
|
8542
|
+
};
|
8543
|
+
static const uint8_t bit_selector_mask_bytes[32] = {
|
8544
|
+
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
8545
|
+
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
8546
|
+
};
|
8547
|
+
|
8548
|
+
const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
8549
|
+
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
|
8550
|
+
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
|
8551
|
+
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
|
8455
8552
|
|
8456
8553
|
uint64_t aux64;
|
8457
8554
|
|
8458
8555
|
// somewhat hacky, but gives a significant boost in performance
|
8459
|
-
|
8556
|
+
__m256i aux_gindex;
|
8460
8557
|
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
|
8461
|
-
const uint16_t * sindex = (const uint16_t *)&aux_sindex;
|
8462
8558
|
|
8463
8559
|
__m256 accumf = _mm256_setzero_ps();
|
8464
8560
|
for (int i = 0; i < nb; ++i) {
|
@@ -8473,26 +8569,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
8473
8569
|
|
8474
8570
|
__m256i sumi1 = _mm256_setzero_si256();
|
8475
8571
|
__m256i sumi2 = _mm256_setzero_si256();
|
8476
|
-
for (int ib32 = 0; ib32 < QK_K/32; ib32 +=
|
8572
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
|
8573
|
+
|
8574
|
+
const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
|
8575
|
+
aux_gindex = _mm256_and_si256(q2_data, m511);
|
8576
|
+
|
8577
|
+
const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
|
8578
|
+
const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
|
8579
|
+
const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
|
8580
|
+
|
8581
|
+
const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
|
8582
|
+
const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
|
8583
|
+
|
8477
8584
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8478
8585
|
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8479
|
-
const
|
8480
|
-
|
8481
|
-
|
8482
|
-
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]],
|
8483
|
-
|
8484
|
-
const __m256i
|
8485
|
-
|
8486
|
-
const __m256i
|
8487
|
-
|
8586
|
+
const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8587
|
+
const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8588
|
+
|
8589
|
+
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
|
8590
|
+
iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
|
8591
|
+
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
|
8592
|
+
iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
|
8593
|
+
const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
|
8594
|
+
iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
|
8595
|
+
const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
|
8596
|
+
iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
|
8597
|
+
|
8598
|
+
const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
|
8599
|
+
const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
|
8600
|
+
const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
|
8601
|
+
const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
|
8602
|
+
|
8603
|
+
__m256i signs;
|
8604
|
+
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
|
8605
|
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
8606
|
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
|
8607
|
+
|
8608
|
+
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
|
8609
|
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
8610
|
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
|
8611
|
+
|
8612
|
+
signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
|
8613
|
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
8614
|
+
const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
|
8615
|
+
|
8616
|
+
signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
|
8617
|
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
8618
|
+
const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
|
8619
|
+
|
8488
8620
|
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
8489
8621
|
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
8622
|
+
const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
|
8623
|
+
const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
|
8490
8624
|
|
8491
8625
|
const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
|
8492
8626
|
const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
|
8627
|
+
const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
|
8628
|
+
const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
|
8493
8629
|
|
8494
8630
|
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
|
8495
8631
|
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
|
8632
|
+
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
|
8633
|
+
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
|
8496
8634
|
}
|
8497
8635
|
|
8498
8636
|
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
@@ -8541,6 +8679,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
8541
8679
|
#endif
|
8542
8680
|
}
|
8543
8681
|
|
8682
|
+
// TODO
|
8683
|
+
void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
8684
|
+
assert(n % QK_K == 0);
|
8685
|
+
|
8686
|
+
const block_iq3_xxs * restrict x = vx;
|
8687
|
+
const block_q8_K * restrict y = vy;
|
8688
|
+
|
8689
|
+
const int nb = n / QK_K;
|
8690
|
+
|
8691
|
+
#if defined(__ARM_NEON)
|
8692
|
+
|
8693
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
8694
|
+
|
8695
|
+
uint32_t aux32[2];
|
8696
|
+
|
8697
|
+
ggml_int8x16x4_t q3s;
|
8698
|
+
ggml_int8x16x4_t q8b;
|
8699
|
+
|
8700
|
+
float sumf = 0;
|
8701
|
+
for (int i = 0; i < nb; ++i) {
|
8702
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
8703
|
+
const uint8_t * restrict q3 = x[i].qs;
|
8704
|
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
8705
|
+
const int8_t * restrict q8 = y[i].qs;
|
8706
|
+
float sumf1 = 0, sumf2 = 0;
|
8707
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
8708
|
+
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
8709
|
+
memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
|
8710
|
+
const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
|
8711
|
+
const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
|
8712
|
+
const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
|
8713
|
+
const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
|
8714
|
+
q3 += 16;
|
8715
|
+
q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
|
8716
|
+
q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
|
8717
|
+
q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
8718
|
+
q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
8719
|
+
q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
|
8720
|
+
q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
|
8721
|
+
q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
|
8722
|
+
q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
|
8723
|
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
8724
|
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
8725
|
+
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
|
8726
|
+
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
|
8727
|
+
}
|
8728
|
+
sumf += d*(sumf1 + sumf2);
|
8729
|
+
}
|
8730
|
+
*s = 0.5f * sumf;
|
8731
|
+
|
8732
|
+
#elif defined(__AVX2__)
|
8733
|
+
|
8734
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
8735
|
+
|
8736
|
+
uint32_t aux32[2];
|
8737
|
+
|
8738
|
+
__m256 accumf = _mm256_setzero_ps();
|
8739
|
+
for (int i = 0; i < nb; ++i) {
|
8740
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
8741
|
+
const uint8_t * restrict q3 = x[i].qs;
|
8742
|
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
8743
|
+
const int8_t * restrict q8 = y[i].qs;
|
8744
|
+
__m256i sumi1 = _mm256_setzero_si256();
|
8745
|
+
__m256i sumi2 = _mm256_setzero_si256();
|
8746
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
8747
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8748
|
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8749
|
+
const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
|
8750
|
+
iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
8751
|
+
q3 += 8;
|
8752
|
+
const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
|
8753
|
+
iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
8754
|
+
q3 += 8;
|
8755
|
+
memcpy(aux32, gas, 8); gas += 8;
|
8756
|
+
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
|
8757
|
+
signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
|
8758
|
+
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
|
8759
|
+
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
8760
|
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
8761
|
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
8762
|
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
8763
|
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
8764
|
+
const uint16_t ls1 = aux32[0] >> 28;
|
8765
|
+
const uint16_t ls2 = aux32[1] >> 28;
|
8766
|
+
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
|
8767
|
+
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
|
8768
|
+
sumi1 = _mm256_add_epi32(sumi1, p1);
|
8769
|
+
sumi2 = _mm256_add_epi32(sumi2, p2);
|
8770
|
+
}
|
8771
|
+
|
8772
|
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
8773
|
+
|
8774
|
+
}
|
8775
|
+
|
8776
|
+
*s = 0.25f * hsum_float_8(accumf);
|
8777
|
+
|
8778
|
+
#else
|
8779
|
+
|
8780
|
+
uint32_t aux32;
|
8781
|
+
|
8782
|
+
float sumf = 0.f;
|
8783
|
+
for (int i = 0; i < nb; ++i) {
|
8784
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
8785
|
+
const uint8_t * restrict q3 = x[i].qs;
|
8786
|
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
8787
|
+
const int8_t * restrict q8 = y[i].qs;
|
8788
|
+
int32_t bsum = 0;
|
8789
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
8790
|
+
memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
|
8791
|
+
const uint32_t ls = 2*(aux32 >> 28) + 1;
|
8792
|
+
int32_t sumi = 0;
|
8793
|
+
for (int l = 0; l < 4; ++l) {
|
8794
|
+
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
|
8795
|
+
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
|
8796
|
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
8797
|
+
for (int j = 0; j < 4; ++j) {
|
8798
|
+
sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
|
8799
|
+
sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
|
8800
|
+
}
|
8801
|
+
q8 += 8;
|
8802
|
+
}
|
8803
|
+
q3 += 8;
|
8804
|
+
bsum += sumi * ls;
|
8805
|
+
}
|
8806
|
+
sumf += d * bsum;
|
8807
|
+
}
|
8808
|
+
*s = 0.25f * sumf;
|
8809
|
+
#endif
|
8810
|
+
}
|
8811
|
+
|
8544
8812
|
// ================================ IQ2 quantization =============================================
|
8545
8813
|
|
8546
8814
|
typedef struct {
|
@@ -8565,7 +8833,7 @@ static int iq2_compare_func(const void * left, const void * right) {
|
|
8565
8833
|
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
|
8566
8834
|
}
|
8567
8835
|
|
8568
|
-
|
8836
|
+
void iq2xs_init_impl(int grid_size) {
|
8569
8837
|
const int gindex = iq2_data_index(grid_size);
|
8570
8838
|
if (iq2_data[gindex].grid) {
|
8571
8839
|
return;
|
@@ -8720,19 +8988,7 @@ static void q2xs_init_impl(int grid_size) {
|
|
8720
8988
|
free(dist2);
|
8721
8989
|
}
|
8722
8990
|
|
8723
|
-
void
|
8724
|
-
if (type == GGML_TYPE_IQ2_XXS) {
|
8725
|
-
q2xs_init_impl(256);
|
8726
|
-
}
|
8727
|
-
else if (type == GGML_TYPE_IQ2_XS) {
|
8728
|
-
q2xs_init_impl(512);
|
8729
|
-
}
|
8730
|
-
else {
|
8731
|
-
fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
|
8732
|
-
}
|
8733
|
-
}
|
8734
|
-
|
8735
|
-
static void q2xs_deinit_impl(int grid_size) {
|
8991
|
+
void iq2xs_free_impl(int grid_size) {
|
8736
8992
|
GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
|
8737
8993
|
const int gindex = iq2_data_index(grid_size);
|
8738
8994
|
if (iq2_data[gindex].grid) {
|
@@ -8742,18 +8998,6 @@ static void q2xs_deinit_impl(int grid_size) {
|
|
8742
8998
|
}
|
8743
8999
|
}
|
8744
9000
|
|
8745
|
-
void ggml_deinit_iq2_quantization(enum ggml_type type) {
|
8746
|
-
if (type == GGML_TYPE_IQ2_XXS) {
|
8747
|
-
q2xs_deinit_impl(256);
|
8748
|
-
}
|
8749
|
-
else if (type == GGML_TYPE_IQ2_XS) {
|
8750
|
-
q2xs_deinit_impl(512);
|
8751
|
-
}
|
8752
|
-
else {
|
8753
|
-
fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
|
8754
|
-
}
|
8755
|
-
}
|
8756
|
-
|
8757
9001
|
static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
8758
9002
|
const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
|
8759
9003
|
int num_neighbors = neighbours[0];
|
@@ -8786,10 +9030,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
8786
9030
|
const int * kmap_q2xs = iq2_data[gindex].map;
|
8787
9031
|
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
|
8788
9032
|
|
8789
|
-
GGML_ASSERT(quant_weights);
|
8790
|
-
GGML_ASSERT(kgrid_q2xs);
|
8791
|
-
GGML_ASSERT(kmap_q2xs);
|
8792
|
-
GGML_ASSERT(kneighbors_q2xs);
|
9033
|
+
GGML_ASSERT(quant_weights && "missing quantization weights");
|
9034
|
+
GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
|
9035
|
+
GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
|
9036
|
+
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
8793
9037
|
GGML_ASSERT(n%QK_K == 0);
|
8794
9038
|
|
8795
9039
|
const int kMaxQ = 3;
|
@@ -9005,10 +9249,10 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
|
|
9005
9249
|
const int * kmap_q2xs = iq2_data[gindex].map;
|
9006
9250
|
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
|
9007
9251
|
|
9008
|
-
GGML_ASSERT(quant_weights);
|
9009
|
-
GGML_ASSERT(kmap_q2xs);
|
9010
|
-
GGML_ASSERT(kgrid_q2xs);
|
9011
|
-
GGML_ASSERT(kneighbors_q2xs);
|
9252
|
+
GGML_ASSERT(quant_weights && "missing quantization weights");
|
9253
|
+
GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
|
9254
|
+
GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
|
9255
|
+
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
9012
9256
|
GGML_ASSERT(n%QK_K == 0);
|
9013
9257
|
|
9014
9258
|
const int kMaxQ = 3;
|
@@ -9203,3 +9447,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
|
|
9203
9447
|
return nrow * nblock * sizeof(block_iq2_xs);
|
9204
9448
|
}
|
9205
9449
|
|
9450
|
+
//
|
9451
|
+
// ============================================= 3-bit using D4 lattice
|
9452
|
+
//
|
9453
|
+
|
9454
|
+
typedef struct {
|
9455
|
+
uint32_t * grid;
|
9456
|
+
int * map;
|
9457
|
+
uint16_t * neighbours;
|
9458
|
+
} iq3_entry_t;
|
9459
|
+
|
9460
|
+
static iq3_entry_t iq3_data[1] = {
|
9461
|
+
{NULL, NULL, NULL},
|
9462
|
+
};
|
9463
|
+
|
9464
|
+
static inline int iq3_data_index(int grid_size) {
|
9465
|
+
(void)grid_size;
|
9466
|
+
GGML_ASSERT(grid_size == 256);
|
9467
|
+
return 0;
|
9468
|
+
}
|
9469
|
+
|
9470
|
+
static int iq3_compare_func(const void * left, const void * right) {
|
9471
|
+
const int * l = (const int *)left;
|
9472
|
+
const int * r = (const int *)right;
|
9473
|
+
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
|
9474
|
+
}
|
9475
|
+
|
9476
|
+
void iq3xs_init_impl(int grid_size) {
|
9477
|
+
const int gindex = iq3_data_index(grid_size);
|
9478
|
+
if (iq3_data[gindex].grid) {
|
9479
|
+
return;
|
9480
|
+
}
|
9481
|
+
static const uint16_t kgrid_256[256] = {
|
9482
|
+
0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
|
9483
|
+
81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
|
9484
|
+
169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
|
9485
|
+
327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
|
9486
|
+
536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
|
9487
|
+
698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
|
9488
|
+
992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
|
9489
|
+
1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
|
9490
|
+
1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
|
9491
|
+
1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
|
9492
|
+
1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
|
9493
|
+
2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
|
9494
|
+
2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
|
9495
|
+
2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
|
9496
|
+
3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
|
9497
|
+
3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
|
9498
|
+
};
|
9499
|
+
const int kmap_size = 4096;
|
9500
|
+
const int nwant = 2;
|
9501
|
+
const uint16_t * kgrid = kgrid_256;
|
9502
|
+
uint32_t * kgrid_q3xs;
|
9503
|
+
int * kmap_q3xs;
|
9504
|
+
uint16_t * kneighbors_q3xs;
|
9505
|
+
|
9506
|
+
printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
|
9507
|
+
uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
|
9508
|
+
for (int k = 0; k < grid_size; ++k) {
|
9509
|
+
int8_t * pos = (int8_t *)(the_grid + k);
|
9510
|
+
for (int i = 0; i < 4; ++i) {
|
9511
|
+
int l = (kgrid[k] >> 3*i) & 0x7;
|
9512
|
+
pos[i] = 2*l + 1;
|
9513
|
+
}
|
9514
|
+
}
|
9515
|
+
kgrid_q3xs = the_grid;
|
9516
|
+
iq3_data[gindex].grid = the_grid;
|
9517
|
+
kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
|
9518
|
+
iq3_data[gindex].map = kmap_q3xs;
|
9519
|
+
for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
|
9520
|
+
uint32_t aux32;
|
9521
|
+
uint8_t * aux8 = (uint8_t *)&aux32;
|
9522
|
+
for (int i = 0; i < grid_size; ++i) {
|
9523
|
+
aux32 = kgrid_q3xs[i];
|
9524
|
+
uint16_t index = 0;
|
9525
|
+
for (int k=0; k<4; ++k) {
|
9526
|
+
uint16_t q = (aux8[k] - 1)/2;
|
9527
|
+
index |= (q << 3*k);
|
9528
|
+
}
|
9529
|
+
kmap_q3xs[index] = i;
|
9530
|
+
}
|
9531
|
+
int8_t pos[4];
|
9532
|
+
int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
|
9533
|
+
int num_neighbors = 0, num_not_in_map = 0;
|
9534
|
+
for (int i = 0; i < kmap_size; ++i) {
|
9535
|
+
if (kmap_q3xs[i] >= 0) continue;
|
9536
|
+
++num_not_in_map;
|
9537
|
+
for (int k = 0; k < 4; ++k) {
|
9538
|
+
int l = (i >> 3*k) & 0x7;
|
9539
|
+
pos[k] = 2*l + 1;
|
9540
|
+
}
|
9541
|
+
for (int j = 0; j < grid_size; ++j) {
|
9542
|
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
|
9543
|
+
int d2 = 0;
|
9544
|
+
for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
|
9545
|
+
dist2[2*j+0] = d2;
|
9546
|
+
dist2[2*j+1] = j;
|
9547
|
+
}
|
9548
|
+
qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
|
9549
|
+
int n = 0; int d2 = dist2[0];
|
9550
|
+
int nhave = 1;
|
9551
|
+
for (int j = 0; j < grid_size; ++j) {
|
9552
|
+
if (dist2[2*j] > d2) {
|
9553
|
+
if (nhave == nwant) break;
|
9554
|
+
d2 = dist2[2*j];
|
9555
|
+
++nhave;
|
9556
|
+
}
|
9557
|
+
++n;
|
9558
|
+
}
|
9559
|
+
num_neighbors += n;
|
9560
|
+
}
|
9561
|
+
printf("%s: %d neighbours in total\n", __func__, num_neighbors);
|
9562
|
+
kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
|
9563
|
+
iq3_data[gindex].neighbours = kneighbors_q3xs;
|
9564
|
+
int counter = 0;
|
9565
|
+
for (int i = 0; i < kmap_size; ++i) {
|
9566
|
+
if (kmap_q3xs[i] >= 0) continue;
|
9567
|
+
for (int k = 0; k < 4; ++k) {
|
9568
|
+
int l = (i >> 3*k) & 0x7;
|
9569
|
+
pos[k] = 2*l + 1;
|
9570
|
+
}
|
9571
|
+
for (int j = 0; j < grid_size; ++j) {
|
9572
|
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
|
9573
|
+
int d2 = 0;
|
9574
|
+
for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
|
9575
|
+
dist2[2*j+0] = d2;
|
9576
|
+
dist2[2*j+1] = j;
|
9577
|
+
}
|
9578
|
+
qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
|
9579
|
+
kmap_q3xs[i] = -(counter + 1);
|
9580
|
+
int d2 = dist2[0];
|
9581
|
+
uint16_t * start = &kneighbors_q3xs[counter++];
|
9582
|
+
int n = 0, nhave = 1;
|
9583
|
+
for (int j = 0; j < grid_size; ++j) {
|
9584
|
+
if (dist2[2*j] > d2) {
|
9585
|
+
if (nhave == nwant) break;
|
9586
|
+
d2 = dist2[2*j];
|
9587
|
+
++nhave;
|
9588
|
+
}
|
9589
|
+
kneighbors_q3xs[counter++] = dist2[2*j+1];
|
9590
|
+
++n;
|
9591
|
+
}
|
9592
|
+
*start = n;
|
9593
|
+
}
|
9594
|
+
free(dist2);
|
9595
|
+
}
|
9596
|
+
|
9597
|
+
void iq3xs_free_impl(int grid_size) {
|
9598
|
+
GGML_ASSERT(grid_size == 256);
|
9599
|
+
const int gindex = iq3_data_index(grid_size);
|
9600
|
+
if (iq3_data[gindex].grid) {
|
9601
|
+
free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
|
9602
|
+
free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
|
9603
|
+
free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
|
9604
|
+
}
|
9605
|
+
}
|
9606
|
+
|
9607
|
+
static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
|
9608
|
+
const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
|
9609
|
+
int num_neighbors = neighbours[0];
|
9610
|
+
GGML_ASSERT(num_neighbors > 0);
|
9611
|
+
float best_d2 = FLT_MAX;
|
9612
|
+
int grid_index = -1;
|
9613
|
+
for (int j = 1; j <= num_neighbors; ++j) {
|
9614
|
+
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
9615
|
+
float d2 = 0;
|
9616
|
+
for (int i = 0; i < 4; ++i) {
|
9617
|
+
float q = pg[i];
|
9618
|
+
float diff = scale*q - xval[i];
|
9619
|
+
d2 += weight[i]*diff*diff;
|
9620
|
+
}
|
9621
|
+
if (d2 < best_d2) {
|
9622
|
+
best_d2 = d2; grid_index = neighbours[j];
|
9623
|
+
}
|
9624
|
+
}
|
9625
|
+
GGML_ASSERT(grid_index >= 0);
|
9626
|
+
const int8_t * pg = (const int8_t *)(grid + grid_index);
|
9627
|
+
for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
|
9628
|
+
return grid_index;
|
9629
|
+
}
|
9630
|
+
|
9631
|
+
static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
9632
|
+
|
9633
|
+
const int gindex = iq3_data_index(256);
|
9634
|
+
|
9635
|
+
const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
|
9636
|
+
const int * kmap_q3xs = iq3_data[gindex].map;
|
9637
|
+
const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
|
9638
|
+
|
9639
|
+
//GGML_ASSERT(quant_weights && "missing quantization weights");
|
9640
|
+
GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
|
9641
|
+
GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
|
9642
|
+
GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
|
9643
|
+
GGML_ASSERT(n%QK_K == 0);
|
9644
|
+
|
9645
|
+
const int kMaxQ = 8;
|
9646
|
+
|
9647
|
+
const int nbl = n/256;
|
9648
|
+
|
9649
|
+
block_iq3_xxs * y = vy;
|
9650
|
+
|
9651
|
+
float scales[QK_K/32];
|
9652
|
+
float weight[32];
|
9653
|
+
float xval[32];
|
9654
|
+
int8_t L[32];
|
9655
|
+
int8_t Laux[32];
|
9656
|
+
float waux[32];
|
9657
|
+
bool is_on_grid[8];
|
9658
|
+
bool is_on_grid_aux[8];
|
9659
|
+
uint8_t block_signs[8];
|
9660
|
+
uint8_t q3[3*(QK_K/8)];
|
9661
|
+
uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
|
9662
|
+
|
9663
|
+
for (int ibl = 0; ibl < nbl; ++ibl) {
|
9664
|
+
|
9665
|
+
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
9666
|
+
memset(q3, 0, 3*QK_K/8);
|
9667
|
+
|
9668
|
+
float max_scale = 0;
|
9669
|
+
|
9670
|
+
const float * xbl = x + QK_K*ibl;
|
9671
|
+
float sumx2 = 0;
|
9672
|
+
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
9673
|
+
float sigma2 = sumx2/QK_K;
|
9674
|
+
|
9675
|
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
9676
|
+
const float * xb = xbl + 32*ib;
|
9677
|
+
if (quant_weights) {
|
9678
|
+
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
9679
|
+
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
9680
|
+
} else {
|
9681
|
+
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
|
9682
|
+
}
|
9683
|
+
for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
|
9684
|
+
for (int k = 0; k < 4; ++k) {
|
9685
|
+
int nflip = 0;
|
9686
|
+
uint8_t s = 0;
|
9687
|
+
for (int i = 0; i < 8; ++i) {
|
9688
|
+
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
|
9689
|
+
else {
|
9690
|
+
xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
|
9691
|
+
}
|
9692
|
+
}
|
9693
|
+
if (nflip%2) {
|
9694
|
+
int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
|
9695
|
+
for (int i = 1; i < 8; ++i) {
|
9696
|
+
float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
|
9697
|
+
if (ax < min) {
|
9698
|
+
min = ax; imin = i;
|
9699
|
+
}
|
9700
|
+
}
|
9701
|
+
xval[8*k+imin] = -xval[8*k+imin];
|
9702
|
+
s ^= (1 << imin);
|
9703
|
+
}
|
9704
|
+
block_signs[k] = s & 127;
|
9705
|
+
}
|
9706
|
+
float max = xval[0];
|
9707
|
+
for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
|
9708
|
+
if (!max) {
|
9709
|
+
scales[ib] = 0;
|
9710
|
+
memset(L, 0, 32);
|
9711
|
+
continue;
|
9712
|
+
}
|
9713
|
+
float best = 0;
|
9714
|
+
float scale = max/(2*kMaxQ-1);
|
9715
|
+
for (int is = -15; is <= 15; ++is) {
|
9716
|
+
float id = (2*kMaxQ-1+is*0.2f)/max;
|
9717
|
+
float this_scale = 1/id;
|
9718
|
+
for (int k = 0; k < 8; ++k) {
|
9719
|
+
for (int i = 0; i < 4; ++i) {
|
9720
|
+
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
9721
|
+
Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
|
9722
|
+
}
|
9723
|
+
uint16_t u = 0;
|
9724
|
+
for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
|
9725
|
+
int grid_index = kmap_q3xs[u];
|
9726
|
+
is_on_grid_aux[k] = true;
|
9727
|
+
if (grid_index < 0) {
|
9728
|
+
is_on_grid_aux[k] = false;
|
9729
|
+
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
|
9730
|
+
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
|
9731
|
+
}
|
9732
|
+
}
|
9733
|
+
float sumqx = 0, sumq2 = 0;
|
9734
|
+
for (int i = 0; i < 32; ++i) {
|
9735
|
+
float w = weight[i];
|
9736
|
+
float q = 2*Laux[i] + 1;
|
9737
|
+
sumqx += w*xval[i]*q;
|
9738
|
+
sumq2 += w*q*q;
|
9739
|
+
}
|
9740
|
+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
9741
|
+
scale = sumqx/sumq2; best = scale*sumqx;
|
9742
|
+
for (int i = 0; i < 32; ++i) L[i] = Laux[i];
|
9743
|
+
for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
|
9744
|
+
}
|
9745
|
+
}
|
9746
|
+
int n_not_ongrid = 0;
|
9747
|
+
for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
|
9748
|
+
if (n_not_ongrid > 0 && scale > 0) {
|
9749
|
+
float id = 1/scale;
|
9750
|
+
for (int k = 0; k < 8; ++k) {
|
9751
|
+
if (is_on_grid[k]) continue;
|
9752
|
+
uint16_t u = 0;
|
9753
|
+
for (int i = 0; i < 4; ++i) {
|
9754
|
+
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
9755
|
+
l = MAX(0, MIN(kMaxQ-1, l));
|
9756
|
+
u |= (l << 3*i);
|
9757
|
+
}
|
9758
|
+
int grid_index = kmap_q3xs[u];
|
9759
|
+
if (grid_index < 0) {
|
9760
|
+
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
|
9761
|
+
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
|
9762
|
+
}
|
9763
|
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
|
9764
|
+
for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
|
9765
|
+
}
|
9766
|
+
float sumqx = 0, sumq2 = 0;
|
9767
|
+
for (int i = 0; i < 32; ++i) {
|
9768
|
+
float w = weight[i];
|
9769
|
+
float q = 2*L[i] + 1;
|
9770
|
+
sumqx += w*xval[i]*q;
|
9771
|
+
sumq2 += w*q*q;
|
9772
|
+
}
|
9773
|
+
if (sumq2 > 0) scale = sumqx/sumq2;
|
9774
|
+
}
|
9775
|
+
if (scale < 0) {
|
9776
|
+
// This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
|
9777
|
+
// and correspondingly flip quant signs.
|
9778
|
+
scale = -scale;
|
9779
|
+
for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
|
9780
|
+
}
|
9781
|
+
for (int k = 0; k < 8; ++k) {
|
9782
|
+
uint16_t u = 0;
|
9783
|
+
for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
|
9784
|
+
int grid_index = kmap_q3xs[u];
|
9785
|
+
if (grid_index < 0) {
|
9786
|
+
printf("Oops: found point %u not on grid:", u);
|
9787
|
+
for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
|
9788
|
+
printf("\n");
|
9789
|
+
GGML_ASSERT(false);
|
9790
|
+
}
|
9791
|
+
q3[8*ib+k] = grid_index;
|
9792
|
+
}
|
9793
|
+
scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
|
9794
|
+
GGML_ASSERT(scale >= 0);
|
9795
|
+
scales[ib] = scale;
|
9796
|
+
max_scale = MAX(max_scale, scale);
|
9797
|
+
}
|
9798
|
+
|
9799
|
+
if (!max_scale) {
|
9800
|
+
memset(y[ibl].qs, 0, 3*QK_K/8);
|
9801
|
+
continue;
|
9802
|
+
}
|
9803
|
+
|
9804
|
+
float d = max_scale/31;
|
9805
|
+
y[ibl].d = GGML_FP32_TO_FP16(d);
|
9806
|
+
float id = 1/d;
|
9807
|
+
float sumqx = 0, sumq2 = 0;
|
9808
|
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
9809
|
+
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
9810
|
+
l = MAX(0, MIN(15, l));
|
9811
|
+
scales_and_signs[ib] |= ((uint32_t)l << 28);
|
9812
|
+
if (false) {
|
9813
|
+
const float * xb = xbl + 32*ib;
|
9814
|
+
if (quant_weights) {
|
9815
|
+
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
9816
|
+
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
9817
|
+
} else {
|
9818
|
+
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
|
9819
|
+
}
|
9820
|
+
const float db = 0.25f * d * (1 + 2*l);
|
9821
|
+
for (int k = 0; k < 8; ++k) {
|
9822
|
+
const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
|
9823
|
+
const float * xk = xb + 4*k;
|
9824
|
+
const float * wk = weight + 4*k;
|
9825
|
+
//const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
|
9826
|
+
const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
|
9827
|
+
float best_mse = 0; int best_index = q3[8*ib+k];
|
9828
|
+
for (int j = 0; j < 4; ++j) {
|
9829
|
+
float diff = db * grid[j] * signs[j] - xk[j];
|
9830
|
+
best_mse += wk[j] * diff * diff;
|
9831
|
+
}
|
9832
|
+
for (int idx = 0; idx < 256; ++idx) {
|
9833
|
+
//grid = (const uint8_t *)(kgrid_q3xs + idx);
|
9834
|
+
grid = (const uint8_t *)(iq3xxs_grid + idx);
|
9835
|
+
float mse = 0;
|
9836
|
+
for (int j = 0; j < 4; ++j) {
|
9837
|
+
float diff = db * grid[j] * signs[j] - xk[j];
|
9838
|
+
mse += wk[j] * diff * diff;
|
9839
|
+
}
|
9840
|
+
if (mse < best_mse) {
|
9841
|
+
best_mse = mse; best_index = idx;
|
9842
|
+
}
|
9843
|
+
}
|
9844
|
+
q3[8*ib+k] = best_index;
|
9845
|
+
//grid = (const uint8_t *)(kgrid_q3xs + best_index);
|
9846
|
+
grid = (const uint8_t *)(iq3xxs_grid + best_index);
|
9847
|
+
for (int j = 0; j < 4; ++j) {
|
9848
|
+
float q = db * grid[j] * signs[j];
|
9849
|
+
sumqx += wk[j] * q * xk[j];
|
9850
|
+
sumq2 += wk[j] * q * q;
|
9851
|
+
}
|
9852
|
+
}
|
9853
|
+
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
|
9854
|
+
}
|
9855
|
+
}
|
9856
|
+
memcpy(y[ibl].qs, q3, 3*QK_K/8);
|
9857
|
+
}
|
9858
|
+
}
|
9859
|
+
|
9860
|
+
size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
9861
|
+
(void)hist;
|
9862
|
+
GGML_ASSERT(n_per_row%QK_K == 0);
|
9863
|
+
int nblock = n_per_row/QK_K;
|
9864
|
+
char * qrow = (char *)dst;
|
9865
|
+
for (int row = 0; row < nrow; ++row) {
|
9866
|
+
quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
|
9867
|
+
src += n_per_row;
|
9868
|
+
qrow += nblock*sizeof(block_iq3_xxs);
|
9869
|
+
}
|
9870
|
+
return nrow * nblock * sizeof(block_iq3_xxs);
|
9871
|
+
}
|
9872
|
+
|
9873
|
+
void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
|
9874
|
+
assert(k % QK_K == 0);
|
9875
|
+
block_iq3_xxs * restrict y = vy;
|
9876
|
+
quantize_row_iq3_xxs_reference(x, y, k);
|
9877
|
+
}
|
9878
|
+
|
9879
|
+
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
|
9880
|
+
assert(k % QK_K == 0);
|
9881
|
+
quantize_row_iq3_xxs_impl(x, y, k, NULL);
|
9882
|
+
}
|