llama_cpp 0.12.3 → 0.12.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3441,6 +3441,41 @@ static const uint64_t iq2xs_grid[512] = {
3441
3441
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3442
3442
  };
3443
3443
 
3444
+ static const uint32_t iq3xxs_grid[256] = {
3445
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
3446
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
3447
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
3448
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
3449
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
3450
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
3451
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
3452
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
3453
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
3454
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
3455
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
3456
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
3457
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
3458
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
3459
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
3460
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
3461
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
3462
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
3463
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
3464
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
3465
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
3466
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
3467
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
3468
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
3469
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
3470
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
3471
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
3472
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
3473
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
3474
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
3475
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
3476
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3477
+ };
3478
+
3444
3479
  static const uint8_t ksigns_iq2xs[128] = {
3445
3480
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3446
3481
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3507,6 +3542,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
3507
3542
  }
3508
3543
  }
3509
3544
 
3545
+ // ====================== 3.0625 bpw (de)-quantization
3546
+
3547
+ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
3548
+ assert(k % QK_K == 0);
3549
+ const int nb = k / QK_K;
3550
+
3551
+ uint32_t aux32;
3552
+
3553
+ for (int i = 0; i < nb; i++) {
3554
+
3555
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3556
+ const uint8_t * qs = x[i].qs;
3557
+ const uint8_t * scales_and_signs = qs + QK_K/4;
3558
+
3559
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3560
+ memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
3561
+ const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
3562
+ for (int l = 0; l < 4; ++l) {
3563
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
3564
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
3565
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
3566
+ for (int j = 0; j < 4; ++j) {
3567
+ y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
3568
+ y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
3569
+ }
3570
+ y += 8;
3571
+ }
3572
+ qs += 8;
3573
+ }
3574
+ }
3575
+ }
3576
+
3510
3577
  //===================================== Q8_K ==============================================
3511
3578
 
3512
3579
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -8458,17 +8525,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8458
8525
 
8459
8526
  const __m128i m4 = _mm_set1_epi8(0xf);
8460
8527
  const __m128i m1 = _mm_set1_epi8(1);
8461
- const __m128i m511 = _mm_set1_epi16(511);
8462
- const __m128i m127 = _mm_set1_epi16(127);
8528
+ const __m256i m511 = _mm256_set1_epi16(511);
8529
+ const __m256i mone = _mm256_set1_epi8(1);
8463
8530
 
8464
- const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8531
+ static const uint8_t k_bit_helper[32] = {
8532
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
8533
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
8534
+ };
8535
+ static const char block_sign_shuffle_mask_1[32] = {
8536
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
8537
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
8538
+ };
8539
+ static const char block_sign_shuffle_mask_2[32] = {
8540
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
8541
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
8542
+ };
8543
+ static const uint8_t bit_selector_mask_bytes[32] = {
8544
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
8545
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
8546
+ };
8547
+
8548
+ const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
8549
+ const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
8550
+ const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
8551
+ const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
8465
8552
 
8466
8553
  uint64_t aux64;
8467
8554
 
8468
8555
  // somewhat hacky, but gives a significant boost in performance
8469
- __m128i aux_gindex, aux_sindex;
8556
+ __m256i aux_gindex;
8470
8557
  const uint16_t * gindex = (const uint16_t *)&aux_gindex;
8471
- const uint16_t * sindex = (const uint16_t *)&aux_sindex;
8472
8558
 
8473
8559
  __m256 accumf = _mm256_setzero_ps();
8474
8560
  for (int i = 0; i < nb; ++i) {
@@ -8483,26 +8569,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8483
8569
 
8484
8570
  __m256i sumi1 = _mm256_setzero_si256();
8485
8571
  __m256i sumi2 = _mm256_setzero_si256();
8486
- for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8572
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
8573
+
8574
+ const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
8575
+ aux_gindex = _mm256_and_si256(q2_data, m511);
8576
+
8577
+ const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
8578
+ const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
8579
+ const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
8580
+
8581
+ const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
8582
+ const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
8583
+
8487
8584
  const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8488
8585
  const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8489
- const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
8490
- aux_gindex = _mm_and_si128(q2_data, m511);
8491
- aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
8492
- const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
8493
- const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
8494
- const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
8495
- const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
8496
- const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8497
- const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8586
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8587
+ const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8588
+
8589
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
8590
+ iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
8591
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
8592
+ iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
8593
+ const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
8594
+ iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
8595
+ const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
8596
+ iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
8597
+
8598
+ const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
8599
+ const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
8600
+ const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
8601
+ const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
8602
+
8603
+ __m256i signs;
8604
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
8605
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8606
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
8607
+
8608
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
8609
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8610
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
8611
+
8612
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
8613
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8614
+ const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
8615
+
8616
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
8617
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8618
+ const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
8619
+
8498
8620
  const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8499
8621
  const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8622
+ const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
8623
+ const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
8500
8624
 
8501
8625
  const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
8502
8626
  const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
8627
+ const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
8628
+ const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
8503
8629
 
8504
8630
  sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
8505
8631
  sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
8632
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
8633
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
8506
8634
  }
8507
8635
 
8508
8636
  accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
@@ -8551,6 +8679,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8551
8679
  #endif
8552
8680
  }
8553
8681
 
8682
+ // TODO
8683
+ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8684
+ assert(n % QK_K == 0);
8685
+
8686
+ const block_iq3_xxs * restrict x = vx;
8687
+ const block_q8_K * restrict y = vy;
8688
+
8689
+ const int nb = n / QK_K;
8690
+
8691
+ #if defined(__ARM_NEON)
8692
+
8693
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8694
+
8695
+ uint32_t aux32[2];
8696
+
8697
+ ggml_int8x16x4_t q3s;
8698
+ ggml_int8x16x4_t q8b;
8699
+
8700
+ float sumf = 0;
8701
+ for (int i = 0; i < nb; ++i) {
8702
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8703
+ const uint8_t * restrict q3 = x[i].qs;
8704
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8705
+ const int8_t * restrict q8 = y[i].qs;
8706
+ float sumf1 = 0, sumf2 = 0;
8707
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8708
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
8709
+ memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
8710
+ const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
8711
+ const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
8712
+ const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
8713
+ const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
8714
+ q3 += 16;
8715
+ q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
8716
+ q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
8717
+ q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
8718
+ q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
8719
+ q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
8720
+ q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
8721
+ q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
8722
+ q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
8723
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
8724
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
8725
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
8726
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
8727
+ }
8728
+ sumf += d*(sumf1 + sumf2);
8729
+ }
8730
+ *s = 0.5f * sumf;
8731
+
8732
+ #elif defined(__AVX2__)
8733
+
8734
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8735
+
8736
+ uint32_t aux32[2];
8737
+
8738
+ __m256 accumf = _mm256_setzero_ps();
8739
+ for (int i = 0; i < nb; ++i) {
8740
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8741
+ const uint8_t * restrict q3 = x[i].qs;
8742
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8743
+ const int8_t * restrict q8 = y[i].qs;
8744
+ __m256i sumi1 = _mm256_setzero_si256();
8745
+ __m256i sumi2 = _mm256_setzero_si256();
8746
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8747
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8748
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8749
+ const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
8750
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
8751
+ q3 += 8;
8752
+ const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
8753
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
8754
+ q3 += 8;
8755
+ memcpy(aux32, gas, 8); gas += 8;
8756
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
8757
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
8758
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
8759
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
8760
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8761
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8762
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8763
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8764
+ const uint16_t ls1 = aux32[0] >> 28;
8765
+ const uint16_t ls2 = aux32[1] >> 28;
8766
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
8767
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
8768
+ sumi1 = _mm256_add_epi32(sumi1, p1);
8769
+ sumi2 = _mm256_add_epi32(sumi2, p2);
8770
+ }
8771
+
8772
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
8773
+
8774
+ }
8775
+
8776
+ *s = 0.25f * hsum_float_8(accumf);
8777
+
8778
+ #else
8779
+
8780
+ uint32_t aux32;
8781
+
8782
+ float sumf = 0.f;
8783
+ for (int i = 0; i < nb; ++i) {
8784
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8785
+ const uint8_t * restrict q3 = x[i].qs;
8786
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8787
+ const int8_t * restrict q8 = y[i].qs;
8788
+ int32_t bsum = 0;
8789
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
8790
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
8791
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
8792
+ int32_t sumi = 0;
8793
+ for (int l = 0; l < 4; ++l) {
8794
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
8795
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
8796
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
8797
+ for (int j = 0; j < 4; ++j) {
8798
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
8799
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
8800
+ }
8801
+ q8 += 8;
8802
+ }
8803
+ q3 += 8;
8804
+ bsum += sumi * ls;
8805
+ }
8806
+ sumf += d * bsum;
8807
+ }
8808
+ *s = 0.25f * sumf;
8809
+ #endif
8810
+ }
8811
+
8554
8812
  // ================================ IQ2 quantization =============================================
8555
8813
 
8556
8814
  typedef struct {
@@ -9189,3 +9447,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
9189
9447
  return nrow * nblock * sizeof(block_iq2_xs);
9190
9448
  }
9191
9449
 
9450
+ //
9451
+ // ============================================= 3-bit using D4 lattice
9452
+ //
9453
+
9454
+ typedef struct {
9455
+ uint32_t * grid;
9456
+ int * map;
9457
+ uint16_t * neighbours;
9458
+ } iq3_entry_t;
9459
+
9460
+ static iq3_entry_t iq3_data[1] = {
9461
+ {NULL, NULL, NULL},
9462
+ };
9463
+
9464
+ static inline int iq3_data_index(int grid_size) {
9465
+ (void)grid_size;
9466
+ GGML_ASSERT(grid_size == 256);
9467
+ return 0;
9468
+ }
9469
+
9470
+ static int iq3_compare_func(const void * left, const void * right) {
9471
+ const int * l = (const int *)left;
9472
+ const int * r = (const int *)right;
9473
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
9474
+ }
9475
+
9476
+ void iq3xs_init_impl(int grid_size) {
9477
+ const int gindex = iq3_data_index(grid_size);
9478
+ if (iq3_data[gindex].grid) {
9479
+ return;
9480
+ }
9481
+ static const uint16_t kgrid_256[256] = {
9482
+ 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
9483
+ 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
9484
+ 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
9485
+ 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
9486
+ 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
9487
+ 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
9488
+ 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
9489
+ 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
9490
+ 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
9491
+ 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
9492
+ 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
9493
+ 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
9494
+ 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
9495
+ 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
9496
+ 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
9497
+ 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
9498
+ };
9499
+ const int kmap_size = 4096;
9500
+ const int nwant = 2;
9501
+ const uint16_t * kgrid = kgrid_256;
9502
+ uint32_t * kgrid_q3xs;
9503
+ int * kmap_q3xs;
9504
+ uint16_t * kneighbors_q3xs;
9505
+
9506
+ printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
9507
+ uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
9508
+ for (int k = 0; k < grid_size; ++k) {
9509
+ int8_t * pos = (int8_t *)(the_grid + k);
9510
+ for (int i = 0; i < 4; ++i) {
9511
+ int l = (kgrid[k] >> 3*i) & 0x7;
9512
+ pos[i] = 2*l + 1;
9513
+ }
9514
+ }
9515
+ kgrid_q3xs = the_grid;
9516
+ iq3_data[gindex].grid = the_grid;
9517
+ kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
9518
+ iq3_data[gindex].map = kmap_q3xs;
9519
+ for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
9520
+ uint32_t aux32;
9521
+ uint8_t * aux8 = (uint8_t *)&aux32;
9522
+ for (int i = 0; i < grid_size; ++i) {
9523
+ aux32 = kgrid_q3xs[i];
9524
+ uint16_t index = 0;
9525
+ for (int k=0; k<4; ++k) {
9526
+ uint16_t q = (aux8[k] - 1)/2;
9527
+ index |= (q << 3*k);
9528
+ }
9529
+ kmap_q3xs[index] = i;
9530
+ }
9531
+ int8_t pos[4];
9532
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
9533
+ int num_neighbors = 0, num_not_in_map = 0;
9534
+ for (int i = 0; i < kmap_size; ++i) {
9535
+ if (kmap_q3xs[i] >= 0) continue;
9536
+ ++num_not_in_map;
9537
+ for (int k = 0; k < 4; ++k) {
9538
+ int l = (i >> 3*k) & 0x7;
9539
+ pos[k] = 2*l + 1;
9540
+ }
9541
+ for (int j = 0; j < grid_size; ++j) {
9542
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
9543
+ int d2 = 0;
9544
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
9545
+ dist2[2*j+0] = d2;
9546
+ dist2[2*j+1] = j;
9547
+ }
9548
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
9549
+ int n = 0; int d2 = dist2[0];
9550
+ int nhave = 1;
9551
+ for (int j = 0; j < grid_size; ++j) {
9552
+ if (dist2[2*j] > d2) {
9553
+ if (nhave == nwant) break;
9554
+ d2 = dist2[2*j];
9555
+ ++nhave;
9556
+ }
9557
+ ++n;
9558
+ }
9559
+ num_neighbors += n;
9560
+ }
9561
+ printf("%s: %d neighbours in total\n", __func__, num_neighbors);
9562
+ kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
9563
+ iq3_data[gindex].neighbours = kneighbors_q3xs;
9564
+ int counter = 0;
9565
+ for (int i = 0; i < kmap_size; ++i) {
9566
+ if (kmap_q3xs[i] >= 0) continue;
9567
+ for (int k = 0; k < 4; ++k) {
9568
+ int l = (i >> 3*k) & 0x7;
9569
+ pos[k] = 2*l + 1;
9570
+ }
9571
+ for (int j = 0; j < grid_size; ++j) {
9572
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
9573
+ int d2 = 0;
9574
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
9575
+ dist2[2*j+0] = d2;
9576
+ dist2[2*j+1] = j;
9577
+ }
9578
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
9579
+ kmap_q3xs[i] = -(counter + 1);
9580
+ int d2 = dist2[0];
9581
+ uint16_t * start = &kneighbors_q3xs[counter++];
9582
+ int n = 0, nhave = 1;
9583
+ for (int j = 0; j < grid_size; ++j) {
9584
+ if (dist2[2*j] > d2) {
9585
+ if (nhave == nwant) break;
9586
+ d2 = dist2[2*j];
9587
+ ++nhave;
9588
+ }
9589
+ kneighbors_q3xs[counter++] = dist2[2*j+1];
9590
+ ++n;
9591
+ }
9592
+ *start = n;
9593
+ }
9594
+ free(dist2);
9595
+ }
9596
+
9597
+ void iq3xs_free_impl(int grid_size) {
9598
+ GGML_ASSERT(grid_size == 256);
9599
+ const int gindex = iq3_data_index(grid_size);
9600
+ if (iq3_data[gindex].grid) {
9601
+ free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
9602
+ free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
9603
+ free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
9604
+ }
9605
+ }
9606
+
9607
+ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
9608
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
9609
+ int num_neighbors = neighbours[0];
9610
+ GGML_ASSERT(num_neighbors > 0);
9611
+ float best_d2 = FLT_MAX;
9612
+ int grid_index = -1;
9613
+ for (int j = 1; j <= num_neighbors; ++j) {
9614
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
9615
+ float d2 = 0;
9616
+ for (int i = 0; i < 4; ++i) {
9617
+ float q = pg[i];
9618
+ float diff = scale*q - xval[i];
9619
+ d2 += weight[i]*diff*diff;
9620
+ }
9621
+ if (d2 < best_d2) {
9622
+ best_d2 = d2; grid_index = neighbours[j];
9623
+ }
9624
+ }
9625
+ GGML_ASSERT(grid_index >= 0);
9626
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
9627
+ for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
9628
+ return grid_index;
9629
+ }
9630
+
9631
+ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9632
+
9633
+ const int gindex = iq3_data_index(256);
9634
+
9635
+ const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
9636
+ const int * kmap_q3xs = iq3_data[gindex].map;
9637
+ const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
9638
+
9639
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
9640
+ GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
9641
+ GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
9642
+ GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
9643
+ GGML_ASSERT(n%QK_K == 0);
9644
+
9645
+ const int kMaxQ = 8;
9646
+
9647
+ const int nbl = n/256;
9648
+
9649
+ block_iq3_xxs * y = vy;
9650
+
9651
+ float scales[QK_K/32];
9652
+ float weight[32];
9653
+ float xval[32];
9654
+ int8_t L[32];
9655
+ int8_t Laux[32];
9656
+ float waux[32];
9657
+ bool is_on_grid[8];
9658
+ bool is_on_grid_aux[8];
9659
+ uint8_t block_signs[8];
9660
+ uint8_t q3[3*(QK_K/8)];
9661
+ uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
9662
+
9663
+ for (int ibl = 0; ibl < nbl; ++ibl) {
9664
+
9665
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
9666
+ memset(q3, 0, 3*QK_K/8);
9667
+
9668
+ float max_scale = 0;
9669
+
9670
+ const float * xbl = x + QK_K*ibl;
9671
+ float sumx2 = 0;
9672
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
9673
+ float sigma2 = sumx2/QK_K;
9674
+
9675
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9676
+ const float * xb = xbl + 32*ib;
9677
+ if (quant_weights) {
9678
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
9679
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9680
+ } else {
9681
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
9682
+ }
9683
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
9684
+ for (int k = 0; k < 4; ++k) {
9685
+ int nflip = 0;
9686
+ uint8_t s = 0;
9687
+ for (int i = 0; i < 8; ++i) {
9688
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
9689
+ else {
9690
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
9691
+ }
9692
+ }
9693
+ if (nflip%2) {
9694
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
9695
+ for (int i = 1; i < 8; ++i) {
9696
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
9697
+ if (ax < min) {
9698
+ min = ax; imin = i;
9699
+ }
9700
+ }
9701
+ xval[8*k+imin] = -xval[8*k+imin];
9702
+ s ^= (1 << imin);
9703
+ }
9704
+ block_signs[k] = s & 127;
9705
+ }
9706
+ float max = xval[0];
9707
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
9708
+ if (!max) {
9709
+ scales[ib] = 0;
9710
+ memset(L, 0, 32);
9711
+ continue;
9712
+ }
9713
+ float best = 0;
9714
+ float scale = max/(2*kMaxQ-1);
9715
+ for (int is = -15; is <= 15; ++is) {
9716
+ float id = (2*kMaxQ-1+is*0.2f)/max;
9717
+ float this_scale = 1/id;
9718
+ for (int k = 0; k < 8; ++k) {
9719
+ for (int i = 0; i < 4; ++i) {
9720
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
9721
+ Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
9722
+ }
9723
+ uint16_t u = 0;
9724
+ for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
9725
+ int grid_index = kmap_q3xs[u];
9726
+ is_on_grid_aux[k] = true;
9727
+ if (grid_index < 0) {
9728
+ is_on_grid_aux[k] = false;
9729
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
9730
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
9731
+ }
9732
+ }
9733
+ float sumqx = 0, sumq2 = 0;
9734
+ for (int i = 0; i < 32; ++i) {
9735
+ float w = weight[i];
9736
+ float q = 2*Laux[i] + 1;
9737
+ sumqx += w*xval[i]*q;
9738
+ sumq2 += w*q*q;
9739
+ }
9740
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
9741
+ scale = sumqx/sumq2; best = scale*sumqx;
9742
+ for (int i = 0; i < 32; ++i) L[i] = Laux[i];
9743
+ for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
9744
+ }
9745
+ }
9746
+ int n_not_ongrid = 0;
9747
+ for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9748
+ if (n_not_ongrid > 0 && scale > 0) {
9749
+ float id = 1/scale;
9750
+ for (int k = 0; k < 8; ++k) {
9751
+ if (is_on_grid[k]) continue;
9752
+ uint16_t u = 0;
9753
+ for (int i = 0; i < 4; ++i) {
9754
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
9755
+ l = MAX(0, MIN(kMaxQ-1, l));
9756
+ u |= (l << 3*i);
9757
+ }
9758
+ int grid_index = kmap_q3xs[u];
9759
+ if (grid_index < 0) {
9760
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
9761
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
9762
+ }
9763
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
9764
+ for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
9765
+ }
9766
+ float sumqx = 0, sumq2 = 0;
9767
+ for (int i = 0; i < 32; ++i) {
9768
+ float w = weight[i];
9769
+ float q = 2*L[i] + 1;
9770
+ sumqx += w*xval[i]*q;
9771
+ sumq2 += w*q*q;
9772
+ }
9773
+ if (sumq2 > 0) scale = sumqx/sumq2;
9774
+ }
9775
+ if (scale < 0) {
9776
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
9777
+ // and correspondingly flip quant signs.
9778
+ scale = -scale;
9779
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
9780
+ }
9781
+ for (int k = 0; k < 8; ++k) {
9782
+ uint16_t u = 0;
9783
+ for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
9784
+ int grid_index = kmap_q3xs[u];
9785
+ if (grid_index < 0) {
9786
+ printf("Oops: found point %u not on grid:", u);
9787
+ for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
9788
+ printf("\n");
9789
+ GGML_ASSERT(false);
9790
+ }
9791
+ q3[8*ib+k] = grid_index;
9792
+ }
9793
+ scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
9794
+ GGML_ASSERT(scale >= 0);
9795
+ scales[ib] = scale;
9796
+ max_scale = MAX(max_scale, scale);
9797
+ }
9798
+
9799
+ if (!max_scale) {
9800
+ memset(y[ibl].qs, 0, 3*QK_K/8);
9801
+ continue;
9802
+ }
9803
+
9804
+ float d = max_scale/31;
9805
+ y[ibl].d = GGML_FP32_TO_FP16(d);
9806
+ float id = 1/d;
9807
+ float sumqx = 0, sumq2 = 0;
9808
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9809
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
9810
+ l = MAX(0, MIN(15, l));
9811
+ scales_and_signs[ib] |= ((uint32_t)l << 28);
9812
+ if (false) {
9813
+ const float * xb = xbl + 32*ib;
9814
+ if (quant_weights) {
9815
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
9816
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9817
+ } else {
9818
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
9819
+ }
9820
+ const float db = 0.25f * d * (1 + 2*l);
9821
+ for (int k = 0; k < 8; ++k) {
9822
+ const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
9823
+ const float * xk = xb + 4*k;
9824
+ const float * wk = weight + 4*k;
9825
+ //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
9826
+ const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
9827
+ float best_mse = 0; int best_index = q3[8*ib+k];
9828
+ for (int j = 0; j < 4; ++j) {
9829
+ float diff = db * grid[j] * signs[j] - xk[j];
9830
+ best_mse += wk[j] * diff * diff;
9831
+ }
9832
+ for (int idx = 0; idx < 256; ++idx) {
9833
+ //grid = (const uint8_t *)(kgrid_q3xs + idx);
9834
+ grid = (const uint8_t *)(iq3xxs_grid + idx);
9835
+ float mse = 0;
9836
+ for (int j = 0; j < 4; ++j) {
9837
+ float diff = db * grid[j] * signs[j] - xk[j];
9838
+ mse += wk[j] * diff * diff;
9839
+ }
9840
+ if (mse < best_mse) {
9841
+ best_mse = mse; best_index = idx;
9842
+ }
9843
+ }
9844
+ q3[8*ib+k] = best_index;
9845
+ //grid = (const uint8_t *)(kgrid_q3xs + best_index);
9846
+ grid = (const uint8_t *)(iq3xxs_grid + best_index);
9847
+ for (int j = 0; j < 4; ++j) {
9848
+ float q = db * grid[j] * signs[j];
9849
+ sumqx += wk[j] * q * xk[j];
9850
+ sumq2 += wk[j] * q * q;
9851
+ }
9852
+ }
9853
+ if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
9854
+ }
9855
+ }
9856
+ memcpy(y[ibl].qs, q3, 3*QK_K/8);
9857
+ }
9858
+ }
9859
+
9860
+ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
9861
+ (void)hist;
9862
+ GGML_ASSERT(n_per_row%QK_K == 0);
9863
+ int nblock = n_per_row/QK_K;
9864
+ char * qrow = (char *)dst;
9865
+ for (int row = 0; row < nrow; ++row) {
9866
+ quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
9867
+ src += n_per_row;
9868
+ qrow += nblock*sizeof(block_iq3_xxs);
9869
+ }
9870
+ return nrow * nblock * sizeof(block_iq3_xxs);
9871
+ }
9872
+
9873
+ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
9874
+ assert(k % QK_K == 0);
9875
+ block_iq3_xxs * restrict y = vy;
9876
+ quantize_row_iq3_xxs_reference(x, y, k);
9877
+ }
9878
+
9879
+ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
9880
+ assert(k % QK_K == 0);
9881
+ quantize_row_iq3_xxs_impl(x, y, k, NULL);
9882
+ }