llama_cpp 0.12.3 → 0.12.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +22 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -2
- data/vendor/tmp/llama.cpp/Makefile +160 -56
- data/vendor/tmp/llama.cpp/ggml-alloc.c +85 -25
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +115 -3
- data/vendor/tmp/llama.cpp/ggml-backend.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +688 -270
- data/vendor/tmp/llama.cpp/ggml-impl.h +2 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +1990 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.h +46 -0
- data/vendor/tmp/llama.cpp/ggml-metal.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +121 -86
- data/vendor/tmp/llama.cpp/ggml-metal.metal +303 -4
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +95 -3
- data/vendor/tmp/llama.cpp/ggml-opencl.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +745 -109
- data/vendor/tmp/llama.cpp/ggml-quants.h +81 -56
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +15296 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.h +29 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +51714 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5726 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +39 -0
- data/vendor/tmp/llama.cpp/ggml.c +356 -60
- data/vendor/tmp/llama.cpp/ggml.h +7 -1
- data/vendor/tmp/llama.cpp/llama.cpp +876 -118
- data/vendor/tmp/llama.cpp/llama.h +12 -16
- metadata +9 -2
@@ -2381,19 +2381,20 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
|
|
2381
2381
|
|
2382
2382
|
uint8_t L[QK_K];
|
2383
2383
|
uint8_t Laux[32];
|
2384
|
+
uint8_t Ls[QK_K/32];
|
2385
|
+
uint8_t Lm[QK_K/32];
|
2384
2386
|
float weights[32];
|
2385
|
-
float
|
2386
|
-
float
|
2387
|
+
float sw[QK_K/32];
|
2388
|
+
float mins[QK_K/32];
|
2389
|
+
float scales[QK_K/32];
|
2387
2390
|
|
2388
2391
|
for (int i = 0; i < nb; i++) {
|
2389
2392
|
|
2390
2393
|
float sum_x2 = 0;
|
2391
2394
|
for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
|
2392
|
-
float sigma2 = sum_x2/QK_K;
|
2395
|
+
float sigma2 = 2*sum_x2/QK_K;
|
2393
2396
|
float av_x = sqrtf(sigma2);
|
2394
2397
|
|
2395
|
-
float max_scale = 0; // as we are deducting the min, scales are always positive
|
2396
|
-
float max_min = 0;
|
2397
2398
|
for (int j = 0; j < QK_K/32; ++j) {
|
2398
2399
|
if (quant_weights) {
|
2399
2400
|
const float * qw = quant_weights + QK_K*i + 32*j;
|
@@ -2401,25 +2402,17 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
|
|
2401
2402
|
} else {
|
2402
2403
|
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
|
2403
2404
|
}
|
2405
|
+
float sumw = 0;
|
2406
|
+
for (int l = 0; l < 32; ++l) sumw += weights[l];
|
2407
|
+
sw[j] = sumw;
|
2404
2408
|
scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
2405
|
-
//scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
|
2406
|
-
float scale = scales[j];
|
2407
|
-
if (scale > max_scale) {
|
2408
|
-
max_scale = scale;
|
2409
|
-
}
|
2410
|
-
float min = mins[j];
|
2411
|
-
if (min > max_min) {
|
2412
|
-
max_min = min;
|
2413
|
-
}
|
2414
2409
|
}
|
2415
2410
|
|
2416
|
-
float
|
2417
|
-
float
|
2411
|
+
float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
|
2412
|
+
float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw);
|
2418
2413
|
for (int j = 0; j < QK_K/32; ++j) {
|
2419
|
-
uint8_t ls =
|
2420
|
-
uint8_t lm =
|
2421
|
-
ls = MIN(63, ls);
|
2422
|
-
lm = MIN(63, lm);
|
2414
|
+
uint8_t ls = Ls[j];
|
2415
|
+
uint8_t lm = Lm[j];
|
2423
2416
|
if (j < 4) {
|
2424
2417
|
y[i].scales[j] = ls;
|
2425
2418
|
y[i].scales[j+4] = lm;
|
@@ -2429,8 +2422,8 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
|
|
2429
2422
|
y[i].scales[j-0] |= ((lm >> 4) << 6);
|
2430
2423
|
}
|
2431
2424
|
}
|
2432
|
-
y[i].d = GGML_FP32_TO_FP16(
|
2433
|
-
y[i].dmin = GGML_FP32_TO_FP16(
|
2425
|
+
y[i].d = GGML_FP32_TO_FP16(d_block);
|
2426
|
+
y[i].dmin = GGML_FP32_TO_FP16(m_block);
|
2434
2427
|
|
2435
2428
|
uint8_t sc, m;
|
2436
2429
|
for (int j = 0; j < QK_K/32; ++j) {
|
@@ -2688,20 +2681,21 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
|
|
2688
2681
|
const int nb = n_per_row / QK_K;
|
2689
2682
|
|
2690
2683
|
uint8_t L[QK_K];
|
2691
|
-
float mins[QK_K/32];
|
2692
|
-
float scales[QK_K/32];
|
2693
|
-
float weights[32];
|
2694
2684
|
uint8_t Laux[32];
|
2685
|
+
uint8_t Ls[QK_K/32];
|
2686
|
+
uint8_t Lm[QK_K/32];
|
2687
|
+
float mins[QK_K/32];
|
2688
|
+
float scales[QK_K/32];
|
2689
|
+
float sw[QK_K/32];
|
2690
|
+
float weights[32];
|
2695
2691
|
|
2696
2692
|
for (int i = 0; i < nb; i++) {
|
2697
2693
|
|
2698
2694
|
float sum_x2 = 0;
|
2699
2695
|
for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
|
2700
|
-
float sigma2 = sum_x2/QK_K;
|
2696
|
+
float sigma2 = 2*sum_x2/QK_K;
|
2701
2697
|
float av_x = sqrtf(sigma2);
|
2702
2698
|
|
2703
|
-
float max_scale = 0; // as we are deducting the min, scales are always positive
|
2704
|
-
float max_min = 0;
|
2705
2699
|
for (int j = 0; j < QK_K/32; ++j) {
|
2706
2700
|
if (quant_weights) {
|
2707
2701
|
const float * qw = quant_weights + QK_K*i + 32*j;
|
@@ -2709,22 +2703,19 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
|
|
2709
2703
|
} else {
|
2710
2704
|
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
|
2711
2705
|
}
|
2706
|
+
float sumw = 0;
|
2707
|
+
for (int l = 0; l < 32; ++l) sumw += weights[l];
|
2708
|
+
sw[j] = sumw;
|
2709
|
+
|
2712
2710
|
scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
2713
|
-
float scale = scales[j];
|
2714
|
-
if (scale > max_scale) {
|
2715
|
-
max_scale = scale;
|
2716
|
-
}
|
2717
|
-
float min = mins[j];
|
2718
|
-
if (min > max_min) {
|
2719
|
-
max_min = min;
|
2720
|
-
}
|
2721
2711
|
}
|
2722
2712
|
|
2723
|
-
float
|
2724
|
-
float
|
2713
|
+
float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
|
2714
|
+
float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw);
|
2715
|
+
|
2725
2716
|
for (int j = 0; j < QK_K/32; ++j) {
|
2726
|
-
uint8_t ls =
|
2727
|
-
uint8_t lm =
|
2717
|
+
uint8_t ls = Ls[j];
|
2718
|
+
uint8_t lm = Lm[j];
|
2728
2719
|
ls = MIN(63, ls);
|
2729
2720
|
lm = MIN(63, lm);
|
2730
2721
|
if (j < 4) {
|
@@ -2736,8 +2727,8 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
|
|
2736
2727
|
y[i].scales[j-0] |= ((lm >> 4) << 6);
|
2737
2728
|
}
|
2738
2729
|
}
|
2739
|
-
y[i].d = GGML_FP32_TO_FP16(
|
2740
|
-
y[i].dmin = GGML_FP32_TO_FP16(
|
2730
|
+
y[i].d = GGML_FP32_TO_FP16(d_block);
|
2731
|
+
y[i].dmin = GGML_FP32_TO_FP16(m_block);
|
2741
2732
|
|
2742
2733
|
uint8_t sc, m;
|
2743
2734
|
for (int j = 0; j < QK_K/32; ++j) {
|
@@ -3441,6 +3432,41 @@ static const uint64_t iq2xs_grid[512] = {
|
|
3441
3432
|
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
3442
3433
|
};
|
3443
3434
|
|
3435
|
+
static const uint32_t iq3xxs_grid[256] = {
|
3436
|
+
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
3437
|
+
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
3438
|
+
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
|
3439
|
+
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
|
3440
|
+
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
3441
|
+
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
|
3442
|
+
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
|
3443
|
+
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
3444
|
+
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
3445
|
+
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
3446
|
+
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
3447
|
+
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
|
3448
|
+
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
3449
|
+
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
3450
|
+
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
|
3451
|
+
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
3452
|
+
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
|
3453
|
+
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
3454
|
+
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
3455
|
+
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
|
3456
|
+
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
3457
|
+
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
|
3458
|
+
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
|
3459
|
+
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
3460
|
+
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
|
3461
|
+
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
|
3462
|
+
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
|
3463
|
+
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
3464
|
+
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
3465
|
+
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
|
3466
|
+
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
3467
|
+
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
3468
|
+
};
|
3469
|
+
|
3444
3470
|
static const uint8_t ksigns_iq2xs[128] = {
|
3445
3471
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
3446
3472
|
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
@@ -3507,6 +3533,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
|
|
3507
3533
|
}
|
3508
3534
|
}
|
3509
3535
|
|
3536
|
+
// ====================== 3.0625 bpw (de)-quantization
|
3537
|
+
|
3538
|
+
void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
|
3539
|
+
assert(k % QK_K == 0);
|
3540
|
+
const int nb = k / QK_K;
|
3541
|
+
|
3542
|
+
uint32_t aux32;
|
3543
|
+
|
3544
|
+
for (int i = 0; i < nb; i++) {
|
3545
|
+
|
3546
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3547
|
+
const uint8_t * qs = x[i].qs;
|
3548
|
+
const uint8_t * scales_and_signs = qs + QK_K/4;
|
3549
|
+
|
3550
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
3551
|
+
memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
|
3552
|
+
const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
3553
|
+
for (int l = 0; l < 4; ++l) {
|
3554
|
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
3555
|
+
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
|
3556
|
+
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
|
3557
|
+
for (int j = 0; j < 4; ++j) {
|
3558
|
+
y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
3559
|
+
y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
3560
|
+
}
|
3561
|
+
y += 8;
|
3562
|
+
}
|
3563
|
+
qs += 8;
|
3564
|
+
}
|
3565
|
+
}
|
3566
|
+
}
|
3567
|
+
|
3510
3568
|
//===================================== Q8_K ==============================================
|
3511
3569
|
|
3512
3570
|
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
@@ -8458,17 +8516,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
8458
8516
|
|
8459
8517
|
const __m128i m4 = _mm_set1_epi8(0xf);
|
8460
8518
|
const __m128i m1 = _mm_set1_epi8(1);
|
8461
|
-
const
|
8462
|
-
const
|
8519
|
+
const __m256i m511 = _mm256_set1_epi16(511);
|
8520
|
+
const __m256i mone = _mm256_set1_epi8(1);
|
8463
8521
|
|
8464
|
-
const
|
8522
|
+
static const uint8_t k_bit_helper[32] = {
|
8523
|
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
8524
|
+
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
8525
|
+
};
|
8526
|
+
static const char block_sign_shuffle_mask_1[32] = {
|
8527
|
+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
8528
|
+
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
8529
|
+
};
|
8530
|
+
static const char block_sign_shuffle_mask_2[32] = {
|
8531
|
+
0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
|
8532
|
+
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
|
8533
|
+
};
|
8534
|
+
static const uint8_t bit_selector_mask_bytes[32] = {
|
8535
|
+
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
8536
|
+
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
8537
|
+
};
|
8538
|
+
|
8539
|
+
const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
8540
|
+
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
|
8541
|
+
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
|
8542
|
+
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
|
8465
8543
|
|
8466
8544
|
uint64_t aux64;
|
8467
8545
|
|
8468
8546
|
// somewhat hacky, but gives a significant boost in performance
|
8469
|
-
|
8547
|
+
__m256i aux_gindex;
|
8470
8548
|
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
|
8471
|
-
const uint16_t * sindex = (const uint16_t *)&aux_sindex;
|
8472
8549
|
|
8473
8550
|
__m256 accumf = _mm256_setzero_ps();
|
8474
8551
|
for (int i = 0; i < nb; ++i) {
|
@@ -8483,26 +8560,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
8483
8560
|
|
8484
8561
|
__m256i sumi1 = _mm256_setzero_si256();
|
8485
8562
|
__m256i sumi2 = _mm256_setzero_si256();
|
8486
|
-
for (int ib32 = 0; ib32 < QK_K/32; ib32 +=
|
8563
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
|
8564
|
+
|
8565
|
+
const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
|
8566
|
+
aux_gindex = _mm256_and_si256(q2_data, m511);
|
8567
|
+
|
8568
|
+
const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
|
8569
|
+
const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
|
8570
|
+
const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
|
8571
|
+
|
8572
|
+
const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
|
8573
|
+
const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
|
8574
|
+
|
8487
8575
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8488
8576
|
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8489
|
-
const
|
8490
|
-
|
8491
|
-
|
8492
|
-
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]],
|
8493
|
-
|
8494
|
-
const __m256i
|
8495
|
-
|
8496
|
-
const __m256i
|
8497
|
-
|
8577
|
+
const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8578
|
+
const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8579
|
+
|
8580
|
+
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
|
8581
|
+
iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
|
8582
|
+
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
|
8583
|
+
iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
|
8584
|
+
const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
|
8585
|
+
iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
|
8586
|
+
const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
|
8587
|
+
iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
|
8588
|
+
|
8589
|
+
const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
|
8590
|
+
const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
|
8591
|
+
const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
|
8592
|
+
const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
|
8593
|
+
|
8594
|
+
__m256i signs;
|
8595
|
+
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
|
8596
|
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
8597
|
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
|
8598
|
+
|
8599
|
+
signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
|
8600
|
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
8601
|
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
|
8602
|
+
|
8603
|
+
signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
|
8604
|
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
8605
|
+
const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
|
8606
|
+
|
8607
|
+
signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
|
8608
|
+
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
8609
|
+
const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
|
8610
|
+
|
8498
8611
|
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
8499
8612
|
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
8613
|
+
const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
|
8614
|
+
const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
|
8500
8615
|
|
8501
8616
|
const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
|
8502
8617
|
const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
|
8618
|
+
const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
|
8619
|
+
const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
|
8503
8620
|
|
8504
8621
|
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
|
8505
8622
|
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
|
8623
|
+
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
|
8624
|
+
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
|
8506
8625
|
}
|
8507
8626
|
|
8508
8627
|
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
@@ -8551,6 +8670,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
8551
8670
|
#endif
|
8552
8671
|
}
|
8553
8672
|
|
8673
|
+
// TODO
|
8674
|
+
void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
8675
|
+
assert(n % QK_K == 0);
|
8676
|
+
|
8677
|
+
const block_iq3_xxs * restrict x = vx;
|
8678
|
+
const block_q8_K * restrict y = vy;
|
8679
|
+
|
8680
|
+
const int nb = n / QK_K;
|
8681
|
+
|
8682
|
+
#if defined(__ARM_NEON)
|
8683
|
+
|
8684
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
8685
|
+
|
8686
|
+
uint32_t aux32[2];
|
8687
|
+
|
8688
|
+
ggml_int8x16x4_t q3s;
|
8689
|
+
ggml_int8x16x4_t q8b;
|
8690
|
+
|
8691
|
+
float sumf = 0;
|
8692
|
+
for (int i = 0; i < nb; ++i) {
|
8693
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
8694
|
+
const uint8_t * restrict q3 = x[i].qs;
|
8695
|
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
8696
|
+
const int8_t * restrict q8 = y[i].qs;
|
8697
|
+
float sumf1 = 0, sumf2 = 0;
|
8698
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
8699
|
+
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
8700
|
+
memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
|
8701
|
+
const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
|
8702
|
+
const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
|
8703
|
+
const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
|
8704
|
+
const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
|
8705
|
+
q3 += 16;
|
8706
|
+
q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
|
8707
|
+
q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
|
8708
|
+
q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
8709
|
+
q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
8710
|
+
q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
|
8711
|
+
q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
|
8712
|
+
q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
|
8713
|
+
q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
|
8714
|
+
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
8715
|
+
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
8716
|
+
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
|
8717
|
+
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
|
8718
|
+
}
|
8719
|
+
sumf += d*(sumf1 + sumf2);
|
8720
|
+
}
|
8721
|
+
*s = 0.5f * sumf;
|
8722
|
+
|
8723
|
+
#elif defined(__AVX2__)
|
8724
|
+
|
8725
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
8726
|
+
|
8727
|
+
uint32_t aux32[2];
|
8728
|
+
|
8729
|
+
__m256 accumf = _mm256_setzero_ps();
|
8730
|
+
for (int i = 0; i < nb; ++i) {
|
8731
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
8732
|
+
const uint8_t * restrict q3 = x[i].qs;
|
8733
|
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
8734
|
+
const int8_t * restrict q8 = y[i].qs;
|
8735
|
+
__m256i sumi1 = _mm256_setzero_si256();
|
8736
|
+
__m256i sumi2 = _mm256_setzero_si256();
|
8737
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
8738
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8739
|
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
8740
|
+
const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
|
8741
|
+
iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
8742
|
+
q3 += 8;
|
8743
|
+
const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
|
8744
|
+
iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
|
8745
|
+
q3 += 8;
|
8746
|
+
memcpy(aux32, gas, 8); gas += 8;
|
8747
|
+
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
|
8748
|
+
signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
|
8749
|
+
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
|
8750
|
+
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
8751
|
+
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
8752
|
+
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
8753
|
+
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
8754
|
+
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
8755
|
+
const uint16_t ls1 = aux32[0] >> 28;
|
8756
|
+
const uint16_t ls2 = aux32[1] >> 28;
|
8757
|
+
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
|
8758
|
+
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
|
8759
|
+
sumi1 = _mm256_add_epi32(sumi1, p1);
|
8760
|
+
sumi2 = _mm256_add_epi32(sumi2, p2);
|
8761
|
+
}
|
8762
|
+
|
8763
|
+
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
8764
|
+
|
8765
|
+
}
|
8766
|
+
|
8767
|
+
*s = 0.25f * hsum_float_8(accumf);
|
8768
|
+
|
8769
|
+
#else
|
8770
|
+
|
8771
|
+
uint32_t aux32;
|
8772
|
+
|
8773
|
+
float sumf = 0.f;
|
8774
|
+
for (int i = 0; i < nb; ++i) {
|
8775
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
8776
|
+
const uint8_t * restrict q3 = x[i].qs;
|
8777
|
+
const uint8_t * restrict gas = x[i].qs + QK_K/4;
|
8778
|
+
const int8_t * restrict q8 = y[i].qs;
|
8779
|
+
int32_t bsum = 0;
|
8780
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
8781
|
+
memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
|
8782
|
+
const uint32_t ls = 2*(aux32 >> 28) + 1;
|
8783
|
+
int32_t sumi = 0;
|
8784
|
+
for (int l = 0; l < 4; ++l) {
|
8785
|
+
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
|
8786
|
+
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
|
8787
|
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
8788
|
+
for (int j = 0; j < 4; ++j) {
|
8789
|
+
sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
|
8790
|
+
sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
|
8791
|
+
}
|
8792
|
+
q8 += 8;
|
8793
|
+
}
|
8794
|
+
q3 += 8;
|
8795
|
+
bsum += sumi * ls;
|
8796
|
+
}
|
8797
|
+
sumf += d * bsum;
|
8798
|
+
}
|
8799
|
+
*s = 0.25f * sumf;
|
8800
|
+
#endif
|
8801
|
+
}
|
8802
|
+
|
8554
8803
|
// ================================ IQ2 quantization =============================================
|
8555
8804
|
|
8556
8805
|
typedef struct {
|
@@ -8790,8 +9039,6 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
8790
9039
|
int8_t L[32];
|
8791
9040
|
int8_t Laux[32];
|
8792
9041
|
float waux[32];
|
8793
|
-
bool is_on_grid[4];
|
8794
|
-
bool is_on_grid_aux[4];
|
8795
9042
|
uint8_t block_signs[4];
|
8796
9043
|
uint32_t q2[2*(QK_K/32)];
|
8797
9044
|
|
@@ -8841,10 +9088,11 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
8841
9088
|
memset(L, 0, 32);
|
8842
9089
|
continue;
|
8843
9090
|
}
|
9091
|
+
float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
|
9092
|
+
float eff_max = scale*kMaxQ;
|
8844
9093
|
float best = 0;
|
8845
|
-
|
8846
|
-
|
8847
|
-
float id = (2*kMaxQ-1+is*0.1f)/max;
|
9094
|
+
for (int is = -6; is <= 6; ++is) {
|
9095
|
+
float id = (2*kMaxQ-1+is*0.1f)/eff_max;
|
8848
9096
|
float this_scale = 1/id;
|
8849
9097
|
for (int k = 0; k < 4; ++k) {
|
8850
9098
|
for (int i = 0; i < 8; ++i) {
|
@@ -8854,9 +9102,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
8854
9102
|
uint16_t u = 0;
|
8855
9103
|
for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
|
8856
9104
|
int grid_index = kmap_q2xs[u];
|
8857
|
-
is_on_grid_aux[k] = true;
|
8858
9105
|
if (grid_index < 0) {
|
8859
|
-
is_on_grid_aux[k] = false;
|
8860
9106
|
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
8861
9107
|
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
|
8862
9108
|
}
|
@@ -8870,16 +9116,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
8870
9116
|
}
|
8871
9117
|
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
8872
9118
|
scale = sumqx/sumq2; best = scale*sumqx;
|
8873
|
-
|
8874
|
-
for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
|
9119
|
+
memcpy(L, Laux, 32);
|
8875
9120
|
}
|
8876
9121
|
}
|
8877
|
-
|
8878
|
-
for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
|
8879
|
-
if (n_not_ongrid > 0 && scale > 0) {
|
9122
|
+
if (scale > 0) {
|
8880
9123
|
float id = 1/scale;
|
8881
9124
|
for (int k = 0; k < 4; ++k) {
|
8882
|
-
if (is_on_grid[k]) continue;
|
8883
9125
|
uint16_t u = 0;
|
8884
9126
|
for (int i = 0; i < 8; ++i) {
|
8885
9127
|
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
|
@@ -8935,49 +9177,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
8935
9177
|
float d = max_scale/31;
|
8936
9178
|
y[ibl].d = GGML_FP32_TO_FP16(d);
|
8937
9179
|
float id = 1/d;
|
8938
|
-
float sumqx = 0, sumq2 = 0;
|
8939
9180
|
for (int ib = 0; ib < QK_K/32; ++ib) {
|
8940
9181
|
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
8941
9182
|
l = MAX(0, MIN(15, l));
|
8942
9183
|
q2[2*ib+1] |= ((uint32_t)l << 28);
|
8943
|
-
const float * xb = xbl + 32*ib;
|
8944
|
-
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
8945
|
-
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
8946
|
-
const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
|
8947
|
-
const float db = d * (1 + 2*l);
|
8948
|
-
uint32_t u = 0;
|
8949
|
-
for (int k = 0; k < 4; ++k) {
|
8950
|
-
const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
|
8951
|
-
const float * xk = xb + 8*k;
|
8952
|
-
const float * wk = weight + 8*k;
|
8953
|
-
const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
|
8954
|
-
float best_mse = 0; int best_index = aux8[k];
|
8955
|
-
for (int j = 0; j < 8; ++j) {
|
8956
|
-
float diff = db * grid[j] * signs[j] - xk[j];
|
8957
|
-
best_mse += wk[j] * diff * diff;
|
8958
|
-
}
|
8959
|
-
for (int idx = 0; idx < 256; ++idx) {
|
8960
|
-
grid = (const uint8_t *)(kgrid_q2xs + idx);
|
8961
|
-
float mse = 0;
|
8962
|
-
for (int j = 0; j < 8; ++j) {
|
8963
|
-
float diff = db * grid[j] * signs[j] - xk[j];
|
8964
|
-
mse += wk[j] * diff * diff;
|
8965
|
-
}
|
8966
|
-
if (mse < best_mse) {
|
8967
|
-
best_mse = mse; best_index = idx;
|
8968
|
-
}
|
8969
|
-
}
|
8970
|
-
u |= (best_index << 8*k);
|
8971
|
-
grid = (const uint8_t *)(kgrid_q2xs + best_index);
|
8972
|
-
//grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
|
8973
|
-
for (int j = 0; j < 8; ++j) {
|
8974
|
-
float q = db * grid[j] * signs[j];
|
8975
|
-
sumqx += wk[j] * q * xk[j];
|
8976
|
-
sumq2 += wk[j] * q * q;
|
8977
|
-
}
|
8978
|
-
}
|
8979
|
-
q2[2*ib] = u;
|
8980
|
-
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
|
8981
9184
|
}
|
8982
9185
|
memcpy(y[ibl].qs, q2, QK_K/4);
|
8983
9186
|
}
|
@@ -9189,3 +9392,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
|
|
9189
9392
|
return nrow * nblock * sizeof(block_iq2_xs);
|
9190
9393
|
}
|
9191
9394
|
|
9395
|
+
//
|
9396
|
+
// ============================================= 3-bit using D4 lattice
|
9397
|
+
//
|
9398
|
+
|
9399
|
+
typedef struct {
|
9400
|
+
uint32_t * grid;
|
9401
|
+
int * map;
|
9402
|
+
uint16_t * neighbours;
|
9403
|
+
} iq3_entry_t;
|
9404
|
+
|
9405
|
+
static iq3_entry_t iq3_data[1] = {
|
9406
|
+
{NULL, NULL, NULL},
|
9407
|
+
};
|
9408
|
+
|
9409
|
+
static inline int iq3_data_index(int grid_size) {
|
9410
|
+
(void)grid_size;
|
9411
|
+
GGML_ASSERT(grid_size == 256);
|
9412
|
+
return 0;
|
9413
|
+
}
|
9414
|
+
|
9415
|
+
static int iq3_compare_func(const void * left, const void * right) {
|
9416
|
+
const int * l = (const int *)left;
|
9417
|
+
const int * r = (const int *)right;
|
9418
|
+
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
|
9419
|
+
}
|
9420
|
+
|
9421
|
+
void iq3xs_init_impl(int grid_size) {
|
9422
|
+
const int gindex = iq3_data_index(grid_size);
|
9423
|
+
if (iq3_data[gindex].grid) {
|
9424
|
+
return;
|
9425
|
+
}
|
9426
|
+
static const uint16_t kgrid_256[256] = {
|
9427
|
+
0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
|
9428
|
+
81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
|
9429
|
+
169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
|
9430
|
+
327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
|
9431
|
+
536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
|
9432
|
+
698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
|
9433
|
+
992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
|
9434
|
+
1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
|
9435
|
+
1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
|
9436
|
+
1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
|
9437
|
+
1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
|
9438
|
+
2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
|
9439
|
+
2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
|
9440
|
+
2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
|
9441
|
+
3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
|
9442
|
+
3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
|
9443
|
+
};
|
9444
|
+
const int kmap_size = 4096;
|
9445
|
+
const int nwant = 2;
|
9446
|
+
const uint16_t * kgrid = kgrid_256;
|
9447
|
+
uint32_t * kgrid_q3xs;
|
9448
|
+
int * kmap_q3xs;
|
9449
|
+
uint16_t * kneighbors_q3xs;
|
9450
|
+
|
9451
|
+
printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
|
9452
|
+
uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
|
9453
|
+
for (int k = 0; k < grid_size; ++k) {
|
9454
|
+
int8_t * pos = (int8_t *)(the_grid + k);
|
9455
|
+
for (int i = 0; i < 4; ++i) {
|
9456
|
+
int l = (kgrid[k] >> 3*i) & 0x7;
|
9457
|
+
pos[i] = 2*l + 1;
|
9458
|
+
}
|
9459
|
+
}
|
9460
|
+
kgrid_q3xs = the_grid;
|
9461
|
+
iq3_data[gindex].grid = the_grid;
|
9462
|
+
kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
|
9463
|
+
iq3_data[gindex].map = kmap_q3xs;
|
9464
|
+
for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
|
9465
|
+
uint32_t aux32;
|
9466
|
+
uint8_t * aux8 = (uint8_t *)&aux32;
|
9467
|
+
for (int i = 0; i < grid_size; ++i) {
|
9468
|
+
aux32 = kgrid_q3xs[i];
|
9469
|
+
uint16_t index = 0;
|
9470
|
+
for (int k=0; k<4; ++k) {
|
9471
|
+
uint16_t q = (aux8[k] - 1)/2;
|
9472
|
+
index |= (q << 3*k);
|
9473
|
+
}
|
9474
|
+
kmap_q3xs[index] = i;
|
9475
|
+
}
|
9476
|
+
int8_t pos[4];
|
9477
|
+
int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
|
9478
|
+
int num_neighbors = 0, num_not_in_map = 0;
|
9479
|
+
for (int i = 0; i < kmap_size; ++i) {
|
9480
|
+
if (kmap_q3xs[i] >= 0) continue;
|
9481
|
+
++num_not_in_map;
|
9482
|
+
for (int k = 0; k < 4; ++k) {
|
9483
|
+
int l = (i >> 3*k) & 0x7;
|
9484
|
+
pos[k] = 2*l + 1;
|
9485
|
+
}
|
9486
|
+
for (int j = 0; j < grid_size; ++j) {
|
9487
|
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
|
9488
|
+
int d2 = 0;
|
9489
|
+
for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
|
9490
|
+
dist2[2*j+0] = d2;
|
9491
|
+
dist2[2*j+1] = j;
|
9492
|
+
}
|
9493
|
+
qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
|
9494
|
+
int n = 0; int d2 = dist2[0];
|
9495
|
+
int nhave = 1;
|
9496
|
+
for (int j = 0; j < grid_size; ++j) {
|
9497
|
+
if (dist2[2*j] > d2) {
|
9498
|
+
if (nhave == nwant) break;
|
9499
|
+
d2 = dist2[2*j];
|
9500
|
+
++nhave;
|
9501
|
+
}
|
9502
|
+
++n;
|
9503
|
+
}
|
9504
|
+
num_neighbors += n;
|
9505
|
+
}
|
9506
|
+
printf("%s: %d neighbours in total\n", __func__, num_neighbors);
|
9507
|
+
kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
|
9508
|
+
iq3_data[gindex].neighbours = kneighbors_q3xs;
|
9509
|
+
int counter = 0;
|
9510
|
+
for (int i = 0; i < kmap_size; ++i) {
|
9511
|
+
if (kmap_q3xs[i] >= 0) continue;
|
9512
|
+
for (int k = 0; k < 4; ++k) {
|
9513
|
+
int l = (i >> 3*k) & 0x7;
|
9514
|
+
pos[k] = 2*l + 1;
|
9515
|
+
}
|
9516
|
+
for (int j = 0; j < grid_size; ++j) {
|
9517
|
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
|
9518
|
+
int d2 = 0;
|
9519
|
+
for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
|
9520
|
+
dist2[2*j+0] = d2;
|
9521
|
+
dist2[2*j+1] = j;
|
9522
|
+
}
|
9523
|
+
qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
|
9524
|
+
kmap_q3xs[i] = -(counter + 1);
|
9525
|
+
int d2 = dist2[0];
|
9526
|
+
uint16_t * start = &kneighbors_q3xs[counter++];
|
9527
|
+
int n = 0, nhave = 1;
|
9528
|
+
for (int j = 0; j < grid_size; ++j) {
|
9529
|
+
if (dist2[2*j] > d2) {
|
9530
|
+
if (nhave == nwant) break;
|
9531
|
+
d2 = dist2[2*j];
|
9532
|
+
++nhave;
|
9533
|
+
}
|
9534
|
+
kneighbors_q3xs[counter++] = dist2[2*j+1];
|
9535
|
+
++n;
|
9536
|
+
}
|
9537
|
+
*start = n;
|
9538
|
+
}
|
9539
|
+
free(dist2);
|
9540
|
+
}
|
9541
|
+
|
9542
|
+
void iq3xs_free_impl(int grid_size) {
|
9543
|
+
GGML_ASSERT(grid_size == 256);
|
9544
|
+
const int gindex = iq3_data_index(grid_size);
|
9545
|
+
if (iq3_data[gindex].grid) {
|
9546
|
+
free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
|
9547
|
+
free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
|
9548
|
+
free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
|
9549
|
+
}
|
9550
|
+
}
|
9551
|
+
|
9552
|
+
static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
|
9553
|
+
const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
|
9554
|
+
int num_neighbors = neighbours[0];
|
9555
|
+
GGML_ASSERT(num_neighbors > 0);
|
9556
|
+
float best_d2 = FLT_MAX;
|
9557
|
+
int grid_index = -1;
|
9558
|
+
for (int j = 1; j <= num_neighbors; ++j) {
|
9559
|
+
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
9560
|
+
float d2 = 0;
|
9561
|
+
for (int i = 0; i < 4; ++i) {
|
9562
|
+
float q = pg[i];
|
9563
|
+
float diff = scale*q - xval[i];
|
9564
|
+
d2 += weight[i]*diff*diff;
|
9565
|
+
}
|
9566
|
+
if (d2 < best_d2) {
|
9567
|
+
best_d2 = d2; grid_index = neighbours[j];
|
9568
|
+
}
|
9569
|
+
}
|
9570
|
+
GGML_ASSERT(grid_index >= 0);
|
9571
|
+
const int8_t * pg = (const int8_t *)(grid + grid_index);
|
9572
|
+
for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
|
9573
|
+
return grid_index;
|
9574
|
+
}
|
9575
|
+
|
9576
|
+
static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
9577
|
+
|
9578
|
+
const int gindex = iq3_data_index(256);
|
9579
|
+
|
9580
|
+
const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
|
9581
|
+
const int * kmap_q3xs = iq3_data[gindex].map;
|
9582
|
+
const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
|
9583
|
+
|
9584
|
+
//GGML_ASSERT(quant_weights && "missing quantization weights");
|
9585
|
+
GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
|
9586
|
+
GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
|
9587
|
+
GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
|
9588
|
+
GGML_ASSERT(n%QK_K == 0);
|
9589
|
+
|
9590
|
+
const int kMaxQ = 8;
|
9591
|
+
|
9592
|
+
const int nbl = n/256;
|
9593
|
+
|
9594
|
+
block_iq3_xxs * y = vy;
|
9595
|
+
|
9596
|
+
float scales[QK_K/32];
|
9597
|
+
float weight[32];
|
9598
|
+
float xval[32];
|
9599
|
+
int8_t L[32];
|
9600
|
+
int8_t Laux[32];
|
9601
|
+
float waux[32];
|
9602
|
+
bool is_on_grid[8];
|
9603
|
+
bool is_on_grid_aux[8];
|
9604
|
+
uint8_t block_signs[8];
|
9605
|
+
uint8_t q3[3*(QK_K/8)];
|
9606
|
+
uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
|
9607
|
+
|
9608
|
+
for (int ibl = 0; ibl < nbl; ++ibl) {
|
9609
|
+
|
9610
|
+
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
9611
|
+
memset(q3, 0, 3*QK_K/8);
|
9612
|
+
|
9613
|
+
float max_scale = 0;
|
9614
|
+
|
9615
|
+
const float * xbl = x + QK_K*ibl;
|
9616
|
+
float sumx2 = 0;
|
9617
|
+
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
9618
|
+
float sigma2 = sumx2/QK_K;
|
9619
|
+
|
9620
|
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
9621
|
+
const float * xb = xbl + 32*ib;
|
9622
|
+
if (quant_weights) {
|
9623
|
+
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
9624
|
+
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
9625
|
+
} else {
|
9626
|
+
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
|
9627
|
+
}
|
9628
|
+
for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
|
9629
|
+
for (int k = 0; k < 4; ++k) {
|
9630
|
+
int nflip = 0;
|
9631
|
+
uint8_t s = 0;
|
9632
|
+
for (int i = 0; i < 8; ++i) {
|
9633
|
+
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
|
9634
|
+
else {
|
9635
|
+
xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
|
9636
|
+
}
|
9637
|
+
}
|
9638
|
+
if (nflip%2) {
|
9639
|
+
int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
|
9640
|
+
for (int i = 1; i < 8; ++i) {
|
9641
|
+
float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
|
9642
|
+
if (ax < min) {
|
9643
|
+
min = ax; imin = i;
|
9644
|
+
}
|
9645
|
+
}
|
9646
|
+
xval[8*k+imin] = -xval[8*k+imin];
|
9647
|
+
s ^= (1 << imin);
|
9648
|
+
}
|
9649
|
+
block_signs[k] = s & 127;
|
9650
|
+
}
|
9651
|
+
float max = xval[0];
|
9652
|
+
for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
|
9653
|
+
if (!max) {
|
9654
|
+
scales[ib] = 0;
|
9655
|
+
memset(L, 0, 32);
|
9656
|
+
continue;
|
9657
|
+
}
|
9658
|
+
float best = 0;
|
9659
|
+
float scale = max/(2*kMaxQ-1);
|
9660
|
+
for (int is = -15; is <= 15; ++is) {
|
9661
|
+
float id = (2*kMaxQ-1+is*0.2f)/max;
|
9662
|
+
float this_scale = 1/id;
|
9663
|
+
for (int k = 0; k < 8; ++k) {
|
9664
|
+
for (int i = 0; i < 4; ++i) {
|
9665
|
+
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
9666
|
+
Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
|
9667
|
+
}
|
9668
|
+
uint16_t u = 0;
|
9669
|
+
for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
|
9670
|
+
int grid_index = kmap_q3xs[u];
|
9671
|
+
is_on_grid_aux[k] = true;
|
9672
|
+
if (grid_index < 0) {
|
9673
|
+
is_on_grid_aux[k] = false;
|
9674
|
+
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
|
9675
|
+
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
|
9676
|
+
}
|
9677
|
+
}
|
9678
|
+
float sumqx = 0, sumq2 = 0;
|
9679
|
+
for (int i = 0; i < 32; ++i) {
|
9680
|
+
float w = weight[i];
|
9681
|
+
float q = 2*Laux[i] + 1;
|
9682
|
+
sumqx += w*xval[i]*q;
|
9683
|
+
sumq2 += w*q*q;
|
9684
|
+
}
|
9685
|
+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
9686
|
+
scale = sumqx/sumq2; best = scale*sumqx;
|
9687
|
+
for (int i = 0; i < 32; ++i) L[i] = Laux[i];
|
9688
|
+
for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
|
9689
|
+
}
|
9690
|
+
}
|
9691
|
+
int n_not_ongrid = 0;
|
9692
|
+
for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
|
9693
|
+
if (n_not_ongrid > 0 && scale > 0) {
|
9694
|
+
float id = 1/scale;
|
9695
|
+
for (int k = 0; k < 8; ++k) {
|
9696
|
+
if (is_on_grid[k]) continue;
|
9697
|
+
uint16_t u = 0;
|
9698
|
+
for (int i = 0; i < 4; ++i) {
|
9699
|
+
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
|
9700
|
+
l = MAX(0, MIN(kMaxQ-1, l));
|
9701
|
+
u |= (l << 3*i);
|
9702
|
+
}
|
9703
|
+
int grid_index = kmap_q3xs[u];
|
9704
|
+
if (grid_index < 0) {
|
9705
|
+
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
|
9706
|
+
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
|
9707
|
+
}
|
9708
|
+
const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
|
9709
|
+
for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
|
9710
|
+
}
|
9711
|
+
float sumqx = 0, sumq2 = 0;
|
9712
|
+
for (int i = 0; i < 32; ++i) {
|
9713
|
+
float w = weight[i];
|
9714
|
+
float q = 2*L[i] + 1;
|
9715
|
+
sumqx += w*xval[i]*q;
|
9716
|
+
sumq2 += w*q*q;
|
9717
|
+
}
|
9718
|
+
if (sumq2 > 0) scale = sumqx/sumq2;
|
9719
|
+
}
|
9720
|
+
if (scale < 0) {
|
9721
|
+
// This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
|
9722
|
+
// and correspondingly flip quant signs.
|
9723
|
+
scale = -scale;
|
9724
|
+
for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
|
9725
|
+
}
|
9726
|
+
for (int k = 0; k < 8; ++k) {
|
9727
|
+
uint16_t u = 0;
|
9728
|
+
for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
|
9729
|
+
int grid_index = kmap_q3xs[u];
|
9730
|
+
if (grid_index < 0) {
|
9731
|
+
printf("Oops: found point %u not on grid:", u);
|
9732
|
+
for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
|
9733
|
+
printf("\n");
|
9734
|
+
GGML_ASSERT(false);
|
9735
|
+
}
|
9736
|
+
q3[8*ib+k] = grid_index;
|
9737
|
+
}
|
9738
|
+
scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
|
9739
|
+
GGML_ASSERT(scale >= 0);
|
9740
|
+
scales[ib] = scale;
|
9741
|
+
max_scale = MAX(max_scale, scale);
|
9742
|
+
}
|
9743
|
+
|
9744
|
+
if (!max_scale) {
|
9745
|
+
memset(y[ibl].qs, 0, 3*QK_K/8);
|
9746
|
+
continue;
|
9747
|
+
}
|
9748
|
+
|
9749
|
+
float d = max_scale/31;
|
9750
|
+
y[ibl].d = GGML_FP32_TO_FP16(d);
|
9751
|
+
float id = 1/d;
|
9752
|
+
float sumqx = 0, sumq2 = 0;
|
9753
|
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
9754
|
+
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
9755
|
+
l = MAX(0, MIN(15, l));
|
9756
|
+
scales_and_signs[ib] |= ((uint32_t)l << 28);
|
9757
|
+
if (false) {
|
9758
|
+
const float * xb = xbl + 32*ib;
|
9759
|
+
if (quant_weights) {
|
9760
|
+
const float * qw = quant_weights + QK_K*ibl + 32*ib;
|
9761
|
+
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
9762
|
+
} else {
|
9763
|
+
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
|
9764
|
+
}
|
9765
|
+
const float db = 0.25f * d * (1 + 2*l);
|
9766
|
+
for (int k = 0; k < 8; ++k) {
|
9767
|
+
const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
|
9768
|
+
const float * xk = xb + 4*k;
|
9769
|
+
const float * wk = weight + 4*k;
|
9770
|
+
//const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
|
9771
|
+
const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
|
9772
|
+
float best_mse = 0; int best_index = q3[8*ib+k];
|
9773
|
+
for (int j = 0; j < 4; ++j) {
|
9774
|
+
float diff = db * grid[j] * signs[j] - xk[j];
|
9775
|
+
best_mse += wk[j] * diff * diff;
|
9776
|
+
}
|
9777
|
+
for (int idx = 0; idx < 256; ++idx) {
|
9778
|
+
//grid = (const uint8_t *)(kgrid_q3xs + idx);
|
9779
|
+
grid = (const uint8_t *)(iq3xxs_grid + idx);
|
9780
|
+
float mse = 0;
|
9781
|
+
for (int j = 0; j < 4; ++j) {
|
9782
|
+
float diff = db * grid[j] * signs[j] - xk[j];
|
9783
|
+
mse += wk[j] * diff * diff;
|
9784
|
+
}
|
9785
|
+
if (mse < best_mse) {
|
9786
|
+
best_mse = mse; best_index = idx;
|
9787
|
+
}
|
9788
|
+
}
|
9789
|
+
q3[8*ib+k] = best_index;
|
9790
|
+
//grid = (const uint8_t *)(kgrid_q3xs + best_index);
|
9791
|
+
grid = (const uint8_t *)(iq3xxs_grid + best_index);
|
9792
|
+
for (int j = 0; j < 4; ++j) {
|
9793
|
+
float q = db * grid[j] * signs[j];
|
9794
|
+
sumqx += wk[j] * q * xk[j];
|
9795
|
+
sumq2 += wk[j] * q * q;
|
9796
|
+
}
|
9797
|
+
}
|
9798
|
+
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
|
9799
|
+
}
|
9800
|
+
}
|
9801
|
+
memcpy(y[ibl].qs, q3, 3*QK_K/8);
|
9802
|
+
}
|
9803
|
+
}
|
9804
|
+
|
9805
|
+
size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
9806
|
+
(void)hist;
|
9807
|
+
GGML_ASSERT(n_per_row%QK_K == 0);
|
9808
|
+
int nblock = n_per_row/QK_K;
|
9809
|
+
char * qrow = (char *)dst;
|
9810
|
+
for (int row = 0; row < nrow; ++row) {
|
9811
|
+
quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
|
9812
|
+
src += n_per_row;
|
9813
|
+
qrow += nblock*sizeof(block_iq3_xxs);
|
9814
|
+
}
|
9815
|
+
return nrow * nblock * sizeof(block_iq3_xxs);
|
9816
|
+
}
|
9817
|
+
|
9818
|
+
void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
|
9819
|
+
assert(k % QK_K == 0);
|
9820
|
+
block_iq3_xxs * restrict y = vy;
|
9821
|
+
quantize_row_iq3_xxs_reference(x, y, k);
|
9822
|
+
}
|
9823
|
+
|
9824
|
+
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
|
9825
|
+
assert(k % QK_K == 0);
|
9826
|
+
quantize_row_iq3_xxs_impl(x, y, k, NULL);
|
9827
|
+
}
|