llama_cpp 0.12.2 → 0.12.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1274,7 +1274,12 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
1274
1274
  }
1275
1275
  float sumlx = 0;
1276
1276
  float suml2 = 0;
1277
+ #ifdef HAVE_BUGGY_APPLE_LINKER
1278
+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
1279
+ for (volatile int i = 0; i < n; ++i) {
1280
+ #else
1277
1281
  for (int i = 0; i < n; ++i) {
1282
+ #endif
1278
1283
  int l = nearest_int(iscale * x[i]);
1279
1284
  l = MAX(-nmax, MIN(nmax-1, l));
1280
1285
  L[i] = l + nmax;
@@ -1649,7 +1654,12 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
1649
1654
  float max = x[0];
1650
1655
  float sum_w = weights ? weights[0] : x[0]*x[0];
1651
1656
  float sum_x = sum_w * x[0];
1657
+ #ifdef HAVE_BUGGY_APPLE_LINKER
1658
+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
1659
+ for (volatile int i = 1; i < n; ++i) {
1660
+ #else
1652
1661
  for (int i = 1; i < n; ++i) {
1662
+ #endif
1653
1663
  if (x[i] < min) min = x[i];
1654
1664
  if (x[i] > max) max = x[i];
1655
1665
  float w = weights ? weights[i] : x[i]*x[i];
@@ -1660,7 +1670,7 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
1660
1670
  min = 0;
1661
1671
  }
1662
1672
  if (max <= min) {
1663
- for (int i = 0; i < n; ++i) L[i] = 0;
1673
+ memset(L, 0, n);
1664
1674
  *the_min = -min;
1665
1675
  return 0.f;
1666
1676
  }
@@ -1862,7 +1872,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
1862
1872
 
1863
1873
  size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
1864
1874
  (void)hist;
1865
- int row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
1875
+ size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
1866
1876
  if (!quant_weights) {
1867
1877
  quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
1868
1878
  }
@@ -2181,7 +2191,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
2181
2191
 
2182
2192
  size_t quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2183
2193
  (void)hist;
2184
- int row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
2194
+ size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
2185
2195
  if (!quant_weights) {
2186
2196
  quantize_row_q3_K_reference(src, dst, nrow*n_per_row);
2187
2197
  }
@@ -2448,7 +2458,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
2448
2458
 
2449
2459
  size_t quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2450
2460
  (void)hist;
2451
- int row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
2461
+ size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
2452
2462
  if (!quant_weights) {
2453
2463
  quantize_row_q4_K_reference(src, dst, nrow*n_per_row);
2454
2464
  }
@@ -2771,7 +2781,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
2771
2781
 
2772
2782
  size_t quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
2773
2783
  (void)hist;
2774
- int row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
2784
+ size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
2775
2785
  if (!quant_weights) {
2776
2786
  quantize_row_q5_K_reference(src, dst, nrow*n_per_row);
2777
2787
  }
@@ -3025,7 +3035,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
3025
3035
 
3026
3036
  size_t quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
3027
3037
  (void)hist;
3028
- int row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
3038
+ size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
3029
3039
  if (!quant_weights) {
3030
3040
  quantize_row_q6_K_reference(src, dst, nrow*n_per_row);
3031
3041
  }
@@ -3072,7 +3082,7 @@ size_t quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int
3072
3082
  if (!quant_weights) {
3073
3083
  return ggml_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist);
3074
3084
  }
3075
- int row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
3085
+ size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
3076
3086
  char * qrow = (char *)dst;
3077
3087
  for (int row = 0; row < nrow; ++row) {
3078
3088
  quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
@@ -3116,7 +3126,7 @@ size_t quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int
3116
3126
  if (!quant_weights) {
3117
3127
  return ggml_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist);
3118
3128
  }
3119
- int row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
3129
+ size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
3120
3130
  char * qrow = (char *)dst;
3121
3131
  for (int row = 0; row < nrow; ++row) {
3122
3132
  quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
@@ -3169,7 +3179,7 @@ size_t quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int
3169
3179
  if (!quant_weights) {
3170
3180
  return ggml_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist);
3171
3181
  }
3172
- int row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
3182
+ size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
3173
3183
  char * qrow = (char *)dst;
3174
3184
  for (int row = 0; row < nrow; ++row) {
3175
3185
  quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
@@ -3221,7 +3231,7 @@ size_t quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int
3221
3231
  if (!quant_weights) {
3222
3232
  return ggml_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist);
3223
3233
  }
3224
- int row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
3234
+ size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
3225
3235
  char * qrow = (char *)dst;
3226
3236
  for (int row = 0; row < nrow; ++row) {
3227
3237
  quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
@@ -3431,6 +3441,41 @@ static const uint64_t iq2xs_grid[512] = {
3431
3441
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3432
3442
  };
3433
3443
 
3444
+ static const uint32_t iq3xxs_grid[256] = {
3445
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
3446
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
3447
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
3448
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
3449
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
3450
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
3451
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
3452
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
3453
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
3454
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
3455
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
3456
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
3457
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
3458
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
3459
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
3460
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
3461
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
3462
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
3463
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
3464
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
3465
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
3466
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
3467
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
3468
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
3469
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
3470
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
3471
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
3472
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
3473
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
3474
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
3475
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
3476
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3477
+ };
3478
+
3434
3479
  static const uint8_t ksigns_iq2xs[128] = {
3435
3480
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3436
3481
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3497,6 +3542,38 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
3497
3542
  }
3498
3543
  }
3499
3544
 
3545
+ // ====================== 3.0625 bpw (de)-quantization
3546
+
3547
+ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) {
3548
+ assert(k % QK_K == 0);
3549
+ const int nb = k / QK_K;
3550
+
3551
+ uint32_t aux32;
3552
+
3553
+ for (int i = 0; i < nb; i++) {
3554
+
3555
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3556
+ const uint8_t * qs = x[i].qs;
3557
+ const uint8_t * scales_and_signs = qs + QK_K/4;
3558
+
3559
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3560
+ memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
3561
+ const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
3562
+ for (int l = 0; l < 4; ++l) {
3563
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
3564
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
3565
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
3566
+ for (int j = 0; j < 4; ++j) {
3567
+ y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
3568
+ y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
3569
+ }
3570
+ y += 8;
3571
+ }
3572
+ qs += 8;
3573
+ }
3574
+ }
3575
+ }
3576
+
3500
3577
  //===================================== Q8_K ==============================================
3501
3578
 
3502
3579
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -8448,17 +8525,36 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8448
8525
 
8449
8526
  const __m128i m4 = _mm_set1_epi8(0xf);
8450
8527
  const __m128i m1 = _mm_set1_epi8(1);
8451
- const __m128i m511 = _mm_set1_epi16(511);
8452
- const __m128i m127 = _mm_set1_epi16(127);
8528
+ const __m256i m511 = _mm256_set1_epi16(511);
8529
+ const __m256i mone = _mm256_set1_epi8(1);
8453
8530
 
8454
- const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8531
+ static const uint8_t k_bit_helper[32] = {
8532
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
8533
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
8534
+ };
8535
+ static const char block_sign_shuffle_mask_1[32] = {
8536
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
8537
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
8538
+ };
8539
+ static const char block_sign_shuffle_mask_2[32] = {
8540
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
8541
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
8542
+ };
8543
+ static const uint8_t bit_selector_mask_bytes[32] = {
8544
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
8545
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
8546
+ };
8547
+
8548
+ const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
8549
+ const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
8550
+ const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
8551
+ const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
8455
8552
 
8456
8553
  uint64_t aux64;
8457
8554
 
8458
8555
  // somewhat hacky, but gives a significant boost in performance
8459
- __m128i aux_gindex, aux_sindex;
8556
+ __m256i aux_gindex;
8460
8557
  const uint16_t * gindex = (const uint16_t *)&aux_gindex;
8461
- const uint16_t * sindex = (const uint16_t *)&aux_sindex;
8462
8558
 
8463
8559
  __m256 accumf = _mm256_setzero_ps();
8464
8560
  for (int i = 0; i < nb; ++i) {
@@ -8473,26 +8569,68 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8473
8569
 
8474
8570
  __m256i sumi1 = _mm256_setzero_si256();
8475
8571
  __m256i sumi2 = _mm256_setzero_si256();
8476
- for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8572
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
8573
+
8574
+ const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
8575
+ aux_gindex = _mm256_and_si256(q2_data, m511);
8576
+
8577
+ const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
8578
+ const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
8579
+ const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
8580
+
8581
+ const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
8582
+ const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
8583
+
8477
8584
  const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8478
8585
  const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8479
- const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
8480
- aux_gindex = _mm_and_si128(q2_data, m511);
8481
- aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
8482
- const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
8483
- const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
8484
- const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
8485
- const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
8486
- const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8487
- const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8586
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8587
+ const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8588
+
8589
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
8590
+ iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
8591
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
8592
+ iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
8593
+ const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
8594
+ iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
8595
+ const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
8596
+ iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
8597
+
8598
+ const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
8599
+ const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
8600
+ const __m256i full_signs_1 = _mm256_set_m128i(full_signs_l, full_signs_l);
8601
+ const __m256i full_signs_2 = _mm256_set_m128i(full_signs_h, full_signs_h);
8602
+
8603
+ __m256i signs;
8604
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
8605
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8606
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
8607
+
8608
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
8609
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8610
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
8611
+
8612
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
8613
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8614
+ const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
8615
+
8616
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
8617
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
8618
+ const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
8619
+
8488
8620
  const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8489
8621
  const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8622
+ const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
8623
+ const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
8490
8624
 
8491
8625
  const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
8492
8626
  const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
8627
+ const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
8628
+ const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
8493
8629
 
8494
8630
  sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
8495
8631
  sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
8632
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
8633
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
8496
8634
  }
8497
8635
 
8498
8636
  accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
@@ -8541,6 +8679,136 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8541
8679
  #endif
8542
8680
  }
8543
8681
 
8682
+ // TODO
8683
+ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8684
+ assert(n % QK_K == 0);
8685
+
8686
+ const block_iq3_xxs * restrict x = vx;
8687
+ const block_q8_K * restrict y = vy;
8688
+
8689
+ const int nb = n / QK_K;
8690
+
8691
+ #if defined(__ARM_NEON)
8692
+
8693
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8694
+
8695
+ uint32_t aux32[2];
8696
+
8697
+ ggml_int8x16x4_t q3s;
8698
+ ggml_int8x16x4_t q8b;
8699
+
8700
+ float sumf = 0;
8701
+ for (int i = 0; i < nb; ++i) {
8702
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8703
+ const uint8_t * restrict q3 = x[i].qs;
8704
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8705
+ const int8_t * restrict q8 = y[i].qs;
8706
+ float sumf1 = 0, sumf2 = 0;
8707
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8708
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
8709
+ memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
8710
+ const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
8711
+ const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
8712
+ const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
8713
+ const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
8714
+ q3 += 16;
8715
+ q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
8716
+ q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
8717
+ q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
8718
+ q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
8719
+ q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
8720
+ q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
8721
+ q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
8722
+ q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
8723
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
8724
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
8725
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
8726
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
8727
+ }
8728
+ sumf += d*(sumf1 + sumf2);
8729
+ }
8730
+ *s = 0.5f * sumf;
8731
+
8732
+ #elif defined(__AVX2__)
8733
+
8734
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8735
+
8736
+ uint32_t aux32[2];
8737
+
8738
+ __m256 accumf = _mm256_setzero_ps();
8739
+ for (int i = 0; i < nb; ++i) {
8740
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8741
+ const uint8_t * restrict q3 = x[i].qs;
8742
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8743
+ const int8_t * restrict q8 = y[i].qs;
8744
+ __m256i sumi1 = _mm256_setzero_si256();
8745
+ __m256i sumi2 = _mm256_setzero_si256();
8746
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8747
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8748
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
8749
+ const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
8750
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
8751
+ q3 += 8;
8752
+ const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
8753
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
8754
+ q3 += 8;
8755
+ memcpy(aux32, gas, 8); gas += 8;
8756
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
8757
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
8758
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
8759
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
8760
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
8761
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
8762
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
8763
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
8764
+ const uint16_t ls1 = aux32[0] >> 28;
8765
+ const uint16_t ls2 = aux32[1] >> 28;
8766
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
8767
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
8768
+ sumi1 = _mm256_add_epi32(sumi1, p1);
8769
+ sumi2 = _mm256_add_epi32(sumi2, p2);
8770
+ }
8771
+
8772
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
8773
+
8774
+ }
8775
+
8776
+ *s = 0.25f * hsum_float_8(accumf);
8777
+
8778
+ #else
8779
+
8780
+ uint32_t aux32;
8781
+
8782
+ float sumf = 0.f;
8783
+ for (int i = 0; i < nb; ++i) {
8784
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8785
+ const uint8_t * restrict q3 = x[i].qs;
8786
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
8787
+ const int8_t * restrict q8 = y[i].qs;
8788
+ int32_t bsum = 0;
8789
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
8790
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
8791
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
8792
+ int32_t sumi = 0;
8793
+ for (int l = 0; l < 4; ++l) {
8794
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
8795
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
8796
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
8797
+ for (int j = 0; j < 4; ++j) {
8798
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
8799
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
8800
+ }
8801
+ q8 += 8;
8802
+ }
8803
+ q3 += 8;
8804
+ bsum += sumi * ls;
8805
+ }
8806
+ sumf += d * bsum;
8807
+ }
8808
+ *s = 0.25f * sumf;
8809
+ #endif
8810
+ }
8811
+
8544
8812
  // ================================ IQ2 quantization =============================================
8545
8813
 
8546
8814
  typedef struct {
@@ -8565,7 +8833,7 @@ static int iq2_compare_func(const void * left, const void * right) {
8565
8833
  return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
8566
8834
  }
8567
8835
 
8568
- static void q2xs_init_impl(int grid_size) {
8836
+ void iq2xs_init_impl(int grid_size) {
8569
8837
  const int gindex = iq2_data_index(grid_size);
8570
8838
  if (iq2_data[gindex].grid) {
8571
8839
  return;
@@ -8720,19 +8988,7 @@ static void q2xs_init_impl(int grid_size) {
8720
8988
  free(dist2);
8721
8989
  }
8722
8990
 
8723
- void ggml_init_iq2_quantization(enum ggml_type type) {
8724
- if (type == GGML_TYPE_IQ2_XXS) {
8725
- q2xs_init_impl(256);
8726
- }
8727
- else if (type == GGML_TYPE_IQ2_XS) {
8728
- q2xs_init_impl(512);
8729
- }
8730
- else {
8731
- fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
8732
- }
8733
- }
8734
-
8735
- static void q2xs_deinit_impl(int grid_size) {
8991
+ void iq2xs_free_impl(int grid_size) {
8736
8992
  GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
8737
8993
  const int gindex = iq2_data_index(grid_size);
8738
8994
  if (iq2_data[gindex].grid) {
@@ -8742,18 +8998,6 @@ static void q2xs_deinit_impl(int grid_size) {
8742
8998
  }
8743
8999
  }
8744
9000
 
8745
- void ggml_deinit_iq2_quantization(enum ggml_type type) {
8746
- if (type == GGML_TYPE_IQ2_XXS) {
8747
- q2xs_deinit_impl(256);
8748
- }
8749
- else if (type == GGML_TYPE_IQ2_XS) {
8750
- q2xs_deinit_impl(512);
8751
- }
8752
- else {
8753
- fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
8754
- }
8755
- }
8756
-
8757
9001
  static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
8758
9002
  const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
8759
9003
  int num_neighbors = neighbours[0];
@@ -8786,10 +9030,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
8786
9030
  const int * kmap_q2xs = iq2_data[gindex].map;
8787
9031
  const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
8788
9032
 
8789
- GGML_ASSERT(quant_weights);
8790
- GGML_ASSERT(kgrid_q2xs);
8791
- GGML_ASSERT(kmap_q2xs);
8792
- GGML_ASSERT(kneighbors_q2xs);
9033
+ GGML_ASSERT(quant_weights && "missing quantization weights");
9034
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
9035
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
9036
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
8793
9037
  GGML_ASSERT(n%QK_K == 0);
8794
9038
 
8795
9039
  const int kMaxQ = 3;
@@ -9005,10 +9249,10 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
9005
9249
  const int * kmap_q2xs = iq2_data[gindex].map;
9006
9250
  const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
9007
9251
 
9008
- GGML_ASSERT(quant_weights);
9009
- GGML_ASSERT(kmap_q2xs);
9010
- GGML_ASSERT(kgrid_q2xs);
9011
- GGML_ASSERT(kneighbors_q2xs);
9252
+ GGML_ASSERT(quant_weights && "missing quantization weights");
9253
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
9254
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
9255
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
9012
9256
  GGML_ASSERT(n%QK_K == 0);
9013
9257
 
9014
9258
  const int kMaxQ = 3;
@@ -9203,3 +9447,436 @@ size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, i
9203
9447
  return nrow * nblock * sizeof(block_iq2_xs);
9204
9448
  }
9205
9449
 
9450
+ //
9451
+ // ============================================= 3-bit using D4 lattice
9452
+ //
9453
+
9454
+ typedef struct {
9455
+ uint32_t * grid;
9456
+ int * map;
9457
+ uint16_t * neighbours;
9458
+ } iq3_entry_t;
9459
+
9460
+ static iq3_entry_t iq3_data[1] = {
9461
+ {NULL, NULL, NULL},
9462
+ };
9463
+
9464
+ static inline int iq3_data_index(int grid_size) {
9465
+ (void)grid_size;
9466
+ GGML_ASSERT(grid_size == 256);
9467
+ return 0;
9468
+ }
9469
+
9470
+ static int iq3_compare_func(const void * left, const void * right) {
9471
+ const int * l = (const int *)left;
9472
+ const int * r = (const int *)right;
9473
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
9474
+ }
9475
+
9476
+ void iq3xs_init_impl(int grid_size) {
9477
+ const int gindex = iq3_data_index(grid_size);
9478
+ if (iq3_data[gindex].grid) {
9479
+ return;
9480
+ }
9481
+ static const uint16_t kgrid_256[256] = {
9482
+ 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
9483
+ 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
9484
+ 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
9485
+ 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
9486
+ 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
9487
+ 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
9488
+ 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
9489
+ 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
9490
+ 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
9491
+ 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
9492
+ 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
9493
+ 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
9494
+ 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
9495
+ 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
9496
+ 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
9497
+ 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
9498
+ };
9499
+ const int kmap_size = 4096;
9500
+ const int nwant = 2;
9501
+ const uint16_t * kgrid = kgrid_256;
9502
+ uint32_t * kgrid_q3xs;
9503
+ int * kmap_q3xs;
9504
+ uint16_t * kneighbors_q3xs;
9505
+
9506
+ printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
9507
+ uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
9508
+ for (int k = 0; k < grid_size; ++k) {
9509
+ int8_t * pos = (int8_t *)(the_grid + k);
9510
+ for (int i = 0; i < 4; ++i) {
9511
+ int l = (kgrid[k] >> 3*i) & 0x7;
9512
+ pos[i] = 2*l + 1;
9513
+ }
9514
+ }
9515
+ kgrid_q3xs = the_grid;
9516
+ iq3_data[gindex].grid = the_grid;
9517
+ kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
9518
+ iq3_data[gindex].map = kmap_q3xs;
9519
+ for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
9520
+ uint32_t aux32;
9521
+ uint8_t * aux8 = (uint8_t *)&aux32;
9522
+ for (int i = 0; i < grid_size; ++i) {
9523
+ aux32 = kgrid_q3xs[i];
9524
+ uint16_t index = 0;
9525
+ for (int k=0; k<4; ++k) {
9526
+ uint16_t q = (aux8[k] - 1)/2;
9527
+ index |= (q << 3*k);
9528
+ }
9529
+ kmap_q3xs[index] = i;
9530
+ }
9531
+ int8_t pos[4];
9532
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
9533
+ int num_neighbors = 0, num_not_in_map = 0;
9534
+ for (int i = 0; i < kmap_size; ++i) {
9535
+ if (kmap_q3xs[i] >= 0) continue;
9536
+ ++num_not_in_map;
9537
+ for (int k = 0; k < 4; ++k) {
9538
+ int l = (i >> 3*k) & 0x7;
9539
+ pos[k] = 2*l + 1;
9540
+ }
9541
+ for (int j = 0; j < grid_size; ++j) {
9542
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
9543
+ int d2 = 0;
9544
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
9545
+ dist2[2*j+0] = d2;
9546
+ dist2[2*j+1] = j;
9547
+ }
9548
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
9549
+ int n = 0; int d2 = dist2[0];
9550
+ int nhave = 1;
9551
+ for (int j = 0; j < grid_size; ++j) {
9552
+ if (dist2[2*j] > d2) {
9553
+ if (nhave == nwant) break;
9554
+ d2 = dist2[2*j];
9555
+ ++nhave;
9556
+ }
9557
+ ++n;
9558
+ }
9559
+ num_neighbors += n;
9560
+ }
9561
+ printf("%s: %d neighbours in total\n", __func__, num_neighbors);
9562
+ kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
9563
+ iq3_data[gindex].neighbours = kneighbors_q3xs;
9564
+ int counter = 0;
9565
+ for (int i = 0; i < kmap_size; ++i) {
9566
+ if (kmap_q3xs[i] >= 0) continue;
9567
+ for (int k = 0; k < 4; ++k) {
9568
+ int l = (i >> 3*k) & 0x7;
9569
+ pos[k] = 2*l + 1;
9570
+ }
9571
+ for (int j = 0; j < grid_size; ++j) {
9572
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
9573
+ int d2 = 0;
9574
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
9575
+ dist2[2*j+0] = d2;
9576
+ dist2[2*j+1] = j;
9577
+ }
9578
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
9579
+ kmap_q3xs[i] = -(counter + 1);
9580
+ int d2 = dist2[0];
9581
+ uint16_t * start = &kneighbors_q3xs[counter++];
9582
+ int n = 0, nhave = 1;
9583
+ for (int j = 0; j < grid_size; ++j) {
9584
+ if (dist2[2*j] > d2) {
9585
+ if (nhave == nwant) break;
9586
+ d2 = dist2[2*j];
9587
+ ++nhave;
9588
+ }
9589
+ kneighbors_q3xs[counter++] = dist2[2*j+1];
9590
+ ++n;
9591
+ }
9592
+ *start = n;
9593
+ }
9594
+ free(dist2);
9595
+ }
9596
+
9597
+ void iq3xs_free_impl(int grid_size) {
9598
+ GGML_ASSERT(grid_size == 256);
9599
+ const int gindex = iq3_data_index(grid_size);
9600
+ if (iq3_data[gindex].grid) {
9601
+ free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
9602
+ free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
9603
+ free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
9604
+ }
9605
+ }
9606
+
9607
+ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
9608
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
9609
+ int num_neighbors = neighbours[0];
9610
+ GGML_ASSERT(num_neighbors > 0);
9611
+ float best_d2 = FLT_MAX;
9612
+ int grid_index = -1;
9613
+ for (int j = 1; j <= num_neighbors; ++j) {
9614
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
9615
+ float d2 = 0;
9616
+ for (int i = 0; i < 4; ++i) {
9617
+ float q = pg[i];
9618
+ float diff = scale*q - xval[i];
9619
+ d2 += weight[i]*diff*diff;
9620
+ }
9621
+ if (d2 < best_d2) {
9622
+ best_d2 = d2; grid_index = neighbours[j];
9623
+ }
9624
+ }
9625
+ GGML_ASSERT(grid_index >= 0);
9626
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
9627
+ for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
9628
+ return grid_index;
9629
+ }
9630
+
9631
+ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9632
+
9633
+ const int gindex = iq3_data_index(256);
9634
+
9635
+ const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
9636
+ const int * kmap_q3xs = iq3_data[gindex].map;
9637
+ const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
9638
+
9639
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
9640
+ GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
9641
+ GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
9642
+ GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
9643
+ GGML_ASSERT(n%QK_K == 0);
9644
+
9645
+ const int kMaxQ = 8;
9646
+
9647
+ const int nbl = n/256;
9648
+
9649
+ block_iq3_xxs * y = vy;
9650
+
9651
+ float scales[QK_K/32];
9652
+ float weight[32];
9653
+ float xval[32];
9654
+ int8_t L[32];
9655
+ int8_t Laux[32];
9656
+ float waux[32];
9657
+ bool is_on_grid[8];
9658
+ bool is_on_grid_aux[8];
9659
+ uint8_t block_signs[8];
9660
+ uint8_t q3[3*(QK_K/8)];
9661
+ uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
9662
+
9663
+ for (int ibl = 0; ibl < nbl; ++ibl) {
9664
+
9665
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
9666
+ memset(q3, 0, 3*QK_K/8);
9667
+
9668
+ float max_scale = 0;
9669
+
9670
+ const float * xbl = x + QK_K*ibl;
9671
+ float sumx2 = 0;
9672
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
9673
+ float sigma2 = sumx2/QK_K;
9674
+
9675
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9676
+ const float * xb = xbl + 32*ib;
9677
+ if (quant_weights) {
9678
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
9679
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9680
+ } else {
9681
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
9682
+ }
9683
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
9684
+ for (int k = 0; k < 4; ++k) {
9685
+ int nflip = 0;
9686
+ uint8_t s = 0;
9687
+ for (int i = 0; i < 8; ++i) {
9688
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
9689
+ else {
9690
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
9691
+ }
9692
+ }
9693
+ if (nflip%2) {
9694
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
9695
+ for (int i = 1; i < 8; ++i) {
9696
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
9697
+ if (ax < min) {
9698
+ min = ax; imin = i;
9699
+ }
9700
+ }
9701
+ xval[8*k+imin] = -xval[8*k+imin];
9702
+ s ^= (1 << imin);
9703
+ }
9704
+ block_signs[k] = s & 127;
9705
+ }
9706
+ float max = xval[0];
9707
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
9708
+ if (!max) {
9709
+ scales[ib] = 0;
9710
+ memset(L, 0, 32);
9711
+ continue;
9712
+ }
9713
+ float best = 0;
9714
+ float scale = max/(2*kMaxQ-1);
9715
+ for (int is = -15; is <= 15; ++is) {
9716
+ float id = (2*kMaxQ-1+is*0.2f)/max;
9717
+ float this_scale = 1/id;
9718
+ for (int k = 0; k < 8; ++k) {
9719
+ for (int i = 0; i < 4; ++i) {
9720
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
9721
+ Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
9722
+ }
9723
+ uint16_t u = 0;
9724
+ for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
9725
+ int grid_index = kmap_q3xs[u];
9726
+ is_on_grid_aux[k] = true;
9727
+ if (grid_index < 0) {
9728
+ is_on_grid_aux[k] = false;
9729
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
9730
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
9731
+ }
9732
+ }
9733
+ float sumqx = 0, sumq2 = 0;
9734
+ for (int i = 0; i < 32; ++i) {
9735
+ float w = weight[i];
9736
+ float q = 2*Laux[i] + 1;
9737
+ sumqx += w*xval[i]*q;
9738
+ sumq2 += w*q*q;
9739
+ }
9740
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
9741
+ scale = sumqx/sumq2; best = scale*sumqx;
9742
+ for (int i = 0; i < 32; ++i) L[i] = Laux[i];
9743
+ for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
9744
+ }
9745
+ }
9746
+ int n_not_ongrid = 0;
9747
+ for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9748
+ if (n_not_ongrid > 0 && scale > 0) {
9749
+ float id = 1/scale;
9750
+ for (int k = 0; k < 8; ++k) {
9751
+ if (is_on_grid[k]) continue;
9752
+ uint16_t u = 0;
9753
+ for (int i = 0; i < 4; ++i) {
9754
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
9755
+ l = MAX(0, MIN(kMaxQ-1, l));
9756
+ u |= (l << 3*i);
9757
+ }
9758
+ int grid_index = kmap_q3xs[u];
9759
+ if (grid_index < 0) {
9760
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
9761
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
9762
+ }
9763
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
9764
+ for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
9765
+ }
9766
+ float sumqx = 0, sumq2 = 0;
9767
+ for (int i = 0; i < 32; ++i) {
9768
+ float w = weight[i];
9769
+ float q = 2*L[i] + 1;
9770
+ sumqx += w*xval[i]*q;
9771
+ sumq2 += w*q*q;
9772
+ }
9773
+ if (sumq2 > 0) scale = sumqx/sumq2;
9774
+ }
9775
+ if (scale < 0) {
9776
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
9777
+ // and correspondingly flip quant signs.
9778
+ scale = -scale;
9779
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
9780
+ }
9781
+ for (int k = 0; k < 8; ++k) {
9782
+ uint16_t u = 0;
9783
+ for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
9784
+ int grid_index = kmap_q3xs[u];
9785
+ if (grid_index < 0) {
9786
+ printf("Oops: found point %u not on grid:", u);
9787
+ for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
9788
+ printf("\n");
9789
+ GGML_ASSERT(false);
9790
+ }
9791
+ q3[8*ib+k] = grid_index;
9792
+ }
9793
+ scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
9794
+ GGML_ASSERT(scale >= 0);
9795
+ scales[ib] = scale;
9796
+ max_scale = MAX(max_scale, scale);
9797
+ }
9798
+
9799
+ if (!max_scale) {
9800
+ memset(y[ibl].qs, 0, 3*QK_K/8);
9801
+ continue;
9802
+ }
9803
+
9804
+ float d = max_scale/31;
9805
+ y[ibl].d = GGML_FP32_TO_FP16(d);
9806
+ float id = 1/d;
9807
+ float sumqx = 0, sumq2 = 0;
9808
+ for (int ib = 0; ib < QK_K/32; ++ib) {
9809
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
9810
+ l = MAX(0, MIN(15, l));
9811
+ scales_and_signs[ib] |= ((uint32_t)l << 28);
9812
+ if (false) {
9813
+ const float * xb = xbl + 32*ib;
9814
+ if (quant_weights) {
9815
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
9816
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9817
+ } else {
9818
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
9819
+ }
9820
+ const float db = 0.25f * d * (1 + 2*l);
9821
+ for (int k = 0; k < 8; ++k) {
9822
+ const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2);
9823
+ const float * xk = xb + 4*k;
9824
+ const float * wk = weight + 4*k;
9825
+ //const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
9826
+ const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
9827
+ float best_mse = 0; int best_index = q3[8*ib+k];
9828
+ for (int j = 0; j < 4; ++j) {
9829
+ float diff = db * grid[j] * signs[j] - xk[j];
9830
+ best_mse += wk[j] * diff * diff;
9831
+ }
9832
+ for (int idx = 0; idx < 256; ++idx) {
9833
+ //grid = (const uint8_t *)(kgrid_q3xs + idx);
9834
+ grid = (const uint8_t *)(iq3xxs_grid + idx);
9835
+ float mse = 0;
9836
+ for (int j = 0; j < 4; ++j) {
9837
+ float diff = db * grid[j] * signs[j] - xk[j];
9838
+ mse += wk[j] * diff * diff;
9839
+ }
9840
+ if (mse < best_mse) {
9841
+ best_mse = mse; best_index = idx;
9842
+ }
9843
+ }
9844
+ q3[8*ib+k] = best_index;
9845
+ //grid = (const uint8_t *)(kgrid_q3xs + best_index);
9846
+ grid = (const uint8_t *)(iq3xxs_grid + best_index);
9847
+ for (int j = 0; j < 4; ++j) {
9848
+ float q = db * grid[j] * signs[j];
9849
+ sumqx += wk[j] * q * xk[j];
9850
+ sumq2 += wk[j] * q * q;
9851
+ }
9852
+ }
9853
+ if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
9854
+ }
9855
+ }
9856
+ memcpy(y[ibl].qs, q3, 3*QK_K/8);
9857
+ }
9858
+ }
9859
+
9860
+ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
9861
+ (void)hist;
9862
+ GGML_ASSERT(n_per_row%QK_K == 0);
9863
+ int nblock = n_per_row/QK_K;
9864
+ char * qrow = (char *)dst;
9865
+ for (int row = 0; row < nrow; ++row) {
9866
+ quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights);
9867
+ src += n_per_row;
9868
+ qrow += nblock*sizeof(block_iq3_xxs);
9869
+ }
9870
+ return nrow * nblock * sizeof(block_iq3_xxs);
9871
+ }
9872
+
9873
+ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
9874
+ assert(k % QK_K == 0);
9875
+ block_iq3_xxs * restrict y = vy;
9876
+ quantize_row_iq3_xxs_reference(x, y, k);
9877
+ }
9878
+
9879
+ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
9880
+ assert(k % QK_K == 0);
9881
+ quantize_row_iq3_xxs_impl(x, y, k, NULL);
9882
+ }