llama_cpp 0.10.2 → 0.10.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -407,6 +407,18 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
407
407
  #define ggml_vld1q_s8_x4 vld1q_s8_x4
408
408
 
409
409
  #endif
410
+
411
+ #if !defined(__ARM_FEATURE_DOTPROD)
412
+
413
+ inline static int32x4_t vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
414
+ const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
415
+ const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
416
+
417
+ return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
418
+ }
419
+
420
+ #endif
421
+
410
422
  #endif
411
423
 
412
424
  #if defined(__ARM_NEON) || defined(__wasm_simd128__)
@@ -2468,32 +2480,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
2468
2480
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2469
2481
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2470
2482
 
2471
- #if defined(__ARM_FEATURE_DOTPROD)
2472
2483
  // dot product into int32x4_t
2473
2484
  const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2474
2485
  const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2475
2486
 
2476
2487
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2477
2488
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2478
- #else
2479
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2480
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
2481
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
2482
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
2483
-
2484
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
2485
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
2486
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
2487
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
2488
-
2489
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2490
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2491
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2492
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2493
-
2494
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2495
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2496
- #endif
2497
2489
  }
2498
2490
 
2499
2491
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -2776,32 +2768,12 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
2776
2768
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2777
2769
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2778
2770
 
2779
- #if defined(__ARM_FEATURE_DOTPROD)
2780
2771
  // dot product into int32x4_t
2781
2772
  const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2782
2773
  const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
2783
2774
 
2784
2775
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
2785
2776
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
2786
- #else
2787
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
2788
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
2789
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
2790
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
2791
-
2792
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
2793
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
2794
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
2795
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
2796
-
2797
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2798
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2799
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2800
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2801
-
2802
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
2803
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
2804
- #endif
2805
2777
  }
2806
2778
 
2807
2779
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
@@ -2963,32 +2935,12 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
2963
2935
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2964
2936
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2965
2937
 
2966
- #if defined(__ARM_FEATURE_DOTPROD)
2967
2938
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2968
2939
  vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2969
2940
  vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2970
2941
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2971
2942
  vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2972
2943
  vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2973
- #else
2974
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2975
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
2976
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
2977
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
2978
-
2979
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
2980
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
2981
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
2982
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
2983
-
2984
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2985
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2986
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2987
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2988
-
2989
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2990
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2991
- #endif
2992
2944
  }
2993
2945
 
2994
2946
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3275,32 +3227,12 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
3275
3227
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
3276
3228
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3277
3229
 
3278
- #if defined(__ARM_FEATURE_DOTPROD)
3279
3230
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3280
3231
  vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3281
3232
  vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3282
3233
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3283
3234
  vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3284
3235
  vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
3285
- #else
3286
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
3287
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
3288
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
3289
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
3290
-
3291
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
3292
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
3293
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
3294
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
3295
-
3296
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3297
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3298
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3299
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3300
-
3301
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
3302
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
3303
- #endif
3304
3236
  }
3305
3237
 
3306
3238
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
@@ -3550,7 +3482,6 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
3550
3482
  const int8x16_t y1_0 = vld1q_s8(y1->qs);
3551
3483
  const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
3552
3484
 
3553
- #if defined(__ARM_FEATURE_DOTPROD)
3554
3485
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3555
3486
  vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3556
3487
  vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
@@ -3558,26 +3489,6 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
3558
3489
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3559
3490
  vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3560
3491
  vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3561
-
3562
- #else
3563
- const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3564
- const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3565
- const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3566
- const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3567
-
3568
- const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3569
- const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3570
- const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3571
- const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3572
-
3573
- const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3574
- const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3575
- const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3576
- const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3577
-
3578
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3579
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3580
- #endif
3581
3492
  }
3582
3493
 
3583
3494
  *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@@ -3650,12 +3561,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3650
3561
  const int nb = n / QK_K;
3651
3562
 
3652
3563
  #ifdef __ARM_NEON
3653
-
3654
3564
  const uint8x16_t m3 = vdupq_n_u8(0x3);
3655
3565
  const uint8x16_t m4 = vdupq_n_u8(0xF);
3656
- #if defined(__ARM_FEATURE_DOTPROD)
3657
- const int32x4_t vzero = vdupq_n_s32(0);
3658
- #endif
3566
+
3567
+ const int32x4_t vzero = vdupq_n_s32(0);
3659
3568
 
3660
3569
  ggml_int8x16x2_t q2bytes;
3661
3570
  uint8_t aux[16];
@@ -3663,7 +3572,6 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3663
3572
  float sum = 0;
3664
3573
 
3665
3574
  for (int i = 0; i < nb; ++i) {
3666
-
3667
3575
  const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3668
3576
  const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3669
3577
 
@@ -3677,7 +3585,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3677
3585
 
3678
3586
  const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
3679
3587
  const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
3680
- const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
3588
+ const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
3681
3589
  const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
3682
3590
  vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
3683
3591
  const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@@ -3689,20 +3597,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3689
3597
 
3690
3598
  // We use this macro instead of a function call because for some reason
3691
3599
  // the code runs 2-3% slower, even if the function is declared inline
3692
- #if defined(__ARM_FEATURE_DOTPROD)
3693
3600
  #define MULTIPLY_ACCUM_WITH_SCALE(index)\
3694
3601
  isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
3695
3602
  isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
3696
- #else
3697
- #define MULTIPLY_ACCUM_WITH_SCALE(index)\
3698
- {\
3699
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
3700
- vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
3701
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
3702
- vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
3703
- isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
3704
- }
3705
- #endif
3706
3603
 
3707
3604
  #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
3708
3605
  q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
@@ -3710,26 +3607,23 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
3710
3607
  q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
3711
3608
  MULTIPLY_ACCUM_WITH_SCALE((index));
3712
3609
 
3713
-
3714
3610
  for (int j = 0; j < QK_K/128; ++j) {
3715
-
3716
3611
  const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
3717
3612
 
3718
3613
  ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
3719
3614
  q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
3720
3615
  q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
3616
+
3721
3617
  MULTIPLY_ACCUM_WITH_SCALE(0);
3722
3618
 
3723
3619
  SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
3724
-
3725
3620
  SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
3726
-
3727
3621
  SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
3728
3622
 
3729
3623
  is += 8;
3730
3624
  }
3731
- sum += d * isum;
3732
3625
 
3626
+ sum += d * isum;
3733
3627
  }
3734
3628
 
3735
3629
  *s = sum;
@@ -4043,11 +3937,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
4043
3937
  const int nb = n / QK_K;
4044
3938
 
4045
3939
  #ifdef __ARM_NEON
4046
-
4047
3940
  const uint8x16_t m3 = vdupq_n_u8(0x3);
4048
- #if defined(__ARM_FEATURE_DOTPROD)
4049
- const int32x4_t vzero = vdupq_n_s32(0);
4050
- #endif
3941
+
3942
+ const int32x4_t vzero = vdupq_n_s32(0);
4051
3943
 
4052
3944
  ggml_int8x16x4_t q2bytes;
4053
3945
 
@@ -4081,28 +3973,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
4081
3973
  q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
4082
3974
  q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
4083
3975
 
4084
- #if defined(__ARM_FEATURE_DOTPROD)
4085
3976
  isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
4086
3977
  isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
4087
3978
  isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
4088
3979
  isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
4089
- #else
4090
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
4091
- vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
4092
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
4093
- vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
4094
- isum1 += vaddvq_s16(p1) * scales[0];
4095
- isum2 += vaddvq_s16(p2) * scales[1];
4096
-
4097
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
4098
- vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
4099
- const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
4100
- vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
4101
- isum1 += vaddvq_s16(p3) * scales[2];
4102
- isum2 += vaddvq_s16(p4) * scales[3];
4103
- #endif
4104
- sum += d * (isum1 + isum2);
4105
3980
 
3981
+ sum += d * (isum1 + isum2);
4106
3982
  }
4107
3983
 
4108
3984
  *s = sum;
@@ -4328,9 +4204,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4328
4204
  uint32_t utmp[4];
4329
4205
 
4330
4206
  const uint8x16_t m3b = vdupq_n_u8(0x3);
4331
- #ifdef __ARM_FEATURE_DOTPROD
4332
4207
  const int32x4_t vzero = vdupq_n_s32(0);
4333
- #endif
4334
4208
 
4335
4209
  const uint8x16_t m0 = vdupq_n_u8(1);
4336
4210
  const uint8x16_t m1 = vshlq_n_u8(m0, 1);
@@ -4382,22 +4256,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4382
4256
  q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
4383
4257
  q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
4384
4258
 
4385
- #if defined(__ARM_FEATURE_DOTPROD)
4386
4259
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
4387
4260
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
4388
4261
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
4389
4262
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
4390
- #else
4391
- int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
4392
- vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
4393
- int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
4394
- vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
4395
- int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
4396
- vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
4397
- int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
4398
- vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
4399
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
4400
- #endif
4263
+
4401
4264
  scale += 4;
4402
4265
 
4403
4266
  q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
@@ -4410,22 +4273,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4410
4273
  q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
4411
4274
  q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
4412
4275
 
4413
- #if defined(__ARM_FEATURE_DOTPROD)
4414
4276
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
4415
4277
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
4416
4278
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
4417
4279
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
4418
- #else
4419
- p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
4420
- vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
4421
- p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
4422
- vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
4423
- p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
4424
- vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
4425
- p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
4426
- vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
4427
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
4428
- #endif
4280
+
4429
4281
  scale += 4;
4430
4282
 
4431
4283
  if (j == 0) {
@@ -4864,10 +4716,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4864
4716
  const int nb = n / QK_K;
4865
4717
 
4866
4718
  #ifdef __ARM_NEON
4867
-
4868
- #ifdef __ARM_FEATURE_DOTPROD
4869
- const int32x4_t vzero = vdupq_n_s32(0);
4870
- #endif
4719
+ const int32x4_t vzero = vdupq_n_s32(0);
4871
4720
 
4872
4721
  const uint8x16_t m3b = vdupq_n_u8(0x3);
4873
4722
  const uint8x16_t mh = vdupq_n_u8(4);
@@ -4908,22 +4757,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
4908
4757
  q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
4909
4758
  q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
4910
4759
 
4911
- #if defined(__ARM_FEATURE_DOTPROD)
4912
4760
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
4913
4761
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
4914
4762
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
4915
4763
  isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
4916
- #else
4917
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
4918
- vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
4919
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
4920
- vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
4921
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
4922
- vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
4923
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
4924
- vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
4925
- isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
4926
- #endif
4927
4764
 
4928
4765
  sum += d * isum;
4929
4766
 
@@ -5228,11 +5065,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5228
5065
  uint32_t utmp[4];
5229
5066
 
5230
5067
  #ifdef __ARM_NEON
5231
-
5232
5068
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5233
- #ifdef __ARM_FEATURE_DOTPROD
5234
5069
  const int32x4_t mzero = vdupq_n_s32(0);
5235
- #endif
5236
5070
 
5237
5071
  ggml_int8x16x2_t q4bytes;
5238
5072
  ggml_int8x16x2_t q8bytes;
@@ -5269,10 +5103,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5269
5103
  int32_t sumi2 = 0;
5270
5104
 
5271
5105
  for (int j = 0; j < QK_K/64; ++j) {
5272
-
5273
5106
  const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
5274
5107
 
5275
- #ifdef __ARM_FEATURE_DOTPROD
5276
5108
  q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
5277
5109
  q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5278
5110
  q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
@@ -5287,26 +5119,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5287
5119
  const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5288
5120
 
5289
5121
  sumi2 += vaddvq_s32(p2) * scales[2*j+1];
5290
- #else
5291
- q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
5292
- q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5293
- q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5294
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
5295
- vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
5296
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
5297
- vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
5298
- sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
5299
-
5300
- q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
5301
- q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5302
- q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5303
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
5304
- vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
5305
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
5306
- vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
5307
- sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
5308
-
5309
- #endif
5310
5122
  }
5311
5123
 
5312
5124
  sumf += d * (sumi1 + sumi2);
@@ -5603,12 +5415,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5603
5415
  const int nb = n / QK_K;
5604
5416
 
5605
5417
  #ifdef __ARM_NEON
5606
-
5607
5418
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5608
5419
 
5609
- #ifdef __ARM_FEATURE_DOTPROD
5610
5420
  const int32x4_t mzero = vdupq_n_s32(0);
5611
- #endif
5612
5421
 
5613
5422
  float sumf = 0;
5614
5423
 
@@ -5636,7 +5445,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5636
5445
 
5637
5446
  const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
5638
5447
 
5639
- #ifdef __ARM_FEATURE_DOTPROD
5640
5448
  q8bytes = ggml_vld1q_s8_x4(q8);
5641
5449
  q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5642
5450
  q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
@@ -5650,27 +5458,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
5650
5458
  const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
5651
5459
  const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
5652
5460
 
5653
- #else
5654
- q8bytes = ggml_vld1q_s8_x4(q8);
5655
- q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5656
- q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5657
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
5658
- vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
5659
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
5660
- vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
5661
- int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
5662
-
5663
- q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5664
- q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
5665
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
5666
- vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
5667
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
5668
- vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
5669
- int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
5670
-
5671
- #endif
5672
5461
  sumf += d * (sumi1 + sumi2);
5673
-
5674
5462
  }
5675
5463
 
5676
5464
  *s = sumf - sum_mins;
@@ -5875,15 +5663,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
5875
5663
 
5876
5664
  uint32_t utmp[4];
5877
5665
 
5878
-
5879
5666
  #ifdef __ARM_NEON
5880
-
5881
5667
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5882
5668
  const uint8x16_t mone = vdupq_n_u8(1);
5883
5669
  const uint8x16_t mtwo = vdupq_n_u8(2);
5884
- #if defined(__ARM_FEATURE_DOTPROD)
5885
5670
  const int32x4_t mzero = vdupq_n_s32(0);
5886
- #endif
5887
5671
 
5888
5672
  ggml_int8x16x4_t q5bytes;
5889
5673
 
@@ -5938,28 +5722,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
5938
5722
  q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
5939
5723
  q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
5940
5724
 
5941
- #if defined(__ARM_FEATURE_DOTPROD)
5942
-
5943
5725
  sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
5944
5726
  sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
5945
- #else
5946
-
5947
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
5948
- vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
5949
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
5950
- vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
5951
- sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
5952
-
5953
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
5954
- vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
5955
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
5956
- vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
5957
- sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
5958
- #endif
5959
5727
  }
5960
5728
 
5961
5729
  sumf += d * sumi - dmin * sumi_mins;
5962
-
5963
5730
  }
5964
5731
 
5965
5732
  *s = sumf;
@@ -6311,12 +6078,9 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
6311
6078
  const int nb = n / QK_K;
6312
6079
 
6313
6080
  #ifdef __ARM_NEON
6314
-
6315
6081
  const uint8x16_t m4b = vdupq_n_u8(0xf);
6316
6082
  const uint8x16_t mh = vdupq_n_u8(16);
6317
- #if defined(__ARM_FEATURE_DOTPROD)
6318
6083
  const int32x4_t mzero = vdupq_n_s32(0);
6319
- #endif
6320
6084
 
6321
6085
  ggml_int8x16x4_t q5bytes;
6322
6086
  ggml_uint8x16x4_t q5h;
@@ -6348,32 +6112,12 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
6348
6112
  q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
6349
6113
  q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
6350
6114
 
6351
- #if defined(__ARM_FEATURE_DOTPROD)
6352
-
6353
6115
  int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
6354
6116
  int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
6355
6117
  int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
6356
6118
  int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
6357
6119
 
6358
6120
  sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
6359
-
6360
- #else
6361
-
6362
- const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
6363
- vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
6364
- const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
6365
- vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
6366
- int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
6367
-
6368
- const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
6369
- vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
6370
- const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
6371
- vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
6372
- sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
6373
-
6374
- sumf += d*sumi;
6375
- #endif
6376
-
6377
6121
  }
6378
6122
 
6379
6123
  *s = sumf;
@@ -6600,13 +6344,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6600
6344
  const int nb = n / QK_K;
6601
6345
 
6602
6346
  #ifdef __ARM_NEON
6603
-
6604
6347
  float sum = 0;
6605
6348
 
6606
6349
  const uint8x16_t m4b = vdupq_n_u8(0xF);
6607
- #if defined(__ARM_FEATURE_DOTPROD)
6608
6350
  const int32x4_t vzero = vdupq_n_s32(0);
6609
- #endif
6610
6351
  //const int8x16_t m32s = vdupq_n_s8(32);
6611
6352
 
6612
6353
  const uint8x16_t mone = vdupq_n_u8(3);
@@ -6626,7 +6367,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6626
6367
 
6627
6368
  const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
6628
6369
  const int8x16_t scales = vld1q_s8(scale);
6629
- const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
6370
+ const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
6630
6371
 
6631
6372
  const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
6632
6373
  vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@@ -6658,30 +6399,12 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6658
6399
  q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
6659
6400
  q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
6660
6401
 
6661
- #if defined(__ARM_FEATURE_DOTPROD)
6662
-
6663
6402
  isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6664
6403
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6665
6404
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6666
6405
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6667
- scale += 4;
6668
6406
 
6669
- #else
6670
-
6671
- int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
6672
- vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
6673
- int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
6674
- vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
6675
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
6676
- scale += 2;
6677
-
6678
- int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
6679
- vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
6680
- int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
6681
- vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
6682
- isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
6683
- scale += 2;
6684
- #endif
6407
+ scale += 4;
6685
6408
 
6686
6409
  q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
6687
6410
 
@@ -6703,34 +6426,11 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
6703
6426
  q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
6704
6427
  q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
6705
6428
 
6706
- #if defined(__ARM_FEATURE_DOTPROD)
6707
-
6708
6429
  isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
6709
6430
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
6710
6431
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
6711
6432
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
6712
6433
  scale += 4;
6713
-
6714
- //for (int l = 0; l < 4; ++l) {
6715
- // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
6716
- // isum += vaddvq_s32(p) * *scale++;
6717
- //}
6718
- #else
6719
- p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
6720
- vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
6721
- p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
6722
- vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
6723
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
6724
- scale += 2;
6725
-
6726
- p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
6727
- vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
6728
- p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
6729
- vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
6730
- isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
6731
- scale += 2;
6732
- #endif
6733
-
6734
6434
  }
6735
6435
  //sum += isum * d_all * y[i].d;
6736
6436
  sum += d_all * y[i].d * (isum - 32 * isum_mins);
@@ -7076,14 +6776,11 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
7076
6776
  const int nb = n / QK_K;
7077
6777
 
7078
6778
  #ifdef __ARM_NEON
7079
-
7080
6779
  float sum = 0;
7081
6780
 
7082
6781
  const uint8x16_t m4b = vdupq_n_u8(0xF);
7083
6782
  const int8x16_t m32s = vdupq_n_s8(32);
7084
- #if defined(__ARM_FEATURE_DOTPROD)
7085
6783
  const int32x4_t vzero = vdupq_n_s32(0);
7086
- #endif
7087
6784
 
7088
6785
  const uint8x16_t mone = vdupq_n_u8(3);
7089
6786
 
@@ -7119,26 +6816,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
7119
6816
  q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
7120
6817
  q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
7121
6818
 
7122
- #if defined(__ARM_FEATURE_DOTPROD)
7123
-
7124
6819
  isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
7125
6820
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
7126
6821
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
7127
6822
  vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
7128
- #else
7129
-
7130
- int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
7131
- vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
7132
- int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
7133
- vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
7134
- isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
7135
-
7136
- int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
7137
- vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
7138
- int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
7139
- vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
7140
- isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
7141
- #endif
7142
6823
 
7143
6824
  sum += isum * d_all * y[i].d;
7144
6825