llama_cpp 0.10.2 → 0.10.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/llama_cpp/src/ggml-alloc.c +1 -1
- data/ext/llama_cpp/src/ggml-backend.c +6 -10
- data/ext/llama_cpp/src/ggml-cuda.cu +510 -372
- data/ext/llama_cpp/src/ggml-quants.c +25 -344
- data/ext/llama_cpp/src/ggml.c +7 -8
- data/ext/llama_cpp/src/ggml.h +2 -0
- data/ext/llama_cpp/src/llama.cpp +432 -39
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +1 -0
- metadata +2 -2
@@ -407,6 +407,18 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
|
407
407
|
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
408
408
|
|
409
409
|
#endif
|
410
|
+
|
411
|
+
#if !defined(__ARM_FEATURE_DOTPROD)
|
412
|
+
|
413
|
+
inline static int32x4_t vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
414
|
+
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
415
|
+
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
416
|
+
|
417
|
+
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
418
|
+
}
|
419
|
+
|
420
|
+
#endif
|
421
|
+
|
410
422
|
#endif
|
411
423
|
|
412
424
|
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
@@ -2468,32 +2480,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
|
|
2468
2480
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2469
2481
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2470
2482
|
|
2471
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
2472
2483
|
// dot product into int32x4_t
|
2473
2484
|
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
2474
2485
|
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
2475
2486
|
|
2476
2487
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2477
2488
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2478
|
-
#else
|
2479
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
|
2480
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
2481
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
|
2482
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
2483
|
-
|
2484
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
|
2485
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
2486
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
|
2487
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
2488
|
-
|
2489
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2490
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2491
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2492
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2493
|
-
|
2494
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2495
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2496
|
-
#endif
|
2497
2489
|
}
|
2498
2490
|
|
2499
2491
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
@@ -2776,32 +2768,12 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
|
|
2776
2768
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2777
2769
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2778
2770
|
|
2779
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
2780
2771
|
// dot product into int32x4_t
|
2781
2772
|
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
|
2782
2773
|
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
|
2783
2774
|
|
2784
2775
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
2785
2776
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
2786
|
-
#else
|
2787
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
|
2788
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
|
2789
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
|
2790
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
|
2791
|
-
|
2792
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
|
2793
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
|
2794
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
|
2795
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
|
2796
|
-
|
2797
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2798
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2799
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2800
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2801
|
-
|
2802
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
2803
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
2804
|
-
#endif
|
2805
2777
|
}
|
2806
2778
|
|
2807
2779
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
@@ -2963,32 +2935,12 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|
2963
2935
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2964
2936
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2965
2937
|
|
2966
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
2967
2938
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
2968
2939
|
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
2969
2940
|
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2970
2941
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
2971
2942
|
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
2972
2943
|
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2973
|
-
#else
|
2974
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
2975
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
2976
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
|
2977
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
|
2978
|
-
|
2979
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
|
2980
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
|
2981
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
|
2982
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
|
2983
|
-
|
2984
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2985
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2986
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2987
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2988
|
-
|
2989
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
2990
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
2991
|
-
#endif
|
2992
2944
|
}
|
2993
2945
|
|
2994
2946
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
@@ -3275,32 +3227,12 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
|
|
3275
3227
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
3276
3228
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
3277
3229
|
|
3278
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
3279
3230
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3280
3231
|
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
3281
3232
|
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
3282
3233
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3283
3234
|
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
3284
3235
|
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
3285
|
-
#else
|
3286
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
3287
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
3288
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
|
3289
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
|
3290
|
-
|
3291
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
|
3292
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
|
3293
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
|
3294
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
|
3295
|
-
|
3296
|
-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
3297
|
-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
3298
|
-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
3299
|
-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
3300
|
-
|
3301
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
3302
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
3303
|
-
#endif
|
3304
3236
|
}
|
3305
3237
|
|
3306
3238
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
@@ -3550,7 +3482,6 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
3550
3482
|
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
3551
3483
|
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
3552
3484
|
|
3553
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
3554
3485
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
3555
3486
|
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
3556
3487
|
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
@@ -3558,26 +3489,6 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
3558
3489
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
3559
3490
|
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
3560
3491
|
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
3561
|
-
|
3562
|
-
#else
|
3563
|
-
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
3564
|
-
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
3565
|
-
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
3566
|
-
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
3567
|
-
|
3568
|
-
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
3569
|
-
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
3570
|
-
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
3571
|
-
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
3572
|
-
|
3573
|
-
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
3574
|
-
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
3575
|
-
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
3576
|
-
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
3577
|
-
|
3578
|
-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
3579
|
-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
3580
|
-
#endif
|
3581
3492
|
}
|
3582
3493
|
|
3583
3494
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
@@ -3650,12 +3561,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3650
3561
|
const int nb = n / QK_K;
|
3651
3562
|
|
3652
3563
|
#ifdef __ARM_NEON
|
3653
|
-
|
3654
3564
|
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
3655
3565
|
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
3656
|
-
|
3657
|
-
const int32x4_t
|
3658
|
-
#endif
|
3566
|
+
|
3567
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
3659
3568
|
|
3660
3569
|
ggml_int8x16x2_t q2bytes;
|
3661
3570
|
uint8_t aux[16];
|
@@ -3663,7 +3572,6 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3663
3572
|
float sum = 0;
|
3664
3573
|
|
3665
3574
|
for (int i = 0; i < nb; ++i) {
|
3666
|
-
|
3667
3575
|
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
3668
3576
|
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
3669
3577
|
|
@@ -3677,7 +3585,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3677
3585
|
|
3678
3586
|
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
3679
3587
|
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
3680
|
-
const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
|
3588
|
+
const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
|
3681
3589
|
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
3682
3590
|
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
3683
3591
|
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
@@ -3689,20 +3597,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3689
3597
|
|
3690
3598
|
// We use this macro instead of a function call because for some reason
|
3691
3599
|
// the code runs 2-3% slower, even if the function is declared inline
|
3692
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
3693
3600
|
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
3694
3601
|
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
3695
3602
|
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
3696
|
-
#else
|
3697
|
-
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
3698
|
-
{\
|
3699
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
|
3700
|
-
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
|
3701
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
|
3702
|
-
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
|
3703
|
-
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
|
3704
|
-
}
|
3705
|
-
#endif
|
3706
3603
|
|
3707
3604
|
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
3708
3605
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
|
@@ -3710,26 +3607,23 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3710
3607
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
3711
3608
|
MULTIPLY_ACCUM_WITH_SCALE((index));
|
3712
3609
|
|
3713
|
-
|
3714
3610
|
for (int j = 0; j < QK_K/128; ++j) {
|
3715
|
-
|
3716
3611
|
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
|
3717
3612
|
|
3718
3613
|
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
3719
3614
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
3720
3615
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
3616
|
+
|
3721
3617
|
MULTIPLY_ACCUM_WITH_SCALE(0);
|
3722
3618
|
|
3723
3619
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
3724
|
-
|
3725
3620
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
3726
|
-
|
3727
3621
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
3728
3622
|
|
3729
3623
|
is += 8;
|
3730
3624
|
}
|
3731
|
-
sum += d * isum;
|
3732
3625
|
|
3626
|
+
sum += d * isum;
|
3733
3627
|
}
|
3734
3628
|
|
3735
3629
|
*s = sum;
|
@@ -4043,11 +3937,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4043
3937
|
const int nb = n / QK_K;
|
4044
3938
|
|
4045
3939
|
#ifdef __ARM_NEON
|
4046
|
-
|
4047
3940
|
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
4048
|
-
|
4049
|
-
const int32x4_t
|
4050
|
-
#endif
|
3941
|
+
|
3942
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
4051
3943
|
|
4052
3944
|
ggml_int8x16x4_t q2bytes;
|
4053
3945
|
|
@@ -4081,28 +3973,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4081
3973
|
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
4082
3974
|
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
4083
3975
|
|
4084
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
4085
3976
|
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
4086
3977
|
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
4087
3978
|
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
4088
3979
|
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
4089
|
-
#else
|
4090
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
4091
|
-
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
4092
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
4093
|
-
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
4094
|
-
isum1 += vaddvq_s16(p1) * scales[0];
|
4095
|
-
isum2 += vaddvq_s16(p2) * scales[1];
|
4096
|
-
|
4097
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
4098
|
-
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
4099
|
-
const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
4100
|
-
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
4101
|
-
isum1 += vaddvq_s16(p3) * scales[2];
|
4102
|
-
isum2 += vaddvq_s16(p4) * scales[3];
|
4103
|
-
#endif
|
4104
|
-
sum += d * (isum1 + isum2);
|
4105
3980
|
|
3981
|
+
sum += d * (isum1 + isum2);
|
4106
3982
|
}
|
4107
3983
|
|
4108
3984
|
*s = sum;
|
@@ -4328,9 +4204,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4328
4204
|
uint32_t utmp[4];
|
4329
4205
|
|
4330
4206
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
4331
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
4332
4207
|
const int32x4_t vzero = vdupq_n_s32(0);
|
4333
|
-
#endif
|
4334
4208
|
|
4335
4209
|
const uint8x16_t m0 = vdupq_n_u8(1);
|
4336
4210
|
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
|
@@ -4382,22 +4256,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4382
4256
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
4383
4257
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
4384
4258
|
|
4385
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
4386
4259
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
4387
4260
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
4388
4261
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
4389
4262
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
4390
|
-
|
4391
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
|
4392
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
|
4393
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
|
4394
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
|
4395
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
|
4396
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
|
4397
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
|
4398
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
|
4399
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
4400
|
-
#endif
|
4263
|
+
|
4401
4264
|
scale += 4;
|
4402
4265
|
|
4403
4266
|
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
|
@@ -4410,22 +4273,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4410
4273
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
4411
4274
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
4412
4275
|
|
4413
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
4414
4276
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
4415
4277
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
4416
4278
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
4417
4279
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
4418
|
-
|
4419
|
-
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
|
4420
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
|
4421
|
-
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
|
4422
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
|
4423
|
-
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
|
4424
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
|
4425
|
-
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
|
4426
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
|
4427
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
4428
|
-
#endif
|
4280
|
+
|
4429
4281
|
scale += 4;
|
4430
4282
|
|
4431
4283
|
if (j == 0) {
|
@@ -4864,10 +4716,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4864
4716
|
const int nb = n / QK_K;
|
4865
4717
|
|
4866
4718
|
#ifdef __ARM_NEON
|
4867
|
-
|
4868
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
4869
|
-
const int32x4_t vzero = vdupq_n_s32(0);
|
4870
|
-
#endif
|
4719
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
4871
4720
|
|
4872
4721
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
4873
4722
|
const uint8x16_t mh = vdupq_n_u8(4);
|
@@ -4908,22 +4757,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
4908
4757
|
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
4909
4758
|
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
4910
4759
|
|
4911
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
4912
4760
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
4913
4761
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
4914
4762
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
4915
4763
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
4916
|
-
#else
|
4917
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
4918
|
-
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
4919
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
4920
|
-
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
4921
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
4922
|
-
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
4923
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
4924
|
-
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
4925
|
-
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
|
4926
|
-
#endif
|
4927
4764
|
|
4928
4765
|
sum += d * isum;
|
4929
4766
|
|
@@ -5228,11 +5065,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5228
5065
|
uint32_t utmp[4];
|
5229
5066
|
|
5230
5067
|
#ifdef __ARM_NEON
|
5231
|
-
|
5232
5068
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
5233
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
5234
5069
|
const int32x4_t mzero = vdupq_n_s32(0);
|
5235
|
-
#endif
|
5236
5070
|
|
5237
5071
|
ggml_int8x16x2_t q4bytes;
|
5238
5072
|
ggml_int8x16x2_t q8bytes;
|
@@ -5269,10 +5103,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5269
5103
|
int32_t sumi2 = 0;
|
5270
5104
|
|
5271
5105
|
for (int j = 0; j < QK_K/64; ++j) {
|
5272
|
-
|
5273
5106
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
|
5274
5107
|
|
5275
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
5276
5108
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
5277
5109
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5278
5110
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
@@ -5287,26 +5119,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5287
5119
|
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
5288
5120
|
|
5289
5121
|
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
5290
|
-
#else
|
5291
|
-
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
5292
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5293
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
5294
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
5295
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
5296
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
5297
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
5298
|
-
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
|
5299
|
-
|
5300
|
-
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
5301
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
5302
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
5303
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
5304
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
5305
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
5306
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
5307
|
-
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
|
5308
|
-
|
5309
|
-
#endif
|
5310
5122
|
}
|
5311
5123
|
|
5312
5124
|
sumf += d * (sumi1 + sumi2);
|
@@ -5603,12 +5415,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5603
5415
|
const int nb = n / QK_K;
|
5604
5416
|
|
5605
5417
|
#ifdef __ARM_NEON
|
5606
|
-
|
5607
5418
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
5608
5419
|
|
5609
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
5610
5420
|
const int32x4_t mzero = vdupq_n_s32(0);
|
5611
|
-
#endif
|
5612
5421
|
|
5613
5422
|
float sumf = 0;
|
5614
5423
|
|
@@ -5636,7 +5445,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5636
5445
|
|
5637
5446
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
|
5638
5447
|
|
5639
|
-
#ifdef __ARM_FEATURE_DOTPROD
|
5640
5448
|
q8bytes = ggml_vld1q_s8_x4(q8);
|
5641
5449
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5642
5450
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
@@ -5650,27 +5458,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5650
5458
|
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
5651
5459
|
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
5652
5460
|
|
5653
|
-
#else
|
5654
|
-
q8bytes = ggml_vld1q_s8_x4(q8);
|
5655
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
5656
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
5657
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
5658
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
5659
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
5660
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
5661
|
-
int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
|
5662
|
-
|
5663
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
5664
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
5665
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
|
5666
|
-
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
|
5667
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
|
5668
|
-
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
|
5669
|
-
int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
|
5670
|
-
|
5671
|
-
#endif
|
5672
5461
|
sumf += d * (sumi1 + sumi2);
|
5673
|
-
|
5674
5462
|
}
|
5675
5463
|
|
5676
5464
|
*s = sumf - sum_mins;
|
@@ -5875,15 +5663,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5875
5663
|
|
5876
5664
|
uint32_t utmp[4];
|
5877
5665
|
|
5878
|
-
|
5879
5666
|
#ifdef __ARM_NEON
|
5880
|
-
|
5881
5667
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
5882
5668
|
const uint8x16_t mone = vdupq_n_u8(1);
|
5883
5669
|
const uint8x16_t mtwo = vdupq_n_u8(2);
|
5884
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
5885
5670
|
const int32x4_t mzero = vdupq_n_s32(0);
|
5886
|
-
#endif
|
5887
5671
|
|
5888
5672
|
ggml_int8x16x4_t q5bytes;
|
5889
5673
|
|
@@ -5938,28 +5722,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5938
5722
|
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
5939
5723
|
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
5940
5724
|
|
5941
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
5942
|
-
|
5943
5725
|
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
5944
5726
|
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
5945
|
-
#else
|
5946
|
-
|
5947
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
5948
|
-
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
5949
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
5950
|
-
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
5951
|
-
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
|
5952
|
-
|
5953
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
5954
|
-
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
5955
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
5956
|
-
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
5957
|
-
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
|
5958
|
-
#endif
|
5959
5727
|
}
|
5960
5728
|
|
5961
5729
|
sumf += d * sumi - dmin * sumi_mins;
|
5962
|
-
|
5963
5730
|
}
|
5964
5731
|
|
5965
5732
|
*s = sumf;
|
@@ -6311,12 +6078,9 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6311
6078
|
const int nb = n / QK_K;
|
6312
6079
|
|
6313
6080
|
#ifdef __ARM_NEON
|
6314
|
-
|
6315
6081
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
6316
6082
|
const uint8x16_t mh = vdupq_n_u8(16);
|
6317
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6318
6083
|
const int32x4_t mzero = vdupq_n_s32(0);
|
6319
|
-
#endif
|
6320
6084
|
|
6321
6085
|
ggml_int8x16x4_t q5bytes;
|
6322
6086
|
ggml_uint8x16x4_t q5h;
|
@@ -6348,32 +6112,12 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6348
6112
|
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
6349
6113
|
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
6350
6114
|
|
6351
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6352
|
-
|
6353
6115
|
int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
6354
6116
|
int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
6355
6117
|
int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
6356
6118
|
int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
6357
6119
|
|
6358
6120
|
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
6359
|
-
|
6360
|
-
#else
|
6361
|
-
|
6362
|
-
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
6363
|
-
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
6364
|
-
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
6365
|
-
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
6366
|
-
int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
|
6367
|
-
|
6368
|
-
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
6369
|
-
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
6370
|
-
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
6371
|
-
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
6372
|
-
sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
|
6373
|
-
|
6374
|
-
sumf += d*sumi;
|
6375
|
-
#endif
|
6376
|
-
|
6377
6121
|
}
|
6378
6122
|
|
6379
6123
|
*s = sumf;
|
@@ -6600,13 +6344,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6600
6344
|
const int nb = n / QK_K;
|
6601
6345
|
|
6602
6346
|
#ifdef __ARM_NEON
|
6603
|
-
|
6604
6347
|
float sum = 0;
|
6605
6348
|
|
6606
6349
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
6607
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6608
6350
|
const int32x4_t vzero = vdupq_n_s32(0);
|
6609
|
-
#endif
|
6610
6351
|
//const int8x16_t m32s = vdupq_n_s8(32);
|
6611
6352
|
|
6612
6353
|
const uint8x16_t mone = vdupq_n_u8(3);
|
@@ -6626,7 +6367,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6626
6367
|
|
6627
6368
|
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
6628
6369
|
const int8x16_t scales = vld1q_s8(scale);
|
6629
|
-
const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
|
6370
|
+
const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
|
6630
6371
|
|
6631
6372
|
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
6632
6373
|
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
@@ -6658,30 +6399,12 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6658
6399
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
6659
6400
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
6660
6401
|
|
6661
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6662
|
-
|
6663
6402
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
6664
6403
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
6665
6404
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
6666
6405
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
6667
|
-
scale += 4;
|
6668
6406
|
|
6669
|
-
|
6670
|
-
|
6671
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
6672
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
6673
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
6674
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
6675
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
6676
|
-
scale += 2;
|
6677
|
-
|
6678
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
6679
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
6680
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
6681
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
6682
|
-
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
6683
|
-
scale += 2;
|
6684
|
-
#endif
|
6407
|
+
scale += 4;
|
6685
6408
|
|
6686
6409
|
q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
|
6687
6410
|
|
@@ -6703,34 +6426,11 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6703
6426
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
6704
6427
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
6705
6428
|
|
6706
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
6707
|
-
|
6708
6429
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
6709
6430
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
6710
6431
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
6711
6432
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
6712
6433
|
scale += 4;
|
6713
|
-
|
6714
|
-
//for (int l = 0; l < 4; ++l) {
|
6715
|
-
// const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
|
6716
|
-
// isum += vaddvq_s32(p) * *scale++;
|
6717
|
-
//}
|
6718
|
-
#else
|
6719
|
-
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
6720
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
6721
|
-
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
6722
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
6723
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
6724
|
-
scale += 2;
|
6725
|
-
|
6726
|
-
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
6727
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
6728
|
-
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
6729
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
6730
|
-
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
6731
|
-
scale += 2;
|
6732
|
-
#endif
|
6733
|
-
|
6734
6434
|
}
|
6735
6435
|
//sum += isum * d_all * y[i].d;
|
6736
6436
|
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
@@ -7076,14 +6776,11 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7076
6776
|
const int nb = n / QK_K;
|
7077
6777
|
|
7078
6778
|
#ifdef __ARM_NEON
|
7079
|
-
|
7080
6779
|
float sum = 0;
|
7081
6780
|
|
7082
6781
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
7083
6782
|
const int8x16_t m32s = vdupq_n_s8(32);
|
7084
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
7085
6783
|
const int32x4_t vzero = vdupq_n_s32(0);
|
7086
|
-
#endif
|
7087
6784
|
|
7088
6785
|
const uint8x16_t mone = vdupq_n_u8(3);
|
7089
6786
|
|
@@ -7119,26 +6816,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7119
6816
|
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
7120
6817
|
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
7121
6818
|
|
7122
|
-
#if defined(__ARM_FEATURE_DOTPROD)
|
7123
|
-
|
7124
6819
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
7125
6820
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
7126
6821
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
7127
6822
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
7128
|
-
#else
|
7129
|
-
|
7130
|
-
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
7131
|
-
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
7132
|
-
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
7133
|
-
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
7134
|
-
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
7135
|
-
|
7136
|
-
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
7137
|
-
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
7138
|
-
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
7139
|
-
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
7140
|
-
isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
7141
|
-
#endif
|
7142
6823
|
|
7143
6824
|
sum += isum * d_all * y[i].d;
|
7144
6825
|
|