llama_cpp 0.12.3 → 0.12.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2381,19 +2381,20 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
2381
2381
 
2382
2382
  uint8_t L[QK_K];
2383
2383
  uint8_t Laux[32];
2384
+ uint8_t Ls[QK_K/32];
2385
+ uint8_t Lm[QK_K/32];
2384
2386
  float weights[32];
2385
- float mins[QK_K/32];
2386
- float scales[QK_K/32];
2387
+ float sw[QK_K/32];
2388
+ float mins[QK_K/32];
2389
+ float scales[QK_K/32];
2387
2390
 
2388
2391
  for (int i = 0; i < nb; i++) {
2389
2392
 
2390
2393
  float sum_x2 = 0;
2391
2394
  for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
2392
- float sigma2 = sum_x2/QK_K;
2395
+ float sigma2 = 2*sum_x2/QK_K;
2393
2396
  float av_x = sqrtf(sigma2);
2394
2397
 
2395
- float max_scale = 0; // as we are deducting the min, scales are always positive
2396
- float max_min = 0;
2397
2398
  for (int j = 0; j < QK_K/32; ++j) {
2398
2399
  if (quant_weights) {
2399
2400
  const float * qw = quant_weights + QK_K*i + 32*j;
@@ -2401,25 +2402,17 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
2401
2402
  } else {
2402
2403
  for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2403
2404
  }
2405
+ float sumw = 0;
2406
+ for (int l = 0; l < 32; ++l) sumw += weights[l];
2407
+ sw[j] = sumw;
2404
2408
  scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
2405
- //scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
2406
- float scale = scales[j];
2407
- if (scale > max_scale) {
2408
- max_scale = scale;
2409
- }
2410
- float min = mins[j];
2411
- if (min > max_min) {
2412
- max_min = min;
2413
- }
2414
2409
  }
2415
2410
 
2416
- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2417
- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2411
+ float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
2412
+ float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw);
2418
2413
  for (int j = 0; j < QK_K/32; ++j) {
2419
- uint8_t ls = nearest_int(inv_scale*scales[j]);
2420
- uint8_t lm = nearest_int(inv_min*mins[j]);
2421
- ls = MIN(63, ls);
2422
- lm = MIN(63, lm);
2414
+ uint8_t ls = Ls[j];
2415
+ uint8_t lm = Lm[j];
2423
2416
  if (j < 4) {
2424
2417
  y[i].scales[j] = ls;
2425
2418
  y[i].scales[j+4] = lm;
@@ -2429,8 +2422,8 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
2429
2422
  y[i].scales[j-0] |= ((lm >> 4) << 6);
2430
2423
  }
2431
2424
  }
2432
- y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
2433
- y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
2425
+ y[i].d = GGML_FP32_TO_FP16(d_block);
2426
+ y[i].dmin = GGML_FP32_TO_FP16(m_block);
2434
2427
 
2435
2428
  uint8_t sc, m;
2436
2429
  for (int j = 0; j < QK_K/32; ++j) {
@@ -2688,20 +2681,21 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
2688
2681
  const int nb = n_per_row / QK_K;
2689
2682
 
2690
2683
  uint8_t L[QK_K];
2691
- float mins[QK_K/32];
2692
- float scales[QK_K/32];
2693
- float weights[32];
2694
2684
  uint8_t Laux[32];
2685
+ uint8_t Ls[QK_K/32];
2686
+ uint8_t Lm[QK_K/32];
2687
+ float mins[QK_K/32];
2688
+ float scales[QK_K/32];
2689
+ float sw[QK_K/32];
2690
+ float weights[32];
2695
2691
 
2696
2692
  for (int i = 0; i < nb; i++) {
2697
2693
 
2698
2694
  float sum_x2 = 0;
2699
2695
  for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
2700
- float sigma2 = sum_x2/QK_K;
2696
+ float sigma2 = 2*sum_x2/QK_K;
2701
2697
  float av_x = sqrtf(sigma2);
2702
2698
 
2703
- float max_scale = 0; // as we are deducting the min, scales are always positive
2704
- float max_min = 0;
2705
2699
  for (int j = 0; j < QK_K/32; ++j) {
2706
2700
  if (quant_weights) {
2707
2701
  const float * qw = quant_weights + QK_K*i + 32*j;
@@ -2709,22 +2703,19 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
2709
2703
  } else {
2710
2704
  for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2711
2705
  }
2706
+ float sumw = 0;
2707
+ for (int l = 0; l < 32; ++l) sumw += weights[l];
2708
+ sw[j] = sumw;
2709
+
2712
2710
  scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
2713
- float scale = scales[j];
2714
- if (scale > max_scale) {
2715
- max_scale = scale;
2716
- }
2717
- float min = mins[j];
2718
- if (min > max_min) {
2719
- max_min = min;
2720
- }
2721
2711
  }
2722
2712
 
2723
- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2724
- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2713
+ float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
2714
+ float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw);
2715
+
2725
2716
  for (int j = 0; j < QK_K/32; ++j) {
2726
- uint8_t ls = nearest_int(inv_scale*scales[j]);
2727
- uint8_t lm = nearest_int(inv_min*mins[j]);
2717
+ uint8_t ls = Ls[j];
2718
+ uint8_t lm = Lm[j];
2728
2719
  ls = MIN(63, ls);
2729
2720
  lm = MIN(63, lm);
2730
2721
  if (j < 4) {
@@ -2736,8 +2727,8 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
2736
2727
  y[i].scales[j-0] |= ((lm >> 4) << 6);
2737
2728
  }
2738
2729
  }
2739
- y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
2740
- y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
2730
+ y[i].d = GGML_FP32_TO_FP16(d_block);
2731
+ y[i].dmin = GGML_FP32_TO_FP16(m_block);
2741
2732
 
2742
2733
  uint8_t sc, m;
2743
2734
  for (int j = 0; j < QK_K/32; ++j) {
@@ -3441,6 +3432,41 @@ static const uint64_t iq2xs_grid[512] = {
3441
3432
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3442
3433
  };
3443
3434
 
3435
+ static const uint32_t iq3xxs_grid[256] = {
3436
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
3437
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
3438
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
3439
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
3440
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
3441
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
3442
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
3443
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
3444
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
3445
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
3446
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
3447
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
3448
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
3449
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
3450
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
3451
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
3452
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
3453
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
3454
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
3455
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
3456
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
3457
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
3458
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
3459
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
3460
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
3461
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
3462
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
3463
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
3464
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
3465
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
3466
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
3467
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3468
+ };
3469
+
3444
3470
  static const uint8_t ksigns_iq2xs[128] = {
3445
3471
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3446
3472
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3507,6 +3533,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
3507
3533
  }
3508
3534
  }
3509
3535
 
3536
+ // ====================== 3.0625 bpw (de)-quantization
3537
+
3538
+ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
3539
+ assert(k % QK_K == 0);
3540
+ const int nb = k / QK_K;
3541
+
3542
+ uint32_t aux32;
3543
+
3544
+ for (int i = 0; i < nb; i++) {
3545
+
3546
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3547
+ const uint8_t * qs = x[i].qs;
3548
+ const uint8_t * scales_and_signs = qs + QK_K/4;
3549
+
3550
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3551
+ memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
3552
+ const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
3553
+ for (int l = 0; l < 4; ++l) {
3554
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
3555
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
3556
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
3557
+ for (int j = 0; j < 4; ++j) {
3558
+ y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
3559
+ y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
3560
+ }
3561
+ y += 8;
3562
+ }
3563
+ qs += 8;
3564
+ }
3565
+ }
3566
+ }
3567
+
3510
3568
  //===================================== Q8_K ==============================================
3511
3569
 
3512
3570
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -8458,17 +8516,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8458
8516
 
8459
8517
  const __m128i m4 = _mm_set1_epi8(0xf);
8460
8518
  const __m128i m1 = _mm_set1_epi8(1);
8461
- const __m128i m511 = _mm_set1_epi16(511);
8462
- const __m128i m127 = _mm_set1_epi16(127);
8519
+ const __m256i m511 = _mm256_set1_epi16(511);
8520
+ const __m256i mone = _mm256_set1_epi8(1);
8463
8521
 
8464
- const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8522
+ static const uint8_t k_bit_helper[32] = {
8523
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
8524
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
8525
+ };
8526
+ static const char block_sign_shuffle_mask_1[32] = {
8527
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
8528
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
8529
+ };
8530
+ static const char block_sign_shuffle_mask_2[32] = {
8531
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
8532
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
8533
+ };
8534
+ static const uint8_t bit_selector_mask_bytes[32] = {
8535
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
8536
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
8537
+ };
8538
+
8539
+ const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
8540
+ const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
8541
+ const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
8542
+ const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
8465
8543
 
8466
8544
  uint64_t aux64;
8467
8545
 
8468
8546
  // somewhat hacky, but gives a significant boost in performance
8469
- __m128i aux_gindex, aux_sindex;
8547
+ __m256i aux_gindex;
8470
8548
  const uint16_t * gindex = (const uint16_t *)&aux_gindex;
8471
- const uint16_t * sindex = (const uint16_t *)&aux_sindex;
8472
8549
 
8473
8550
  __m256 accumf = _mm256_setzero_ps();
8474
8551
  for (int i = 0; i < nb; ++i) {
@@ -8483,26 +8560,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8483
8560
 
8484
8561
  __m256i sumi1 = _mm256_setzero_si256();
8485
8562
  __m256i sumi2 = _mm256_setzero_si256();
8486
- for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8563
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
8564
+
8565
+ const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
8566
+ aux_gindex = _mm256_and_si256(q2_data, m511);
8567
+
8568
+ const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
8569
+ const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
8570
+ const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
8571
+
8572
+ const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
8573
+ const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
8574
+
8487
8575
  const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8488
8576
  const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8489
- const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
8490
- aux_gindex = _mm_and_si128(q2_data, m511);
8491
- aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
8492
- const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
8493
- const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
8494
- const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
8495
- const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
8496
- const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8497
- const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8577
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8578
+ const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8579
+
8580
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
8581
+ iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
8582
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
8583
+ iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
8584
+ const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
8585
+ iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
8586
+ const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
8587
+ iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
8588
+
8589
+ const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
8590
+ const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
8591
+ const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
8592
+ const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
8593
+
8594
+ __m256i signs;
8595
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
8596
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8597
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
8598
+
8599
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
8600
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8601
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
8602
+
8603
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
8604
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8605
+ const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
8606
+
8607
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
8608
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8609
+ const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
8610
+
8498
8611
  const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8499
8612
  const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8613
+ const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
8614
+ const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
8500
8615
 
8501
8616
  const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
8502
8617
  const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
8618
+ const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
8619
+ const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
8503
8620
 
8504
8621
  sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
8505
8622
  sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
8623
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
8624
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
8506
8625
  }
8507
8626
 
8508
8627
  accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
@@ -8551,6 +8670,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8551
8670
  #endif
8552
8671
  }
8553
8672
 
8673
+ // TODO
8674
+ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8675
+ assert(n % QK_K == 0);
8676
+
8677
+ const block_iq3_xxs * restrict x = vx;
8678
+ const block_q8_K * restrict y = vy;
8679
+
8680
+ const int nb = n / QK_K;
8681
+
8682
+ #if defined(__ARM_NEON)
8683
+
8684
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8685
+
8686
+ uint32_t aux32[2];
8687
+
8688
+ ggml_int8x16x4_t q3s;
8689
+ ggml_int8x16x4_t q8b;
8690
+
8691
+ float sumf = 0;
8692
+ for (int i = 0; i < nb; ++i) {
8693
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8694
+ const uint8_t * restrict q3 = x[i].qs;
8695
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8696
+ const int8_t * restrict q8 = y[i].qs;
8697
+ float sumf1 = 0, sumf2 = 0;
8698
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8699
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
8700
+ memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
8701
+ const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
8702
+ const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
8703
+ const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
8704
+ const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
8705
+ q3 += 16;
8706
+ q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
8707
+ q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
8708
+ q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
8709
+ q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
8710
+ q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
8711
+ q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
8712
+ q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
8713
+ q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
8714
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
8715
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
8716
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
8717
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
8718
+ }
8719
+ sumf += d*(sumf1 + sumf2);
8720
+ }
8721
+ *s = 0.5f * sumf;
8722
+
8723
+ #elif defined(__AVX2__)
8724
+
8725
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8726
+
8727
+ uint32_t aux32[2];
8728
+
8729
+ __m256 accumf = _mm256_setzero_ps();
8730
+ for (int i = 0; i < nb; ++i) {
8731
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8732
+ const uint8_t * restrict q3 = x[i].qs;
8733
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8734
+ const int8_t * restrict q8 = y[i].qs;
8735
+ __m256i sumi1 = _mm256_setzero_si256();
8736
+ __m256i sumi2 = _mm256_setzero_si256();
8737
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8738
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8739
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8740
+ const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
8741
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
8742
+ q3 += 8;
8743
+ const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
8744
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
8745
+ q3 += 8;
8746
+ memcpy(aux32, gas, 8); gas += 8;
8747
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
8748
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
8749
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
8750
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
8751
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8752
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8753
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8754
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8755
+ const uint16_t ls1 = aux32[0] >> 28;
8756
+ const uint16_t ls2 = aux32[1] >> 28;
8757
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
8758
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
8759
+ sumi1 = _mm256_add_epi32(sumi1, p1);
8760
+ sumi2 = _mm256_add_epi32(sumi2, p2);
8761
+ }
8762
+
8763
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
8764
+
8765
+ }
8766
+
8767
+ *s = 0.25f * hsum_float_8(accumf);
8768
+
8769
+ #else
8770
+
8771
+ uint32_t aux32;
8772
+
8773
+ float sumf = 0.f;
8774
+ for (int i = 0; i < nb; ++i) {
8775
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8776
+ const uint8_t * restrict q3 = x[i].qs;
8777
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8778
+ const int8_t * restrict q8 = y[i].qs;
8779
+ int32_t bsum = 0;
8780
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
8781
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
8782
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
8783
+ int32_t sumi = 0;
8784
+ for (int l = 0; l < 4; ++l) {
8785
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
8786
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
8787
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
8788
+ for (int j = 0; j < 4; ++j) {
8789
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
8790
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
8791
+ }
8792
+ q8 += 8;
8793
+ }
8794
+ q3 += 8;
8795
+ bsum += sumi * ls;
8796
+ }
8797
+ sumf += d * bsum;
8798
+ }
8799
+ *s = 0.25f * sumf;
8800
+ #endif
8801
+ }
8802
+
8554
8803
  // ================================ IQ2 quantization =============================================
8555
8804
 
8556
8805
  typedef struct {
@@ -8790,8 +9039,6 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
8790
9039
  int8_t L[32];
8791
9040
  int8_t Laux[32];
8792
9041
  float waux[32];
8793
- bool is_on_grid[4];
8794
- bool is_on_grid_aux[4];
8795
9042
  uint8_t block_signs[4];
8796
9043
  uint32_t q2[2*(QK_K/32)];
8797
9044
 
@@ -8841,10 +9088,11 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
8841
9088
  memset(L, 0, 32);
8842
9089
  continue;
8843
9090
  }
9091
+ float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
9092
+ float eff_max = scale*kMaxQ;
8844
9093
  float best = 0;
8845
- float scale = max/(2*kMaxQ-1);
8846
- for (int is = -9; is <= 9; ++is) {
8847
- float id = (2*kMaxQ-1+is*0.1f)/max;
9094
+ for (int is = -6; is <= 6; ++is) {
9095
+ float id = (2*kMaxQ-1+is*0.1f)/eff_max;
8848
9096
  float this_scale = 1/id;
8849
9097
  for (int k = 0; k < 4; ++k) {
8850
9098
  for (int i = 0; i < 8; ++i) {
@@ -8854,9 +9102,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
8854
9102
  uint16_t u = 0;
8855
9103
  for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
8856
9104
  int grid_index = kmap_q2xs[u];
8857
- is_on_grid_aux[k] = true;
8858
9105
  if (grid_index < 0) {
8859
- is_on_grid_aux[k] = false;
8860
9106
  const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
8861
9107
  grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
8862
9108
  }
@@ -8870,16 +9116,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
8870
9116
  }
8871
9117
  if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
8872
9118
  scale = sumqx/sumq2; best = scale*sumqx;
8873
- for (int i = 0; i < 32; ++i) L[i] = Laux[i];
8874
- for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
9119
+ memcpy(L, Laux, 32);
8875
9120
  }
8876
9121
  }
8877
- int n_not_ongrid = 0;
8878
- for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
8879
- if (n_not_ongrid > 0 && scale > 0) {
9122
+ if (scale > 0) {
8880
9123
  float id = 1/scale;
8881
9124
  for (int k = 0; k < 4; ++k) {
8882
- if (is_on_grid[k]) continue;
8883
9125
  uint16_t u = 0;
8884
9126
  for (int i = 0; i < 8; ++i) {
8885
9127
  int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
@@ -8935,49 +9177,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
8935
9177
  float d = max_scale/31;
8936
9178
  y[ibl].d = GGML_FP32_TO_FP16(d);
8937
9179
  float id = 1/d;
8938
- float sumqx = 0, sumq2 = 0;
8939
9180
  for (int ib = 0; ib < QK_K/32; ++ib) {
8940
9181
  int l = nearest_int(0.5f*(id*scales[ib]-1));
8941
9182
  l = MAX(0, MIN(15, l));
8942
9183
  q2[2*ib+1] |= ((uint32_t)l << 28);
8943
- const float * xb = xbl + 32*ib;
8944
- const float * qw = quant_weights + QK_K*ibl + 32*ib;
8945
- for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
8946
- const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
8947
- const float db = d * (1 + 2*l);
8948
- uint32_t u = 0;
8949
- for (int k = 0; k < 4; ++k) {
8950
- const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
8951
- const float * xk = xb + 8*k;
8952
- const float * wk = weight + 8*k;
8953
- const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
8954
- float best_mse = 0; int best_index = aux8[k];
8955
- for (int j = 0; j < 8; ++j) {
8956
- float diff = db * grid[j] * signs[j] - xk[j];
8957
- best_mse += wk[j] * diff * diff;
8958
- }
8959
- for (int idx = 0; idx < 256; ++idx) {
8960
- grid = (const uint8_t *)(kgrid_q2xs + idx);
8961
- float mse = 0;
8962
- for (int j = 0; j < 8; ++j) {
8963
- float diff = db * grid[j] * signs[j] - xk[j];
8964
- mse += wk[j] * diff * diff;
8965
- }
8966
- if (mse < best_mse) {
8967
- best_mse = mse; best_index = idx;
8968
- }
8969
- }
8970
- u |= (best_index << 8*k);
8971
- grid = (const uint8_t *)(kgrid_q2xs + best_index);
8972
- //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
8973
- for (int j = 0; j < 8; ++j) {
8974
- float q = db * grid[j] * signs[j];
8975
- sumqx += wk[j] * q * xk[j];
8976
- sumq2 += wk[j] * q * q;
8977
- }
8978
- }
8979
- q2[2*ib] = u;
8980
- if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
8981
9184
  }
8982
9185
  memcpy(y[ibl].qs, q2, QK_K/4);
8983
9186
  }
@@ -9189,3 +9392,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
9189
9392
  return nrow * nblock * sizeof(block_iq2_xs);
9190
9393
  }
9191
9394
 
9395
+ //
9396
+ // ============================================= 3-bit using D4 lattice
9397
+ //
9398
+
9399
+ typedef struct {
9400
+ uint32_t * grid;
9401
+ int * map;
9402
+ uint16_t * neighbours;
9403
+ } iq3_entry_t;
9404
+
9405
+ static iq3_entry_t iq3_data[1] = {
9406
+ {NULL, NULL, NULL},
9407
+ };
9408
+
9409
+ static inline int iq3_data_index(int grid_size) {
9410
+ (void)grid_size;
9411
+ GGML_ASSERT(grid_size == 256);
9412
+ return 0;
9413
+ }
9414
+
9415
+ static int iq3_compare_func(const void * left, const void * right) {
9416
+ const int * l = (const int *)left;
9417
+ const int * r = (const int *)right;
9418
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
9419
+ }
9420
+
9421
+ void iq3xs_init_impl(int grid_size) {
9422
+ const int gindex = iq3_data_index(grid_size);
9423
+ if (iq3_data[gindex].grid) {
9424
+ return;
9425
+ }
9426
+ static const uint16_t kgrid_256[256] = {
9427
+ 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
9428
+ 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
9429
+ 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
9430
+ 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
9431
+ 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
9432
+ 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
9433
+ 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
9434
+ 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
9435
+ 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
9436
+ 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
9437
+ 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
9438
+ 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
9439
+ 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
9440
+ 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
9441
+ 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
9442
+ 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
9443
+ };
9444
+ const int kmap_size = 4096;
9445
+ const int nwant = 2;
9446
+ const uint16_t * kgrid = kgrid_256;
9447
+ uint32_t * kgrid_q3xs;
9448
+ int * kmap_q3xs;
9449
+ uint16_t * kneighbors_q3xs;
9450
+
9451
+ printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
9452
+ uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
9453
+ for (int k = 0; k < grid_size; ++k) {
9454
+ int8_t * pos = (int8_t *)(the_grid + k);
9455
+ for (int i = 0; i < 4; ++i) {
9456
+ int l = (kgrid[k] >> 3*i) & 0x7;
9457
+ pos[i] = 2*l + 1;
9458
+ }
9459
+ }
9460
+ kgrid_q3xs = the_grid;
9461
+ iq3_data[gindex].grid = the_grid;
9462
+ kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
9463
+ iq3_data[gindex].map = kmap_q3xs;
9464
+ for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
9465
+ uint32_t aux32;
9466
+ uint8_t * aux8 = (uint8_t *)&aux32;
9467
+ for (int i = 0; i < grid_size; ++i) {
9468
+ aux32 = kgrid_q3xs[i];
9469
+ uint16_t index = 0;
9470
+ for (int k=0; k<4; ++k) {
9471
+ uint16_t q = (aux8[k] - 1)/2;
9472
+ index |= (q << 3*k);
9473
+ }
9474
+ kmap_q3xs[index] = i;
9475
+ }
9476
+ int8_t pos[4];
9477
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
9478
+ int num_neighbors = 0, num_not_in_map = 0;
9479
+ for (int i = 0; i < kmap_size; ++i) {
9480
+ if (kmap_q3xs[i] >= 0) continue;
9481
+ ++num_not_in_map;
9482
+ for (int k = 0; k < 4; ++k) {
9483
+ int l = (i >> 3*k) & 0x7;
9484
+ pos[k] = 2*l + 1;
9485
+ }
9486
+ for (int j = 0; j < grid_size; ++j) {
9487
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
9488
+ int d2 = 0;
9489
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
9490
+ dist2[2*j+0] = d2;
9491
+ dist2[2*j+1] = j;
9492
+ }
9493
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
9494
+ int n = 0; int d2 = dist2[0];
9495
+ int nhave = 1;
9496
+ for (int j = 0; j < grid_size; ++j) {
9497
+ if (dist2[2*j] > d2) {
9498
+ if (nhave == nwant) break;
9499
+ d2 = dist2[2*j];
9500
+ ++nhave;
9501
+ }
9502
+ ++n;
9503
+ }
9504
+ num_neighbors += n;
9505
+ }
9506
+ printf("%s: %d neighbours in total\n", __func__, num_neighbors);
9507
+ kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
9508
+ iq3_data[gindex].neighbours = kneighbors_q3xs;
9509
+ int counter = 0;
9510
+ for (int i = 0; i < kmap_size; ++i) {
9511
+ if (kmap_q3xs[i] >= 0) continue;
9512
+ for (int k = 0; k < 4; ++k) {
9513
+ int l = (i >> 3*k) & 0x7;
9514
+ pos[k] = 2*l + 1;
9515
+ }
9516
+ for (int j = 0; j < grid_size; ++j) {
9517
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
9518
+ int d2 = 0;
9519
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
9520
+ dist2[2*j+0] = d2;
9521
+ dist2[2*j+1] = j;
9522
+ }
9523
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
9524
+ kmap_q3xs[i] = -(counter + 1);
9525
+ int d2 = dist2[0];
9526
+ uint16_t * start = &kneighbors_q3xs[counter++];
9527
+ int n = 0, nhave = 1;
9528
+ for (int j = 0; j < grid_size; ++j) {
9529
+ if (dist2[2*j] > d2) {
9530
+ if (nhave == nwant) break;
9531
+ d2 = dist2[2*j];
9532
+ ++nhave;
9533
+ }
9534
+ kneighbors_q3xs[counter++] = dist2[2*j+1];
9535
+ ++n;
9536
+ }
9537
+ *start = n;
9538
+ }
9539
+ free(dist2);
9540
+ }
9541
+
9542
+ void iq3xs_free_impl(int grid_size) {
9543
+ GGML_ASSERT(grid_size == 256);
9544
+ const int gindex = iq3_data_index(grid_size);
9545
+ if (iq3_data[gindex].grid) {
9546
+ free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
9547
+ free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
9548
+ free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
9549
+ }
9550
+ }
9551
+
9552
+ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
9553
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
9554
+ int num_neighbors = neighbours[0];
9555
+ GGML_ASSERT(num_neighbors > 0);
9556
+ float best_d2 = FLT_MAX;
9557
+ int grid_index = -1;
9558
+ for (int j = 1; j <= num_neighbors; ++j) {
9559
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
9560
+ float d2 = 0;
9561
+ for (int i = 0; i < 4; ++i) {
9562
+ float q = pg[i];
9563
+ float diff = scale*q - xval[i];
9564
+ d2 += weight[i]*diff*diff;
9565
+ }
9566
+ if (d2 < best_d2) {
9567
+ best_d2 = d2; grid_index = neighbours[j];
9568
+ }
9569
+ }
9570
+ GGML_ASSERT(grid_index >= 0);
9571
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
9572
+ for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
9573
+ return grid_index;
9574
+ }
9575
+
9576
+ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9577
+
9578
+ const int gindex = iq3_data_index(256);
9579
+
9580
+ const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
9581
+ const int * kmap_q3xs = iq3_data[gindex].map;
9582
+ const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
9583
+
9584
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
9585
+ GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
9586
+ GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
9587
+ GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
9588
+ GGML_ASSERT(n%QK_K == 0);
9589
+
9590
+ const int kMaxQ = 8;
9591
+
9592
+ const int nbl = n/256;
9593
+
9594
+ block_iq3_xxs * y = vy;
9595
+
9596
+ float scales[QK_K/32];
9597
+ float weight[32];
9598
+ float xval[32];
9599
+ int8_t L[32];
9600
+ int8_t Laux[32];
9601
+ float waux[32];
9602
+ bool is_on_grid[8];
9603
+ bool is_on_grid_aux[8];
9604
+ uint8_t block_signs[8];
9605
+ uint8_t q3[3*(QK_K/8)];
9606
+ uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
9607
+
9608
+ for (int ibl = 0; ibl < nbl; ++ibl) {
9609
+
9610
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
9611
+ memset(q3, 0, 3*QK_K/8);
9612
+
9613
+ float max_scale = 0;
9614
+
9615
+ const float * xbl = x + QK_K*ibl;
9616
+ float sumx2 = 0;
9617
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
9618
+ float sigma2 = sumx2/QK_K;
9619
+
9620
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9621
+ const float * xb = xbl + 32*ib;
9622
+ if (quant_weights) {
9623
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
9624
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9625
+ } else {
9626
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
9627
+ }
9628
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
9629
+ for (int k = 0; k < 4; ++k) {
9630
+ int nflip = 0;
9631
+ uint8_t s = 0;
9632
+ for (int i = 0; i < 8; ++i) {
9633
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
9634
+ else {
9635
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
9636
+ }
9637
+ }
9638
+ if (nflip%2) {
9639
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
9640
+ for (int i = 1; i < 8; ++i) {
9641
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
9642
+ if (ax < min) {
9643
+ min = ax; imin = i;
9644
+ }
9645
+ }
9646
+ xval[8*k+imin] = -xval[8*k+imin];
9647
+ s ^= (1 << imin);
9648
+ }
9649
+ block_signs[k] = s & 127;
9650
+ }
9651
+ float max = xval[0];
9652
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
9653
+ if (!max) {
9654
+ scales[ib] = 0;
9655
+ memset(L, 0, 32);
9656
+ continue;
9657
+ }
9658
+ float best = 0;
9659
+ float scale = max/(2*kMaxQ-1);
9660
+ for (int is = -15; is <= 15; ++is) {
9661
+ float id = (2*kMaxQ-1+is*0.2f)/max;
9662
+ float this_scale = 1/id;
9663
+ for (int k = 0; k < 8; ++k) {
9664
+ for (int i = 0; i < 4; ++i) {
9665
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
9666
+ Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
9667
+ }
9668
+ uint16_t u = 0;
9669
+ for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
9670
+ int grid_index = kmap_q3xs[u];
9671
+ is_on_grid_aux[k] = true;
9672
+ if (grid_index < 0) {
9673
+ is_on_grid_aux[k] = false;
9674
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
9675
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
9676
+ }
9677
+ }
9678
+ float sumqx = 0, sumq2 = 0;
9679
+ for (int i = 0; i < 32; ++i) {
9680
+ float w = weight[i];
9681
+ float q = 2*Laux[i] + 1;
9682
+ sumqx += w*xval[i]*q;
9683
+ sumq2 += w*q*q;
9684
+ }
9685
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
9686
+ scale = sumqx/sumq2; best = scale*sumqx;
9687
+ for (int i = 0; i < 32; ++i) L[i] = Laux[i];
9688
+ for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
9689
+ }
9690
+ }
9691
+ int n_not_ongrid = 0;
9692
+ for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9693
+ if (n_not_ongrid > 0 && scale > 0) {
9694
+ float id = 1/scale;
9695
+ for (int k = 0; k < 8; ++k) {
9696
+ if (is_on_grid[k]) continue;
9697
+ uint16_t u = 0;
9698
+ for (int i = 0; i < 4; ++i) {
9699
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
9700
+ l = MAX(0, MIN(kMaxQ-1, l));
9701
+ u |= (l << 3*i);
9702
+ }
9703
+ int grid_index = kmap_q3xs[u];
9704
+ if (grid_index < 0) {
9705
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
9706
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
9707
+ }
9708
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
9709
+ for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
9710
+ }
9711
+ float sumqx = 0, sumq2 = 0;
9712
+ for (int i = 0; i < 32; ++i) {
9713
+ float w = weight[i];
9714
+ float q = 2*L[i] + 1;
9715
+ sumqx += w*xval[i]*q;
9716
+ sumq2 += w*q*q;
9717
+ }
9718
+ if (sumq2 > 0) scale = sumqx/sumq2;
9719
+ }
9720
+ if (scale < 0) {
9721
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
9722
+ // and correspondingly flip quant signs.
9723
+ scale = -scale;
9724
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
9725
+ }
9726
+ for (int k = 0; k < 8; ++k) {
9727
+ uint16_t u = 0;
9728
+ for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
9729
+ int grid_index = kmap_q3xs[u];
9730
+ if (grid_index < 0) {
9731
+ printf("Oops: found point %u not on grid:", u);
9732
+ for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
9733
+ printf("\n");
9734
+ GGML_ASSERT(false);
9735
+ }
9736
+ q3[8*ib+k] = grid_index;
9737
+ }
9738
+ scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
9739
+ GGML_ASSERT(scale >= 0);
9740
+ scales[ib] = scale;
9741
+ max_scale = MAX(max_scale, scale);
9742
+ }
9743
+
9744
+ if (!max_scale) {
9745
+ memset(y[ibl].qs, 0, 3*QK_K/8);
9746
+ continue;
9747
+ }
9748
+
9749
+ float d = max_scale/31;
9750
+ y[ibl].d = GGML_FP32_TO_FP16(d);
9751
+ float id = 1/d;
9752
+ float sumqx = 0, sumq2 = 0;
9753
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9754
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
9755
+ l = MAX(0, MIN(15, l));
9756
+ scales_and_signs[ib] |= ((uint32_t)l << 28);
9757
+ if (false) {
9758
+ const float * xb = xbl + 32*ib;
9759
+ if (quant_weights) {
9760
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
9761
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9762
+ } else {
9763
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
9764
+ }
9765
+ const float db = 0.25f * d * (1 + 2*l);
9766
+ for (int k = 0; k < 8; ++k) {
9767
+ const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
9768
+ const float * xk = xb + 4*k;
9769
+ const float * wk = weight + 4*k;
9770
+ //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
9771
+ const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
9772
+ float best_mse = 0; int best_index = q3[8*ib+k];
9773
+ for (int j = 0; j < 4; ++j) {
9774
+ float diff = db * grid[j] * signs[j] - xk[j];
9775
+ best_mse += wk[j] * diff * diff;
9776
+ }
9777
+ for (int idx = 0; idx < 256; ++idx) {
9778
+ //grid = (const uint8_t *)(kgrid_q3xs + idx);
9779
+ grid = (const uint8_t *)(iq3xxs_grid + idx);
9780
+ float mse = 0;
9781
+ for (int j = 0; j < 4; ++j) {
9782
+ float diff = db * grid[j] * signs[j] - xk[j];
9783
+ mse += wk[j] * diff * diff;
9784
+ }
9785
+ if (mse < best_mse) {
9786
+ best_mse = mse; best_index = idx;
9787
+ }
9788
+ }
9789
+ q3[8*ib+k] = best_index;
9790
+ //grid = (const uint8_t *)(kgrid_q3xs + best_index);
9791
+ grid = (const uint8_t *)(iq3xxs_grid + best_index);
9792
+ for (int j = 0; j < 4; ++j) {
9793
+ float q = db * grid[j] * signs[j];
9794
+ sumqx += wk[j] * q * xk[j];
9795
+ sumq2 += wk[j] * q * q;
9796
+ }
9797
+ }
9798
+ if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
9799
+ }
9800
+ }
9801
+ memcpy(y[ibl].qs, q3, 3*QK_K/8);
9802
+ }
9803
+ }
9804
+
9805
+ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
9806
+ (void)hist;
9807
+ GGML_ASSERT(n_per_row%QK_K == 0);
9808
+ int nblock = n_per_row/QK_K;
9809
+ char * qrow = (char *)dst;
9810
+ for (int row = 0; row < nrow; ++row) {
9811
+ quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
9812
+ src += n_per_row;
9813
+ qrow += nblock*sizeof(block_iq3_xxs);
9814
+ }
9815
+ return nrow * nblock * sizeof(block_iq3_xxs);
9816
+ }
9817
+
9818
+ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
9819
+ assert(k % QK_K == 0);
9820
+ block_iq3_xxs * restrict y = vy;
9821
+ quantize_row_iq3_xxs_reference(x, y, k);
9822
+ }
9823
+
9824
+ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
9825
+ assert(k % QK_K == 0);
9826
+ quantize_row_iq3_xxs_impl(x, y, k, NULL);
9827
+ }