llama_cpp 0.10.2 → 0.10.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/llama_cpp/src/ggml-alloc.c +1 -1
- data/ext/llama_cpp/src/ggml-backend.c +6 -10
- data/ext/llama_cpp/src/ggml-cuda.cu +510 -372
- data/ext/llama_cpp/src/ggml-quants.c +25 -344
- data/ext/llama_cpp/src/ggml.c +7 -8
- data/ext/llama_cpp/src/ggml.h +2 -0
- data/ext/llama_cpp/src/llama.cpp +432 -39
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +1 -0
- metadata +2 -2
@@ -407,6 +407,18 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
|
407
407
|
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
408
408
|
|
409
409
|
#endif
|
410
|
+
|
411
|
+
#if !defined(__ARM_FEATURE_DOTPROD)
|
412
|
+
|
413
|
+
inline static int32x4_t vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
414
|
+
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
415
|
+
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
416
|
+
|
417
|
+
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
418
|
+
}
|
419
|
+
|
420
|
+
#endif
|
421
|
+
|
410
422
|
#endif
|
411
423
|
|
412
424
|
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
@@ -2468,32 +2480,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
|
|
2468
2480
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2469
2481
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2470
2482
|
|
2471
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
2472
2483
|
// dot product into int32x4_t
|
2473
2484
|
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
2474
2485
|
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
2475
2486
|
|
2476
2487
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2477
2488
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2478
|
-
#else
|
2479
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
|
2480
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
2481
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
|
2482
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
2483
|
-
|
2484
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
|
2485
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
2486
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
|
2487
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
2488
|
-
|
2489
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2490
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2491
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2492
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2493
|
-
|
2494
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2495
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2496
|
-
#endif
|
2497
2489
|
}
|
2498
2490
|
|
2499
2491
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
@@ -2776,32 +2768,12 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
|
|
2776
2768
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2777
2769
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2778
2770
|
|
2779
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
2780
2771
|
// dot product into int32x4_t
|
2781
2772
|
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
|
2782
2773
|
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
|
2783
2774
|
|
2784
2775
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
2785
2776
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
2786
|
-
#else
|
2787
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
|
2788
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
|
2789
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
|
2790
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
|
2791
|
-
|
2792
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
|
2793
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
|
2794
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
|
2795
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
|
2796
|
-
|
2797
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2798
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2799
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2800
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2801
|
-
|
2802
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
2803
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
2804
|
-
#endif
|
2805
2777
|
}
|
2806
2778
|
|
2807
2779
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
@@ -2963,32 +2935,12 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|
2963
2935
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2964
2936
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2965
2937
|
|
2966
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
2967
2938
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
2968
2939
|
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
2969
2940
|
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2970
2941
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
2971
2942
|
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
2972
2943
|
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2973
|
-
#else
|
2974
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
2975
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
2976
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
|
2977
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
|
2978
|
-
|
2979
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
|
2980
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
|
2981
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
|
2982
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
|
2983
|
-
|
2984
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2985
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2986
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2987
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2988
|
-
|
2989
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2990
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2991
|
-
#endif
|
2992
2944
|
}
|
2993
2945
|
|
2994
2946
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
@@ -3275,32 +3227,12 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
|
|
3275
3227
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
3276
3228
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3277
3229
|
|
3278
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
3279
3230
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3280
3231
|
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
3281
3232
|
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
3282
3233
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3283
3234
|
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
3284
3235
|
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
3285
|
-
#else
|
3286
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
3287
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
3288
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
|
3289
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
|
3290
|
-
|
3291
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
|
3292
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
|
3293
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
|
3294
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
|
3295
|
-
|
3296
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3297
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3298
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
3299
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
3300
|
-
|
3301
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
3302
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
3303
|
-
#endif
|
3304
3236
|
}
|
3305
3237
|
|
3306
3238
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
@@ -3550,7 +3482,6 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
3550
3482
|
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
3551
3483
|
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
3552
3484
|
|
3553
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
3554
3485
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3555
3486
|
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
3556
3487
|
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
@@ -3558,26 +3489,6 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
3558
3489
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3559
3490
|
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
3560
3491
|
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
3561
|
-
|
3562
|
-
#else
|
3563
|
-
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
3564
|
-
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
3565
|
-
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
3566
|
-
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
3567
|
-
|
3568
|
-
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
3569
|
-
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
3570
|
-
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
3571
|
-
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
3572
|
-
|
3573
|
-
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
3574
|
-
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
3575
|
-
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
3576
|
-
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
3577
|
-
|
3578
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
3579
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
3580
|
-
#endif
|
3581
3492
|
}
|
3582
3493
|
|
3583
3494
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
@@ -3650,12 +3561,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3650
3561
|
const int nb = n / QK_K;
|
3651
3562
|
|
3652
3563
|
#ifdef __ARM_NEON
|
3653
|
-
|
3654
3564
|
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
3655
3565
|
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
3656
|
-
|
3657
|
-
const int32x4_t
|
3658
|
-
#endif
|
3566
|
+
|
3567
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
3659
3568
|
|
3660
3569
|
ggml_int8x16x2_t q2bytes;
|
3661
3570
|
uint8_t aux[16];
|
@@ -3663,7 +3572,6 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3663
3572
|
float sum = 0;
|
3664
3573
|
|
3665
3574
|
for (int i = 0; i < nb; ++i) {
|
3666
|
-
|
3667
3575
|
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
3668
3576
|
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
3669
3577
|
|
@@ -3677,7 +3585,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3677
3585
|
|
3678
3586
|
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
3679
3587
|
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
3680
|
-
const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
|
3588
|
+
const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
|
3681
3589
|
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
3682
3590
|
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
3683
3591
|
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
@@ -3689,20 +3597,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3689
3597
|
|
3690
3598
|
// We use this macro instead of a function call because for some reason
|
3691
3599
|
// the code runs 2-3% slower, even if the function is declared inline
|
3692
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
3693
3600
|
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
3694
3601
|
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
3695
3602
|
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
3696
|
-
#else
|
3697
|
-
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
3698
|
-
{\
|
3699
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
|
3700
|
-
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
|
3701
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
|
3702
|
-
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
|
3703
|
-
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
|
3704
|
-
}
|
3705
|
-
#endif
|
3706
3603
|
|
3707
3604
|
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
3708
3605
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
|
@@ -3710,26 +3607,23 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3710
3607
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
3711
3608
|
MULTIPLY_ACCUM_WITH_SCALE((index));
|
3712
3609
|
|
3713
|
-
|
3714
3610
|
for (int j = 0; j < QK_K/128; ++j) {
|
3715
|
-
|
3716
3611
|
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
|
3717
3612
|
|
3718
3613
|
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
3719
3614
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
3720
3615
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
3616
|
+
|
3721
3617
|
MULTIPLY_ACCUM_WITH_SCALE(0);
|
3722
3618
|
|
3723
3619
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
3724
|
-
|
3725
3620
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
3726
|
-
|
3727
3621
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
3728
3622
|
|
3729
3623
|
is += 8;
|
3730
3624
|
}
|
3731
|
-
sum += d * isum;
|
3732
3625
|
|
3626
|
+
sum += d * isum;
|
3733
3627
|
}
|
3734
3628
|
|
3735
3629
|
*s = sum;
|
@@ -4043,11 +3937,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4043
3937
|
const int nb = n / QK_K;
|
4044
3938
|
|
4045
3939
|
#ifdef __ARM_NEON
|
4046
|
-
|
4047
3940
|
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
4048
|
-
|
4049
|
-
const int32x4_t
|
4050
|
-
#endif
|
3941
|
+
|
3942
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
4051
3943
|
|
4052
3944
|
ggml_int8x16x4_t q2bytes;
|
4053
3945
|
|
@@ -4081,28 +3973,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4081
3973
|
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
4082
3974
|
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
4083
3975
|
|
4084
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
4085
3976
|
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
4086
3977
|
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
4087
3978
|
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
4088
3979
|
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
4089
|
-
#else
|
4090
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
4091
|
-
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
4092
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
4093
|
-
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
4094
|
-
isum1 += vaddvq_s16(p1) * scales[0];
|
4095
|
-
isum2 += vaddvq_s16(p2) * scales[1];
|
4096
|
-
|
4097
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
4098
|
-
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
4099
|
-
const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
4100
|
-
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
4101
|
-
isum1 += vaddvq_s16(p3) * scales[2];
|
4102
|
-
isum2 += vaddvq_s16(p4) * scales[3];
|
4103
|
-
#endif
|
4104
|
-
sum += d * (isum1 + isum2);
|
4105
3980
|
|
3981
|
+
sum += d * (isum1 + isum2);
|
4106
3982
|
}
|
4107
3983
|
|
4108
3984
|
*s = sum;
|
@@ -4328,9 +4204,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4328
4204
|
uint32_t utmp[4];
|
4329
4205
|
|
4330
4206
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
4331
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
4332
4207
|
const int32x4_t vzero = vdupq_n_s32(0);
|
4333
|
-
#endif
|
4334
4208
|
|
4335
4209
|
const uint8x16_t m0 = vdupq_n_u8(1);
|
4336
4210
|
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
|
@@ -4382,22 +4256,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4382
4256
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
4383
4257
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
4384
4258
|
|
4385
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
4386
4259
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
4387
4260
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
4388
4261
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
4389
4262
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
4390
|
-
|
4391
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
|
4392
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
|
4393
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
|
4394
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
|
4395
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
|
4396
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
|
4397
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
|
4398
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
|
4399
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
4400
|
-
#endif
|
4263
|
+
|
4401
4264
|
scale += 4;
|
4402
4265
|
|
4403
4266
|
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
|
@@ -4410,22 +4273,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4410
4273
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
4411
4274
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
4412
4275
|
|
4413
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
4414
4276
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
4415
4277
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
4416
4278
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
4417
4279
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
4418
|
-
|
4419
|
-
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
|
4420
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
|
4421
|
-
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
|
4422
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
|
4423
|
-
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
|
4424
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
|
4425
|
-
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
|
4426
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
|
4427
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
4428
|
-
#endif
|
4280
|
+
|
4429
4281
|
scale += 4;
|
4430
4282
|
|
4431
4283
|
if (j == 0) {
|
@@ -4864,10 +4716,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4864
4716
|
const int nb = n / QK_K;
|
4865
4717
|
|
4866
4718
|
#ifdef __ARM_NEON
|
4867
|
-
|
4868
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
4869
|
-
const int32x4_t vzero = vdupq_n_s32(0);
|
4870
|
-
#endif
|
4719
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
4871
4720
|
|
4872
4721
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
4873
4722
|
const uint8x16_t mh = vdupq_n_u8(4);
|
@@ -4908,22 +4757,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4908
4757
|
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
4909
4758
|
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
4910
4759
|
|
4911
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
4912
4760
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
4913
4761
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
4914
4762
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
4915
4763
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
4916
|
-
#else
|
4917
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
4918
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
4919
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
4920
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
4921
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
4922
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
4923
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
4924
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
4925
|
-
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
|
4926
|
-
#endif
|
4927
4764
|
|
4928
4765
|
sum += d * isum;
|
4929
4766
|
|
@@ -5228,11 +5065,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5228
5065
|
uint32_t utmp[4];
|
5229
5066
|
|
5230
5067
|
#ifdef __ARM_NEON
|
5231
|
-
|
5232
5068
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
5233
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
5234
5069
|
const int32x4_t mzero = vdupq_n_s32(0);
|
5235
|
-
#endif
|
5236
5070
|
|
5237
5071
|
ggml_int8x16x2_t q4bytes;
|
5238
5072
|
ggml_int8x16x2_t q8bytes;
|
@@ -5269,10 +5103,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5269
5103
|
int32_t sumi2 = 0;
|
5270
5104
|
|
5271
5105
|
for (int j = 0; j < QK_K/64; ++j) {
|
5272
|
-
|
5273
5106
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
|
5274
5107
|
|
5275
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
5276
5108
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
5277
5109
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5278
5110
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
@@ -5287,26 +5119,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5287
5119
|
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
5288
5120
|
|
5289
5121
|
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
5290
|
-
#else
|
5291
|
-
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
5292
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5293
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
5294
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
5295
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
5296
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
5297
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
5298
|
-
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
|
5299
|
-
|
5300
|
-
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
5301
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
5302
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
5303
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
5304
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
5305
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
5306
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
5307
|
-
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
|
5308
|
-
|
5309
|
-
#endif
|
5310
5122
|
}
|
5311
5123
|
|
5312
5124
|
sumf += d * (sumi1 + sumi2);
|
@@ -5603,12 +5415,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5603
5415
|
const int nb = n / QK_K;
|
5604
5416
|
|
5605
5417
|
#ifdef __ARM_NEON
|
5606
|
-
|
5607
5418
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
5608
5419
|
|
5609
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
5610
5420
|
const int32x4_t mzero = vdupq_n_s32(0);
|
5611
|
-
#endif
|
5612
5421
|
|
5613
5422
|
float sumf = 0;
|
5614
5423
|
|
@@ -5636,7 +5445,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5636
5445
|
|
5637
5446
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
|
5638
5447
|
|
5639
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
5640
5448
|
q8bytes = ggml_vld1q_s8_x4(q8);
|
5641
5449
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5642
5450
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
@@ -5650,27 +5458,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5650
5458
|
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
5651
5459
|
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
5652
5460
|
|
5653
|
-
#else
|
5654
|
-
q8bytes = ggml_vld1q_s8_x4(q8);
|
5655
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5656
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
5657
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
5658
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
5659
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
5660
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
5661
|
-
int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
|
5662
|
-
|
5663
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
5664
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
5665
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
|
5666
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
|
5667
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
|
5668
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
|
5669
|
-
int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
|
5670
|
-
|
5671
|
-
#endif
|
5672
5461
|
sumf += d * (sumi1 + sumi2);
|
5673
|
-
|
5674
5462
|
}
|
5675
5463
|
|
5676
5464
|
*s = sumf - sum_mins;
|
@@ -5875,15 +5663,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5875
5663
|
|
5876
5664
|
uint32_t utmp[4];
|
5877
5665
|
|
5878
|
-
|
5879
5666
|
#ifdef __ARM_NEON
|
5880
|
-
|
5881
5667
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
5882
5668
|
const uint8x16_t mone = vdupq_n_u8(1);
|
5883
5669
|
const uint8x16_t mtwo = vdupq_n_u8(2);
|
5884
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
5885
5670
|
const int32x4_t mzero = vdupq_n_s32(0);
|
5886
|
-
#endif
|
5887
5671
|
|
5888
5672
|
ggml_int8x16x4_t q5bytes;
|
5889
5673
|
|
@@ -5938,28 +5722,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5938
5722
|
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
5939
5723
|
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
5940
5724
|
|
5941
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
5942
|
-
|
5943
5725
|
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
5944
5726
|
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
5945
|
-
#else
|
5946
|
-
|
5947
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
5948
|
-
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
5949
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
5950
|
-
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
5951
|
-
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
|
5952
|
-
|
5953
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
5954
|
-
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
5955
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
5956
|
-
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
5957
|
-
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
|
5958
|
-
#endif
|
5959
5727
|
}
|
5960
5728
|
|
5961
5729
|
sumf += d * sumi - dmin * sumi_mins;
|
5962
|
-
|
5963
5730
|
}
|
5964
5731
|
|
5965
5732
|
*s = sumf;
|
@@ -6311,12 +6078,9 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6311
6078
|
const int nb = n / QK_K;
|
6312
6079
|
|
6313
6080
|
#ifdef __ARM_NEON
|
6314
|
-
|
6315
6081
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
6316
6082
|
const uint8x16_t mh = vdupq_n_u8(16);
|
6317
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6318
6083
|
const int32x4_t mzero = vdupq_n_s32(0);
|
6319
|
-
#endif
|
6320
6084
|
|
6321
6085
|
ggml_int8x16x4_t q5bytes;
|
6322
6086
|
ggml_uint8x16x4_t q5h;
|
@@ -6348,32 +6112,12 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6348
6112
|
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
6349
6113
|
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
6350
6114
|
|
6351
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6352
|
-
|
6353
6115
|
int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
6354
6116
|
int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
6355
6117
|
int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
6356
6118
|
int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
6357
6119
|
|
6358
6120
|
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
6359
|
-
|
6360
|
-
#else
|
6361
|
-
|
6362
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
6363
|
-
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
6364
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
6365
|
-
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
6366
|
-
int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
|
6367
|
-
|
6368
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
6369
|
-
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
6370
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
6371
|
-
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
6372
|
-
sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
|
6373
|
-
|
6374
|
-
sumf += d*sumi;
|
6375
|
-
#endif
|
6376
|
-
|
6377
6121
|
}
|
6378
6122
|
|
6379
6123
|
*s = sumf;
|
@@ -6600,13 +6344,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6600
6344
|
const int nb = n / QK_K;
|
6601
6345
|
|
6602
6346
|
#ifdef __ARM_NEON
|
6603
|
-
|
6604
6347
|
float sum = 0;
|
6605
6348
|
|
6606
6349
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
6607
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6608
6350
|
const int32x4_t vzero = vdupq_n_s32(0);
|
6609
|
-
#endif
|
6610
6351
|
//const int8x16_t m32s = vdupq_n_s8(32);
|
6611
6352
|
|
6612
6353
|
const uint8x16_t mone = vdupq_n_u8(3);
|
@@ -6626,7 +6367,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6626
6367
|
|
6627
6368
|
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
6628
6369
|
const int8x16_t scales = vld1q_s8(scale);
|
6629
|
-
const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
|
6370
|
+
const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
|
6630
6371
|
|
6631
6372
|
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
6632
6373
|
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
@@ -6658,30 +6399,12 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6658
6399
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
6659
6400
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
6660
6401
|
|
6661
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6662
|
-
|
6663
6402
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
6664
6403
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
6665
6404
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
6666
6405
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
6667
|
-
scale += 4;
|
6668
6406
|
|
6669
|
-
|
6670
|
-
|
6671
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
6672
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
6673
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
6674
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
6675
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
6676
|
-
scale += 2;
|
6677
|
-
|
6678
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
6679
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
6680
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
6681
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
6682
|
-
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
6683
|
-
scale += 2;
|
6684
|
-
#endif
|
6407
|
+
scale += 4;
|
6685
6408
|
|
6686
6409
|
q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
|
6687
6410
|
|
@@ -6703,34 +6426,11 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6703
6426
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
6704
6427
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
6705
6428
|
|
6706
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6707
|
-
|
6708
6429
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
6709
6430
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
6710
6431
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
6711
6432
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
6712
6433
|
scale += 4;
|
6713
|
-
|
6714
|
-
//for (int l = 0; l < 4; ++l) {
|
6715
|
-
// const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
|
6716
|
-
// isum += vaddvq_s32(p) * *scale++;
|
6717
|
-
//}
|
6718
|
-
#else
|
6719
|
-
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
6720
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
6721
|
-
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
6722
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
6723
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
6724
|
-
scale += 2;
|
6725
|
-
|
6726
|
-
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
6727
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
6728
|
-
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
6729
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
6730
|
-
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
6731
|
-
scale += 2;
|
6732
|
-
#endif
|
6733
|
-
|
6734
6434
|
}
|
6735
6435
|
//sum += isum * d_all * y[i].d;
|
6736
6436
|
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
@@ -7076,14 +6776,11 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7076
6776
|
const int nb = n / QK_K;
|
7077
6777
|
|
7078
6778
|
#ifdef __ARM_NEON
|
7079
|
-
|
7080
6779
|
float sum = 0;
|
7081
6780
|
|
7082
6781
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
7083
6782
|
const int8x16_t m32s = vdupq_n_s8(32);
|
7084
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
7085
6783
|
const int32x4_t vzero = vdupq_n_s32(0);
|
7086
|
-
#endif
|
7087
6784
|
|
7088
6785
|
const uint8x16_t mone = vdupq_n_u8(3);
|
7089
6786
|
|
@@ -7119,26 +6816,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7119
6816
|
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
7120
6817
|
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
7121
6818
|
|
7122
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
7123
|
-
|
7124
6819
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
7125
6820
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
7126
6821
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
7127
6822
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
7128
|
-
#else
|
7129
|
-
|
7130
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
7131
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
7132
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
7133
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
7134
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
7135
|
-
|
7136
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
7137
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
7138
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
7139
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
7140
|
-
isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
7141
|
-
#endif
|
7142
6823
|
|
7143
6824
|
sum += isum * d_all * y[i].d;
|
7144
6825
|
|