llama_cpp 0.12.4 → 0.12.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,6 +49,8 @@
49
49
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
50
50
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
51
51
 
52
+ #define UNUSED GGML_UNUSED
53
+
52
54
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
53
55
 
54
56
  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
@@ -268,6 +270,17 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
268
270
  #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
269
271
 
270
272
  #if defined(__ARM_NEON)
273
+
274
+ #ifdef _MSC_VER
275
+
276
+ #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
277
+
278
+ #else
279
+
280
+ #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
281
+
282
+ #endif
283
+
271
284
  #if !defined(__aarch64__)
272
285
 
273
286
  // 64-bit compatibility
@@ -2381,19 +2394,20 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
2381
2394
 
2382
2395
  uint8_t L[QK_K];
2383
2396
  uint8_t Laux[32];
2397
+ uint8_t Ls[QK_K/32];
2398
+ uint8_t Lm[QK_K/32];
2384
2399
  float weights[32];
2385
- float mins[QK_K/32];
2386
- float scales[QK_K/32];
2400
+ float sw[QK_K/32];
2401
+ float mins[QK_K/32];
2402
+ float scales[QK_K/32];
2387
2403
 
2388
2404
  for (int i = 0; i < nb; i++) {
2389
2405
 
2390
2406
  float sum_x2 = 0;
2391
2407
  for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
2392
- float sigma2 = sum_x2/QK_K;
2408
+ float sigma2 = 2*sum_x2/QK_K;
2393
2409
  float av_x = sqrtf(sigma2);
2394
2410
 
2395
- float max_scale = 0; // as we are deducting the min, scales are always positive
2396
- float max_min = 0;
2397
2411
  for (int j = 0; j < QK_K/32; ++j) {
2398
2412
  if (quant_weights) {
2399
2413
  const float * qw = quant_weights + QK_K*i + 32*j;
@@ -2401,25 +2415,17 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
2401
2415
  } else {
2402
2416
  for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2403
2417
  }
2418
+ float sumw = 0;
2419
+ for (int l = 0; l < 32; ++l) sumw += weights[l];
2420
+ sw[j] = sumw;
2404
2421
  scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
2405
- //scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
2406
- float scale = scales[j];
2407
- if (scale > max_scale) {
2408
- max_scale = scale;
2409
- }
2410
- float min = mins[j];
2411
- if (min > max_min) {
2412
- max_min = min;
2413
- }
2414
2422
  }
2415
2423
 
2416
- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2417
- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2424
+ float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
2425
+ float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw);
2418
2426
  for (int j = 0; j < QK_K/32; ++j) {
2419
- uint8_t ls = nearest_int(inv_scale*scales[j]);
2420
- uint8_t lm = nearest_int(inv_min*mins[j]);
2421
- ls = MIN(63, ls);
2422
- lm = MIN(63, lm);
2427
+ uint8_t ls = Ls[j];
2428
+ uint8_t lm = Lm[j];
2423
2429
  if (j < 4) {
2424
2430
  y[i].scales[j] = ls;
2425
2431
  y[i].scales[j+4] = lm;
@@ -2429,8 +2435,8 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
2429
2435
  y[i].scales[j-0] |= ((lm >> 4) << 6);
2430
2436
  }
2431
2437
  }
2432
- y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
2433
- y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
2438
+ y[i].d = GGML_FP32_TO_FP16(d_block);
2439
+ y[i].dmin = GGML_FP32_TO_FP16(m_block);
2434
2440
 
2435
2441
  uint8_t sc, m;
2436
2442
  for (int j = 0; j < QK_K/32; ++j) {
@@ -2688,20 +2694,21 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
2688
2694
  const int nb = n_per_row / QK_K;
2689
2695
 
2690
2696
  uint8_t L[QK_K];
2691
- float mins[QK_K/32];
2692
- float scales[QK_K/32];
2693
- float weights[32];
2694
2697
  uint8_t Laux[32];
2698
+ uint8_t Ls[QK_K/32];
2699
+ uint8_t Lm[QK_K/32];
2700
+ float mins[QK_K/32];
2701
+ float scales[QK_K/32];
2702
+ float sw[QK_K/32];
2703
+ float weights[32];
2695
2704
 
2696
2705
  for (int i = 0; i < nb; i++) {
2697
2706
 
2698
2707
  float sum_x2 = 0;
2699
2708
  for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
2700
- float sigma2 = sum_x2/QK_K;
2709
+ float sigma2 = 2*sum_x2/QK_K;
2701
2710
  float av_x = sqrtf(sigma2);
2702
2711
 
2703
- float max_scale = 0; // as we are deducting the min, scales are always positive
2704
- float max_min = 0;
2705
2712
  for (int j = 0; j < QK_K/32; ++j) {
2706
2713
  if (quant_weights) {
2707
2714
  const float * qw = quant_weights + QK_K*i + 32*j;
@@ -2709,22 +2716,19 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
2709
2716
  } else {
2710
2717
  for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
2711
2718
  }
2719
+ float sumw = 0;
2720
+ for (int l = 0; l < 32; ++l) sumw += weights[l];
2721
+ sw[j] = sumw;
2722
+
2712
2723
  scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
2713
- float scale = scales[j];
2714
- if (scale > max_scale) {
2715
- max_scale = scale;
2716
- }
2717
- float min = mins[j];
2718
- if (min > max_min) {
2719
- max_min = min;
2720
- }
2721
2724
  }
2722
2725
 
2723
- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
2724
- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
2726
+ float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
2727
+ float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw);
2728
+
2725
2729
  for (int j = 0; j < QK_K/32; ++j) {
2726
- uint8_t ls = nearest_int(inv_scale*scales[j]);
2727
- uint8_t lm = nearest_int(inv_min*mins[j]);
2730
+ uint8_t ls = Ls[j];
2731
+ uint8_t lm = Lm[j];
2728
2732
  ls = MIN(63, ls);
2729
2733
  lm = MIN(63, lm);
2730
2734
  if (j < 4) {
@@ -2736,8 +2740,8 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
2736
2740
  y[i].scales[j-0] |= ((lm >> 4) << 6);
2737
2741
  }
2738
2742
  }
2739
- y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
2740
- y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
2743
+ y[i].d = GGML_FP32_TO_FP16(d_block);
2744
+ y[i].dmin = GGML_FP32_TO_FP16(m_block);
2741
2745
 
2742
2746
  uint8_t sc, m;
2743
2747
  for (int j = 0; j < QK_K/32; ++j) {
@@ -3675,15 +3679,92 @@ static inline __m128i get_scale_shuffle(int i) {
3675
3679
  }
3676
3680
  #endif
3677
3681
 
3678
- void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3682
+ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
3679
3683
  const int qk = QK8_0;
3680
3684
  const int nb = n / qk;
3681
3685
 
3682
3686
  assert(n % qk == 0);
3687
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
3688
+ assert((nrc == 2) || (nrc == 1));
3689
+ #else
3690
+ assert(nrc == 1);
3691
+ #endif
3692
+ UNUSED(nrc);
3693
+ UNUSED(bx);
3694
+ UNUSED(by);
3695
+ UNUSED(bs);
3683
3696
 
3684
3697
  const block_q4_0 * restrict x = vx;
3685
3698
  const block_q8_0 * restrict y = vy;
3686
3699
 
3700
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
3701
+ if (nrc == 2) {
3702
+ const block_q4_0 * restrict vx0 = vx;
3703
+ const block_q4_0 * restrict vx1 = vx + bx;
3704
+
3705
+ const block_q8_0 * restrict vy0 = vy;
3706
+ const block_q8_0 * restrict vy1 = vy + by;
3707
+
3708
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3709
+
3710
+ for (int i = 0; i < nb; i++) {
3711
+ const block_q4_0 * restrict b_x0 = &vx0[i];
3712
+ const block_q4_0 * restrict b_x1 = &vx1[i];
3713
+ const block_q8_0 * restrict b_y0 = &vy0[i];
3714
+ const block_q8_0 * restrict b_y1 = &vy1[i];
3715
+
3716
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3717
+ const int8x16_t s8b = vdupq_n_s8(0x8);
3718
+
3719
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
3720
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
3721
+
3722
+ // 4-bit -> 8-bit
3723
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3724
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3725
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3726
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3727
+
3728
+ // sub 8
3729
+ const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
3730
+ const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
3731
+ const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
3732
+ const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
3733
+
3734
+ // load y
3735
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
3736
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
3737
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
3738
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3739
+
3740
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3741
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3742
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3743
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3744
+
3745
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3746
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3747
+
3748
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
3749
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
3750
+
3751
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
3752
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
3753
+
3754
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
3755
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
3756
+
3757
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
3758
+ l1, r1)), l2, r2)), l3, r3))), scale);
3759
+ }
3760
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3761
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3762
+
3763
+ vst1_f32(s, vget_low_f32(sumv2));
3764
+ vst1_f32(s + bs, vget_high_f32(sumv2));
3765
+ return;
3766
+ }
3767
+ #endif
3687
3768
  #if defined(__ARM_NEON)
3688
3769
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3689
3770
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -3738,15 +3819,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
3738
3819
  /* Compute combined scale for the block */
3739
3820
  const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
3740
3821
 
3741
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3822
+ __m256i qx = bytes_from_nibbles_32(x[i].qs);
3742
3823
 
3743
3824
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3744
3825
  const __m256i off = _mm256_set1_epi8( 8 );
3745
- bx = _mm256_sub_epi8( bx, off );
3826
+ qx = _mm256_sub_epi8( qx, off );
3746
3827
 
3747
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3828
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
3748
3829
 
3749
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
3830
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
3750
3831
 
3751
3832
  /* Multiply q with scale and accumulate */
3752
3833
  acc = _mm256_fmadd_ps( d, q, acc );
@@ -3965,15 +4046,93 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
3965
4046
  #endif
3966
4047
  }
3967
4048
 
3968
- void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4049
+ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
3969
4050
  const int qk = QK8_1;
3970
4051
  const int nb = n / qk;
3971
4052
 
3972
4053
  assert(n % qk == 0);
4054
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4055
+ assert((nrc == 2) || (nrc == 1));
4056
+ #else
4057
+ assert(nrc == 1);
4058
+ #endif
4059
+ UNUSED(nrc);
4060
+ UNUSED(bx);
4061
+ UNUSED(by);
4062
+ UNUSED(bs);
3973
4063
 
3974
4064
  const block_q4_1 * restrict x = vx;
3975
4065
  const block_q8_1 * restrict y = vy;
3976
4066
 
4067
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4068
+ if (nrc == 2) {
4069
+ const block_q4_1 * restrict vx0 = vx;
4070
+ const block_q4_1 * restrict vx1 = vx + bx;
4071
+ const block_q8_1 * restrict vy0 = vy;
4072
+ const block_q8_1 * restrict vy1 = vy + by;
4073
+
4074
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
4075
+ float32x4_t summs0 = vdupq_n_f32(0.0f);
4076
+
4077
+ for (int i = 0; i < nb; i++) {
4078
+ const block_q4_1 * restrict b_x0 = &vx0[i];
4079
+ const block_q4_1 * restrict b_x1 = &vx1[i];
4080
+ const block_q8_1 * restrict b_y0 = &vy0[i];
4081
+ const block_q8_1 * restrict b_y1 = &vy1[i];
4082
+
4083
+ float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
4084
+ GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
4085
+ GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
4086
+ GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
4087
+ summs0 += summs_t;
4088
+
4089
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
4090
+
4091
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
4092
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
4093
+
4094
+ // 4-bit -> 8-bit
4095
+ const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
4096
+ const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
4097
+ const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
4098
+ const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
4099
+
4100
+ // load y
4101
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4102
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4103
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4104
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4105
+
4106
+ // mmla into int32x4_t
4107
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4108
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4109
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4110
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4111
+
4112
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4113
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4114
+
4115
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4116
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4117
+
4118
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4119
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4120
+
4121
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4122
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4123
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
4124
+ l1, r1)), l2, r2)), l3, r3))), scale);
4125
+ }
4126
+
4127
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
4128
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
4129
+ sumv2 = sumv2 + summs0;
4130
+
4131
+ vst1_f32(s, vget_low_f32(sumv2));
4132
+ vst1_f32(s + bs, vget_high_f32(sumv2));
4133
+ return;
4134
+ }
4135
+ #endif
3977
4136
  // TODO: add WASM SIMD
3978
4137
  #if defined(__ARM_NEON)
3979
4138
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -4037,10 +4196,10 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
4037
4196
  const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
4038
4197
 
4039
4198
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
4040
- const __m256i bx = bytes_from_nibbles_32(x[i].qs);
4041
- const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
4199
+ const __m256i qx = bytes_from_nibbles_32(x[i].qs);
4200
+ const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs );
4042
4201
 
4043
- const __m256 xy = mul_sum_us8_pairs_float(bx, by);
4202
+ const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
4044
4203
 
4045
4204
  // Accumulate d0*d1*x*y
4046
4205
  #if defined(__AVX2__)
@@ -4105,12 +4264,17 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
4105
4264
  #endif
4106
4265
  }
4107
4266
 
4108
- void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4267
+ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4109
4268
  const int qk = QK8_0;
4110
4269
  const int nb = n / qk;
4111
4270
 
4112
4271
  assert(n % qk == 0);
4113
4272
  assert(qk == QK5_0);
4273
+ assert(nrc == 1);
4274
+ UNUSED(nrc);
4275
+ UNUSED(bx);
4276
+ UNUSED(by);
4277
+ UNUSED(bs);
4114
4278
 
4115
4279
  const block_q5_0 * restrict x = vx;
4116
4280
  const block_q8_0 * restrict y = vy;
@@ -4254,14 +4418,14 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
4254
4418
  /* Compute combined scale for the block */
4255
4419
  const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
4256
4420
 
4257
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4421
+ __m256i qx = bytes_from_nibbles_32(x[i].qs);
4258
4422
  __m256i bxhi = bytes_from_bits_32(x[i].qh);
4259
4423
  bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
4260
- bx = _mm256_or_si256(bx, bxhi);
4424
+ qx = _mm256_or_si256(qx, bxhi);
4261
4425
 
4262
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4426
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
4263
4427
 
4264
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
4428
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
4265
4429
 
4266
4430
  /* Multiply q with scale and accumulate */
4267
4431
  acc = _mm256_fmadd_ps(d, q, acc);
@@ -4391,12 +4555,17 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
4391
4555
  #endif
4392
4556
  }
4393
4557
 
4394
- void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4558
+ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4395
4559
  const int qk = QK8_1;
4396
4560
  const int nb = n / qk;
4397
4561
 
4398
4562
  assert(n % qk == 0);
4399
4563
  assert(qk == QK5_1);
4564
+ assert(nrc == 1);
4565
+ UNUSED(nrc);
4566
+ UNUSED(bx);
4567
+ UNUSED(by);
4568
+ UNUSED(bs);
4400
4569
 
4401
4570
  const block_q5_1 * restrict x = vx;
4402
4571
  const block_q8_1 * restrict y = vy;
@@ -4553,15 +4722,15 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
4553
4722
 
4554
4723
  summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
4555
4724
 
4556
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4725
+ __m256i qx = bytes_from_nibbles_32(x[i].qs);
4557
4726
  __m256i bxhi = bytes_from_bits_32(x[i].qh);
4558
4727
  bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
4559
- bx = _mm256_or_si256(bx, bxhi);
4728
+ qx = _mm256_or_si256(qx, bxhi);
4560
4729
 
4561
4730
  const __m256 dy = _mm256_set1_ps(y[i].d);
4562
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4731
+ const __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
4563
4732
 
4564
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
4733
+ const __m256 q = mul_sum_us8_pairs_float(qx, qy);
4565
4734
 
4566
4735
  acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
4567
4736
  }
@@ -4690,15 +4859,79 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
4690
4859
  #endif
4691
4860
  }
4692
4861
 
4693
- void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4862
+ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4694
4863
  const int qk = QK8_0;
4695
4864
  const int nb = n / qk;
4696
4865
 
4697
4866
  assert(n % qk == 0);
4867
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4868
+ assert((nrc == 2) || (nrc == 1));
4869
+ #else
4870
+ assert(nrc == 1);
4871
+ #endif
4872
+ UNUSED(nrc);
4873
+ UNUSED(bx);
4874
+ UNUSED(by);
4875
+ UNUSED(bs);
4698
4876
 
4699
4877
  const block_q8_0 * restrict x = vx;
4700
4878
  const block_q8_0 * restrict y = vy;
4701
4879
 
4880
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4881
+ if (nrc == 2) {
4882
+ const block_q8_0 * restrict vx0 = vx;
4883
+ const block_q8_0 * restrict vx1 = vx + bx;
4884
+ const block_q8_0 * restrict vy0 = vy;
4885
+ const block_q8_0 * restrict vy1 = vy + by;
4886
+
4887
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
4888
+
4889
+ for (int i = 0; i < nb; i++) {
4890
+ const block_q8_0 * restrict b_x0 = &vx0[i];
4891
+ const block_q8_0 * restrict b_y0 = &vy0[i];
4892
+
4893
+ const block_q8_0 * restrict b_x1 = &vx1[i];
4894
+ const block_q8_0 * restrict b_y1 = &vy1[i];
4895
+
4896
+ const int8x16_t x0_l = vld1q_s8(b_x0->qs);
4897
+ const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
4898
+ const int8x16_t x1_l = vld1q_s8(b_x1->qs);
4899
+ const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
4900
+
4901
+ // load y
4902
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4903
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4904
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4905
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4906
+
4907
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4908
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4909
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4910
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4911
+
4912
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4913
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4914
+
4915
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4916
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4917
+
4918
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4919
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4920
+
4921
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4922
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4923
+
4924
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
4925
+ l1, r1)), l2, r2)), l3, r3))), scale);
4926
+ }
4927
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
4928
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
4929
+
4930
+ vst1_f32(s, vget_low_f32(sumv2));
4931
+ vst1_f32(s + bs, vget_high_f32(sumv2));
4932
+ return;
4933
+ }
4934
+ #endif
4702
4935
  #if defined(__ARM_NEON)
4703
4936
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
4704
4937
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -4740,10 +4973,10 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
4740
4973
  for (int i = 0; i < nb; ++i) {
4741
4974
  // Compute combined scale for the block
4742
4975
  const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
4743
- __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
4744
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4976
+ __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs);
4977
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
4745
4978
 
4746
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
4979
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
4747
4980
 
4748
4981
  // Multiply q with scale and accumulate
4749
4982
  #if defined(__AVX2__)
@@ -4793,7 +5026,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
4793
5026
  }
4794
5027
 
4795
5028
  #if QK_K == 256
4796
- void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5029
+ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5030
+ assert(nrc == 1);
5031
+ UNUSED(nrc);
5032
+ UNUSED(bx);
5033
+ UNUSED(by);
5034
+ UNUSED(bs);
4797
5035
 
4798
5036
  const block_q2_K * restrict x = vx;
4799
5037
  const block_q8_K * restrict y = vy;
@@ -5169,7 +5407,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
5169
5407
 
5170
5408
  #else
5171
5409
 
5172
- void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5410
+ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5411
+ assert(nrc == 1);
5412
+ UNUSED(nrc);
5413
+ UNUSED(bx);
5414
+ UNUSED(by);
5415
+ UNUSED(bs);
5173
5416
 
5174
5417
  const block_q2_K * restrict x = vx;
5175
5418
  const block_q8_K * restrict y = vy;
@@ -5427,8 +5670,13 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
5427
5670
  #endif
5428
5671
 
5429
5672
  #if QK_K == 256
5430
- void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5673
+ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5431
5674
  assert(n % QK_K == 0);
5675
+ assert(nrc == 1);
5676
+ UNUSED(nrc);
5677
+ UNUSED(bx);
5678
+ UNUSED(by);
5679
+ UNUSED(bs);
5432
5680
 
5433
5681
  const uint32_t kmask1 = 0x03030303;
5434
5682
  const uint32_t kmask2 = 0x0f0f0f0f;
@@ -5947,8 +6195,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
5947
6195
 
5948
6196
  #else
5949
6197
 
5950
- void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
6198
+ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5951
6199
  assert(n % QK_K == 0);
6200
+ assert(nrc == 1);
6201
+ UNUSED(nrc);
6202
+ UNUSED(bx);
6203
+ UNUSED(by);
6204
+ UNUSED(bs);
5952
6205
 
5953
6206
  const block_q3_K * restrict x = vx;
5954
6207
  const block_q8_K * restrict y = vy;
@@ -6290,8 +6543,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
6290
6543
  #endif
6291
6544
 
6292
6545
  #if QK_K == 256
6293
- void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
6546
+ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6294
6547
  assert(n % QK_K == 0);
6548
+ assert(nrc == 1);
6549
+ UNUSED(nrc);
6550
+ UNUSED(bx);
6551
+ UNUSED(by);
6552
+ UNUSED(bs);
6295
6553
 
6296
6554
  const block_q4_K * restrict x = vx;
6297
6555
  const block_q8_K * restrict y = vy;
@@ -6646,8 +6904,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
6646
6904
  #endif
6647
6905
  }
6648
6906
  #else
6649
- void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
6907
+ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6650
6908
  assert(n % QK_K == 0);
6909
+ assert(nrc == 1);
6910
+ UNUSED(nrc);
6911
+ UNUSED(bx);
6912
+ UNUSED(by);
6913
+ UNUSED(bs);
6651
6914
 
6652
6915
  const block_q4_K * restrict x = vx;
6653
6916
  const block_q8_K * restrict y = vy;
@@ -6889,8 +7152,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
6889
7152
  #endif
6890
7153
 
6891
7154
  #if QK_K == 256
6892
- void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7155
+ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6893
7156
  assert(n % QK_K == 0);
7157
+ assert(nrc == 1);
7158
+ UNUSED(nrc);
7159
+ UNUSED(bx);
7160
+ UNUSED(by);
7161
+ UNUSED(bs);
6894
7162
 
6895
7163
  const block_q5_K * restrict x = vx;
6896
7164
  const block_q8_K * restrict y = vy;
@@ -7309,8 +7577,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
7309
7577
 
7310
7578
  #else
7311
7579
 
7312
- void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7580
+ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
7313
7581
  assert(n % QK_K == 0);
7582
+ assert(nrc == 1);
7583
+ UNUSED(nrc);
7584
+ UNUSED(bx);
7585
+ UNUSED(by);
7586
+ UNUSED(bs);
7314
7587
 
7315
7588
  const block_q5_K * restrict x = vx;
7316
7589
  const block_q8_K * restrict y = vy;
@@ -7575,8 +7848,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
7575
7848
 
7576
7849
 
7577
7850
  #if QK_K == 256
7578
- void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7851
+ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
7579
7852
  assert(n % QK_K == 0);
7853
+ assert(nrc == 1);
7854
+ UNUSED(nrc);
7855
+ UNUSED(bx);
7856
+ UNUSED(by);
7857
+ UNUSED(bs);
7580
7858
 
7581
7859
  const block_q6_K * restrict x = vx;
7582
7860
  const block_q8_K * restrict y = vy;
@@ -8007,8 +8285,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
8007
8285
 
8008
8286
  #else
8009
8287
 
8010
- void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8288
+ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8011
8289
  assert(n % QK_K == 0);
8290
+ assert(nrc == 1);
8291
+ UNUSED(nrc);
8292
+ UNUSED(bx);
8293
+ UNUSED(by);
8294
+ UNUSED(bs);
8012
8295
 
8013
8296
  const block_q6_K * restrict x = vx;
8014
8297
  const block_q8_K * restrict y = vy;
@@ -8337,8 +8620,13 @@ static const int8_t keven_signs_q2xs[1024] = {
8337
8620
  1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
8338
8621
  };
8339
8622
 
8340
- void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8623
+ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8341
8624
  assert(n % QK_K == 0);
8625
+ assert(nrc == 1);
8626
+ UNUSED(nrc);
8627
+ UNUSED(bx);
8628
+ UNUSED(by);
8629
+ UNUSED(bs);
8342
8630
 
8343
8631
  const block_iq2_xxs * restrict x = vx;
8344
8632
  const block_q8_K * restrict y = vy;
@@ -8460,8 +8748,13 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
8460
8748
  #endif
8461
8749
  }
8462
8750
 
8463
- void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8751
+ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8464
8752
  assert(n % QK_K == 0);
8753
+ assert(nrc == 1);
8754
+ UNUSED(nrc);
8755
+ UNUSED(bx);
8756
+ UNUSED(by);
8757
+ UNUSED(bs);
8465
8758
 
8466
8759
  const block_iq2_xs * restrict x = vx;
8467
8760
  const block_q8_K * restrict y = vy;
@@ -8680,8 +8973,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8680
8973
  }
8681
8974
 
8682
8975
  // TODO
8683
- void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8976
+ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8684
8977
  assert(n % QK_K == 0);
8978
+ assert(nrc == 1);
8979
+ UNUSED(nrc);
8980
+ UNUSED(bx);
8981
+ UNUSED(by);
8982
+ UNUSED(bs);
8685
8983
 
8686
8984
  const block_iq3_xxs * restrict x = vx;
8687
8985
  const block_q8_K * restrict y = vy;
@@ -8707,10 +9005,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res
8707
9005
  for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8708
9006
  q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
8709
9007
  memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
8710
- const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
8711
- const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
8712
- const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
8713
- const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
9008
+ const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
9009
+ const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
9010
+ const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
9011
+ const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
8714
9012
  q3 += 16;
8715
9013
  q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
8716
9014
  q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
@@ -9048,8 +9346,6 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9048
9346
  int8_t L[32];
9049
9347
  int8_t Laux[32];
9050
9348
  float waux[32];
9051
- bool is_on_grid[4];
9052
- bool is_on_grid_aux[4];
9053
9349
  uint8_t block_signs[4];
9054
9350
  uint32_t q2[2*(QK_K/32)];
9055
9351
 
@@ -9099,10 +9395,11 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9099
9395
  memset(L, 0, 32);
9100
9396
  continue;
9101
9397
  }
9398
+ float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
9399
+ float eff_max = scale*kMaxQ;
9102
9400
  float best = 0;
9103
- float scale = max/(2*kMaxQ-1);
9104
- for (int is = -9; is <= 9; ++is) {
9105
- float id = (2*kMaxQ-1+is*0.1f)/max;
9401
+ for (int is = -6; is <= 6; ++is) {
9402
+ float id = (2*kMaxQ-1+is*0.1f)/eff_max;
9106
9403
  float this_scale = 1/id;
9107
9404
  for (int k = 0; k < 4; ++k) {
9108
9405
  for (int i = 0; i < 8; ++i) {
@@ -9112,9 +9409,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9112
9409
  uint16_t u = 0;
9113
9410
  for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
9114
9411
  int grid_index = kmap_q2xs[u];
9115
- is_on_grid_aux[k] = true;
9116
9412
  if (grid_index < 0) {
9117
- is_on_grid_aux[k] = false;
9118
9413
  const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
9119
9414
  grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
9120
9415
  }
@@ -9128,16 +9423,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9128
9423
  }
9129
9424
  if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
9130
9425
  scale = sumqx/sumq2; best = scale*sumqx;
9131
- for (int i = 0; i < 32; ++i) L[i] = Laux[i];
9132
- for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
9426
+ memcpy(L, Laux, 32);
9133
9427
  }
9134
9428
  }
9135
- int n_not_ongrid = 0;
9136
- for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9137
- if (n_not_ongrid > 0 && scale > 0) {
9429
+ if (scale > 0) {
9138
9430
  float id = 1/scale;
9139
9431
  for (int k = 0; k < 4; ++k) {
9140
- if (is_on_grid[k]) continue;
9141
9432
  uint16_t u = 0;
9142
9433
  for (int i = 0; i < 8; ++i) {
9143
9434
  int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
@@ -9193,49 +9484,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9193
9484
  float d = max_scale/31;
9194
9485
  y[ibl].d = GGML_FP32_TO_FP16(d);
9195
9486
  float id = 1/d;
9196
- float sumqx = 0, sumq2 = 0;
9197
9487
  for (int ib = 0; ib < QK_K/32; ++ib) {
9198
9488
  int l = nearest_int(0.5f*(id*scales[ib]-1));
9199
9489
  l = MAX(0, MIN(15, l));
9200
9490
  q2[2*ib+1] |= ((uint32_t)l << 28);
9201
- const float * xb = xbl + 32*ib;
9202
- const float * qw = quant_weights + QK_K*ibl + 32*ib;
9203
- for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9204
- const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
9205
- const float db = d * (1 + 2*l);
9206
- uint32_t u = 0;
9207
- for (int k = 0; k < 4; ++k) {
9208
- const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
9209
- const float * xk = xb + 8*k;
9210
- const float * wk = weight + 8*k;
9211
- const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9212
- float best_mse = 0; int best_index = aux8[k];
9213
- for (int j = 0; j < 8; ++j) {
9214
- float diff = db * grid[j] * signs[j] - xk[j];
9215
- best_mse += wk[j] * diff * diff;
9216
- }
9217
- for (int idx = 0; idx < 256; ++idx) {
9218
- grid = (const uint8_t *)(kgrid_q2xs + idx);
9219
- float mse = 0;
9220
- for (int j = 0; j < 8; ++j) {
9221
- float diff = db * grid[j] * signs[j] - xk[j];
9222
- mse += wk[j] * diff * diff;
9223
- }
9224
- if (mse < best_mse) {
9225
- best_mse = mse; best_index = idx;
9226
- }
9227
- }
9228
- u |= (best_index << 8*k);
9229
- grid = (const uint8_t *)(kgrid_q2xs + best_index);
9230
- //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9231
- for (int j = 0; j < 8; ++j) {
9232
- float q = db * grid[j] * signs[j];
9233
- sumqx += wk[j] * q * xk[j];
9234
- sumq2 += wk[j] * q * q;
9235
- }
9236
- }
9237
- q2[2*ib] = u;
9238
- if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
9239
9491
  }
9240
9492
  memcpy(y[ibl].qs, q2, QK_K/4);
9241
9493
  }