cui-llama.rn 1.1.2 → 1.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml-quants.c CHANGED
@@ -1630,7 +1630,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6
1630
1630
  // ===================== Helper functions
1631
1631
  //
1632
1632
  static inline int nearest_int(float fval) {
1633
- assert(fval <= 4194303.f);
1633
+ assert(fabsf(fval) <= 4194303.f);
1634
1634
  float val = fval + 12582912.f;
1635
1635
  int i; memcpy(&i, &val, sizeof(int));
1636
1636
  return (i & 0x007fffff) - 0x00400000;
@@ -3306,6 +3306,191 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
3306
3306
  return nrow * row_size;
3307
3307
  }
3308
3308
 
3309
+ // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
3310
+
3311
+ void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
3312
+ assert(k % QK_K == 0);
3313
+ const int64_t nb = k / QK_K;
3314
+
3315
+ for (int64_t i = 0; i < nb; i++) {
3316
+ float amax = 0.0f; // absolute max
3317
+
3318
+ for (int j = 0; j < QK_K; j++) {
3319
+ const float v = x[j];
3320
+ amax = MAX(amax, fabsf(v));
3321
+ }
3322
+
3323
+ const float d = amax;
3324
+ const float id = d ? 1.0f/d : 0.0f;
3325
+
3326
+ y[i].d = LM_GGML_FP32_TO_FP16(d);
3327
+
3328
+ // 5 elements per byte, along 32 bytes
3329
+ for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
3330
+ for (size_t m = 0; m < 32; ++m) {
3331
+ uint8_t q = 0;
3332
+ for (size_t n = 0; n < 5; ++n) {
3333
+ int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2
3334
+ q *= 3;
3335
+ q += xi;
3336
+ }
3337
+ // ceiling division (243 == pow(3, 5))
3338
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
3339
+ y[i].qs[j + m] = q;
3340
+ }
3341
+ x += 5*32;
3342
+ }
3343
+ // along 16 bytes
3344
+ for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
3345
+ for (size_t m = 0; m < 16; ++m) {
3346
+ uint8_t q = 0;
3347
+ for (size_t n = 0; n < 5; ++n) {
3348
+ int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2
3349
+ q *= 3;
3350
+ q += xi;
3351
+ }
3352
+ // ceiling division (243 == pow(3, 5))
3353
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
3354
+ y[i].qs[j + m] = q;
3355
+ }
3356
+ x += 5*16;
3357
+ }
3358
+ // 4 elements per byte
3359
+ for (size_t j = 0; j < sizeof(y->qh); ++j) {
3360
+ uint8_t q = 0;
3361
+ for (size_t m = 0; m < 4; ++m) {
3362
+ // -1, 0, 1 -> 0, 1, 2
3363
+ int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;
3364
+ q *= 3;
3365
+ q += xi;
3366
+ }
3367
+ // shift the first value to the most significant trit
3368
+ q *= 3;
3369
+ // ceiling division (243 == pow(3, 5))
3370
+ q = ((uint16_t)q * 256 + (243 - 1)) / 243;
3371
+ y[i].qh[j] = q;
3372
+ }
3373
+ x += 4*sizeof(y->qh);
3374
+ }
3375
+ }
3376
+
3377
+ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
3378
+ assert(k % QK_K == 0);
3379
+ const int64_t nb = k / QK_K;
3380
+
3381
+ for (int64_t i = 0; i < nb; i++) {
3382
+ float amax = 0.0f; // absolute max
3383
+
3384
+ for (int j = 0; j < QK_K; j++) {
3385
+ const float v = x[j];
3386
+ amax = MAX(amax, fabsf(v));
3387
+ }
3388
+
3389
+ const float d = amax;
3390
+ const float id = d ? 1.0f/d : 0.0f;
3391
+
3392
+ y[i].d = LM_GGML_FP32_TO_FP16(d);
3393
+
3394
+ for (size_t j = 0; j < sizeof(y->qs); j += 32) {
3395
+ for (size_t m = 0; m < 32; ++m) {
3396
+ uint8_t q = 0;
3397
+ for (size_t n = 0; n < 4; ++n) {
3398
+ // -1, 0, 1 -> 0, 1, 2
3399
+ int xi = lroundf(x[m + n*32] * id) + 1;
3400
+ q += (xi & 3) << (2*n);
3401
+ }
3402
+ y[i].qs[j + m] = q;
3403
+ }
3404
+ x += 4*32;
3405
+ }
3406
+ }
3407
+ }
3408
+
3409
+ void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
3410
+ assert(k % QK_K == 0);
3411
+ block_tq1_0 * restrict y = vy;
3412
+ quantize_row_tq1_0_ref(x, y, k);
3413
+ }
3414
+
3415
+ void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
3416
+ assert(k % QK_K == 0);
3417
+ block_tq2_0 * restrict y = vy;
3418
+ quantize_row_tq2_0_ref(x, y, k);
3419
+ }
3420
+
3421
+ size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3422
+ (void)quant_weights; // not used
3423
+ const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_TQ1_0, n_per_row);
3424
+ quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row);
3425
+ return nrow * row_size;
3426
+ }
3427
+
3428
+ size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3429
+ (void)quant_weights; // not used
3430
+ const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_TQ2_0, n_per_row);
3431
+ quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row);
3432
+ return nrow * row_size;
3433
+ }
3434
+
3435
+
3436
+ void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) {
3437
+ assert(k % QK_K == 0);
3438
+ const int64_t nb = k / QK_K;
3439
+
3440
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
3441
+
3442
+ for (int64_t i = 0; i < nb; ++i) {
3443
+
3444
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d);
3445
+
3446
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
3447
+ for (size_t n = 0; n < 5; ++n) {
3448
+ for (size_t m = 0; m < 32; ++m) {
3449
+ uint8_t q = x[i].qs[j + m] * pow3[n];
3450
+ int16_t xi = ((uint16_t) q * 3) >> 8;
3451
+ *y++ = (float) (xi - 1) * d;
3452
+ }
3453
+ }
3454
+ }
3455
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
3456
+ for (size_t n = 0; n < 5; ++n) {
3457
+ for (size_t m = 0; m < 16; ++m) {
3458
+ uint8_t q = x[i].qs[j + m] * pow3[n];
3459
+ int16_t xi = ((uint16_t) q * 3) >> 8;
3460
+ *y++ = (float) (xi - 1) * d;
3461
+ }
3462
+ }
3463
+ }
3464
+
3465
+ for (size_t n = 0; n < 4; ++n) {
3466
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
3467
+ uint8_t q = x[i].qh[j] * pow3[n];
3468
+ int16_t xi = ((uint16_t) q * 3) >> 8;
3469
+ *y++ = (float) (xi - 1) * d;
3470
+ }
3471
+ }
3472
+ }
3473
+ }
3474
+
3475
+ void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) {
3476
+ assert(k % QK_K == 0);
3477
+ const int64_t nb = k / QK_K;
3478
+
3479
+ for (int64_t i = 0; i < nb; ++i) {
3480
+
3481
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d);
3482
+
3483
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
3484
+ for (size_t l = 0; l < 4; ++l) {
3485
+ for (size_t m = 0; m < 32; ++m) {
3486
+ int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;
3487
+ *y++ = (float) (q - 1) * d;
3488
+ }
3489
+ }
3490
+ }
3491
+ }
3492
+ }
3493
+
3309
3494
  // ====================== "True" 2-bit (de)-quantization
3310
3495
 
3311
3496
  void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
@@ -3644,7 +3829,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
3644
3829
  quantize_row_q8_K_ref(x, y, k);
3645
3830
  }
3646
3831
 
3647
- //===================================== Dot ptoducts =================================
3832
+ //===================================== Dot products =================================
3648
3833
 
3649
3834
  //
3650
3835
  // Helper functions
@@ -3818,42 +4003,141 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
3818
4003
  float sumf = 0;
3819
4004
 
3820
4005
  #if defined(__ARM_FEATURE_SVE)
3821
- if (lm_ggml_sve_cnt_b == QK8_0) {
3822
- const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
3823
- const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
3824
-
3825
- svfloat32_t sumv0 = svdup_n_f32(0.0f);
3826
- svfloat32_t sumv1 = svdup_n_f32(0.0f);
3827
-
3828
- for (; ib + 1 < nb; ib += 2) {
3829
- const block_q4_0 * restrict x0 = &x[ib + 0];
3830
- const block_q4_0 * restrict x1 = &x[ib + 1];
3831
- const block_q8_0 * restrict y0 = &y[ib + 0];
3832
- const block_q8_0 * restrict y1 = &y[ib + 1];
4006
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
4007
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
3833
4008
 
3834
- // load x
3835
- const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
3836
- const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4009
+ const int vector_length = lm_ggml_sve_cnt_b*8;
3837
4010
 
3838
- // 4-bit -> 8-bit
3839
- const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
3840
- const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
3841
-
3842
- // sub 8
3843
- const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
3844
- const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4011
+ // VLA Implementation using switch case
4012
+ switch (vector_length) {
4013
+ case 128:
4014
+ {
4015
+ // predicate for activating higher lanes for 4 float32 elements
4016
+ const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
4017
+
4018
+ for (; ib + 1 < nb; ib += 2) {
4019
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4020
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4021
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4022
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4023
+
4024
+ // load x
4025
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4026
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4027
+
4028
+ // 4-bit -> 8-bit
4029
+ const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
4030
+ const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
4031
+ const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
4032
+ const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
4033
+
4034
+ // sub 8
4035
+ const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
4036
+ const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
4037
+ const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
4038
+ const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
4039
+
4040
+ // load y
4041
+ const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
4042
+ const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
4043
+ const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
4044
+ const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
4045
+
4046
+ // dot product
4047
+ sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4048
+ svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
4049
+ svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
4050
+ sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
4051
+ svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
4052
+ svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
4053
+ }
3845
4054
 
3846
- // load y
3847
- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
3848
- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4055
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4056
+ } break;
4057
+ case 256:
4058
+ {
4059
+ // predicate for activating higher lanes for 16 int8 elements
4060
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4061
+ // predicate for activating lower lanes for 16 int8 elements
4062
+ const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
4063
+
4064
+ for (; ib + 1 < nb; ib += 2) {
4065
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4066
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4067
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4068
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4069
+
4070
+ // load x
4071
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
4072
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
4073
+
4074
+ // 4-bit -> 8-bit
4075
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4076
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4077
+
4078
+ // sub 8
4079
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
4080
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
4081
+
4082
+ // load y
4083
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
4084
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
4085
+
4086
+ // dot product
4087
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
4088
+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
4089
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
4090
+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
4091
+ }
3849
4092
 
3850
- // dot product
3851
- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
3852
- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
3853
- }
4093
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4094
+ } break;
4095
+ case 512:
4096
+ {
4097
+ // predicate for activating higher lanes for 32 int8 elements
4098
+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
4099
+
4100
+ // predicate for activating higher lanes for 16 int8 elements
4101
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
4102
+ // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
4103
+ const svbool_t pl16 = svnot_b_z(ph32, ph16);
4104
+
4105
+ for (; ib + 1 < nb; ib += 2) {
4106
+ const block_q4_0 * restrict x0 = &x[ib + 0];
4107
+ const block_q4_0 * restrict x1 = &x[ib + 1];
4108
+ const block_q8_0 * restrict y0 = &y[ib + 0];
4109
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4110
+
4111
+ // load x
4112
+ const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
4113
+ const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
4114
+
4115
+ // 4-bit -> 8-bit
4116
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
4117
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
4118
+
4119
+ // sub 8
4120
+ const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
4121
+ const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
4122
+
4123
+ // load y
4124
+ const svint8_t qy0 = svld1_s8(ph32, y0->qs);
4125
+ const svint8_t qy1 = svld1_s8(ph32, y1->qs);
4126
+
4127
+ // dot product
4128
+ sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
4129
+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
4130
+ sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
4131
+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
4132
+ }
3854
4133
 
3855
- sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
4134
+ sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
4135
+ } break;
4136
+ default:
4137
+ assert(false && "Unsupported vector length");
4138
+ break;
3856
4139
  }
4140
+
3857
4141
  #elif defined(__ARM_NEON)
3858
4142
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3859
4143
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -5303,29 +5587,124 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
5303
5587
  float sumf = 0;
5304
5588
 
5305
5589
  #if defined(__ARM_FEATURE_SVE)
5306
- if (lm_ggml_sve_cnt_b == QK8_0) {
5307
- svfloat32_t sumv0 = svdup_n_f32(0.0f);
5308
- svfloat32_t sumv1 = svdup_n_f32(0.0f);
5590
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5591
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
5309
5592
 
5310
- for (; ib + 1 < nb; ib += 2) {
5311
- const block_q8_0 * restrict x0 = &x[ib + 0];
5312
- const block_q8_0 * restrict x1 = &x[ib + 1];
5313
- const block_q8_0 * restrict y0 = &y[ib + 0];
5314
- const block_q8_0 * restrict y1 = &y[ib + 1];
5593
+ const int vector_length = lm_ggml_sve_cnt_b*8;
5315
5594
 
5316
- // load x
5317
- const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5318
- const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5595
+ //VLA Implemenation for SVE
5596
+ switch (vector_length) {
5597
+ case 128:
5598
+ {
5599
+ // predicate for activating lanes for 16 Int8 elements
5600
+ const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
5601
+ const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
5602
+
5603
+ for (; ib + 1 < nb; ib += 2) {
5604
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5605
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5606
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5607
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5608
+
5609
+ // load x
5610
+ const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
5611
+ const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
5612
+ const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
5613
+ const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
5614
+
5615
+ // load y
5616
+ const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
5617
+ const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
5618
+ const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
5619
+ const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
5620
+
5621
+ sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5622
+ svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
5623
+ svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
5624
+ sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
5625
+ svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
5626
+ svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
5627
+ }
5319
5628
 
5320
- // load y
5321
- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5322
- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5629
+ sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
5630
+ } break;
5631
+ case 256:
5632
+ {
5633
+ //printf("sve256");
5634
+ for (; ib + 1 < nb; ib += 2) {
5635
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5636
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5637
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5638
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5639
+
5640
+ // load x
5641
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5642
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5643
+
5644
+ // load y
5645
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5646
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5647
+
5648
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
5649
+ svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
5650
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
5651
+ svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
5652
+ }
5323
5653
 
5324
- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d));
5325
- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d));
5326
- }
5654
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5655
+ } break;
5656
+ case 512:
5657
+ {
5658
+ // predicate for activating high 256 bit
5659
+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
5660
+ // predicate for activating low 256 bit
5661
+ const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
5662
+
5663
+ // predicate for activating high lanes for 8 float32 elements
5664
+ const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
5665
+ // predicate for activating low lanes for 8 float32 elements
5666
+ const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
5667
+
5668
+ svfloat32_t sumv00 = svdup_n_f32(0.0f);
5669
+
5670
+ for (; ib + 1 < nb; ib += 2) {
5671
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5672
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5673
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5674
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5327
5675
 
5328
- sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5676
+ //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
5677
+ // and add them to make one 64 element vector
5678
+ // load x
5679
+ const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
5680
+ svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
5681
+
5682
+ qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
5683
+
5684
+ // load y
5685
+ const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
5686
+ svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
5687
+
5688
+ qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
5689
+
5690
+ // scale creation
5691
+ const float32_t deq1 = LM_GGML_FP16_TO_FP32(x0->d)*LM_GGML_FP16_TO_FP32(y0->d);
5692
+ const float32_t deq2 = LM_GGML_FP16_TO_FP32(x1->d)*LM_GGML_FP16_TO_FP32(y1->d);
5693
+
5694
+ // duplicate deq1 in first half of vector and deq2 in second half of vector
5695
+ const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
5696
+
5697
+ const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
5698
+
5699
+ sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
5700
+ }
5701
+
5702
+ sumf = svaddv_f32(svptrue_b32(), sumv00);
5703
+ break;
5704
+ }
5705
+ default:
5706
+ assert(false && "Unsupported vector length");
5707
+ break;
5329
5708
  }
5330
5709
  #elif defined(__ARM_NEON)
5331
5710
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -5470,6 +5849,501 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
5470
5849
  *s = sumf;
5471
5850
  }
5472
5851
 
5852
+ void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5853
+ assert(nrc == 1);
5854
+ UNUSED(nrc);
5855
+ UNUSED(bx);
5856
+ UNUSED(by);
5857
+ UNUSED(bs);
5858
+
5859
+ const block_tq1_0 * restrict x = vx;
5860
+ const block_q8_K * restrict y = vy;
5861
+
5862
+ const int nb = n / QK_K;
5863
+
5864
+ #if defined(__ARM_NEON)
5865
+ float sumf = 0.0f;
5866
+
5867
+ uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
5868
+
5869
+ const uint8x16_t shift = vld1q_u8(k_shift);
5870
+
5871
+ for (int i = 0; i < nb; ++i) {
5872
+ #if defined(__ARM_FEATURE_DOTPROD)
5873
+ int32x4_t sumi0 = vdupq_n_s32(0);
5874
+ int32x4_t sumi1 = vdupq_n_s32(0);
5875
+ #else
5876
+ int16x8_t sumi0 = vdupq_n_s16(0);
5877
+ int16x8_t sumi1 = vdupq_n_s16(0);
5878
+ #endif
5879
+
5880
+ // first 32 bytes of 5 elements
5881
+ {
5882
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
5883
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
5884
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
5885
+ uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
5886
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
5887
+ uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
5888
+ uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
5889
+ uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
5890
+ uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
5891
+ uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
5892
+
5893
+ // multiply by 3 and keep the 2 bits above 8 bits
5894
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5895
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5896
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5897
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5898
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5899
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5900
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
5901
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
5902
+ int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
5903
+ int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
5904
+
5905
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
5906
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
5907
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
5908
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
5909
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
5910
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
5911
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
5912
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
5913
+ const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
5914
+ const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
5915
+
5916
+ #if defined(__ARM_FEATURE_DOTPROD)
5917
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5918
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5919
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5920
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5921
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5922
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5923
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
5924
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
5925
+ sumi0 = vdotq_s32(sumi0, sqx8, qy8);
5926
+ sumi1 = vdotq_s32(sumi1, sqx9, qy9);
5927
+ #else
5928
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
5929
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
5930
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
5931
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
5932
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
5933
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
5934
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
5935
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
5936
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
5937
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
5938
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
5939
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
5940
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
5941
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
5942
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
5943
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
5944
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
5945
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
5946
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
5947
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
5948
+ #endif
5949
+ }
5950
+
5951
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
5952
+ {
5953
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
5954
+ uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
5955
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
5956
+ uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
5957
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
5958
+ uint32_t qh;
5959
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
5960
+ uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
5961
+ qx5 = vmulq_u8(qx5, shift);
5962
+
5963
+ // multiply by 3 and keep the 2 bits above 8 bits
5964
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5965
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5966
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5967
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5968
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5969
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5970
+
5971
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
5972
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
5973
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
5974
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
5975
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
5976
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
5977
+
5978
+ #if defined(__ARM_FEATURE_DOTPROD)
5979
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5980
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5981
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5982
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5983
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5984
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5985
+ #else
5986
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
5987
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
5988
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
5989
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
5990
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
5991
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
5992
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
5993
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
5994
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
5995
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
5996
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
5997
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
5998
+ #endif
5999
+ }
6000
+
6001
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
6002
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
6003
+
6004
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6005
+
6006
+ #if defined(__ARM_FEATURE_DOTPROD)
6007
+ sumi0 = vaddq_s32(sumi0, sumi1);
6008
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6009
+
6010
+ sumf += d * (float) vaddvq_s32(sumi0);
6011
+ #else
6012
+ sumi0 = vaddq_s16(sumi0, sumi1);
6013
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
6014
+
6015
+ sumf += d * (float) vaddlvq_s16(sumi0);
6016
+ #endif
6017
+ }
6018
+
6019
+ *s = sumf;
6020
+
6021
+ #elif defined(__AVX2__)
6022
+ __m256 sumf = _mm256_setzero_ps();
6023
+
6024
+ for (int i = 0; i < nb; ++i) {
6025
+ // 16-bit sums
6026
+ __m256i sumi0 = _mm256_setzero_si256();
6027
+ __m256i sumi1 = _mm256_setzero_si256();
6028
+ __m256i sumi2 = _mm256_setzero_si256();
6029
+
6030
+ // first 32 bytes of 5 elements
6031
+ {
6032
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
6033
+ // 8-bit multiplies with shifts, masks and adds
6034
+ __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
6035
+ __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
6036
+ __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
6037
+ __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
6038
+
6039
+ // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
6040
+
6041
+ // Cancel the +1 from avg so that it behaves like a halving add
6042
+ qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
6043
+ qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
6044
+ qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
6045
+ qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
6046
+ qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
6047
+ // Multiply by 3 and get the top 2 bits
6048
+ qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
6049
+ qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
6050
+ qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
6051
+ qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
6052
+ qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
6053
+ qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
6054
+ qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
6055
+ qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
6056
+ qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
6057
+ qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
6058
+
6059
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
6060
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
6061
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
6062
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
6063
+ const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
6064
+
6065
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
6066
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
6067
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
6068
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
6069
+ qx4 = _mm256_maddubs_epi16(qx4, qy4);
6070
+
6071
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
6072
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
6073
+ sumi2 = _mm256_add_epi16(sumi2, qx4);
6074
+ }
6075
+
6076
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
6077
+ {
6078
+ __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
6079
+ uint32_t qh;
6080
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
6081
+ __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
6082
+ __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
6083
+ __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
6084
+ __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
6085
+ __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
6086
+ __m256i qx01 = MM256_SET_M128I(qx1, qx0);
6087
+ __m256i qx23 = MM256_SET_M128I(qx3, qx2);
6088
+
6089
+ // avx2 does not have 8-bit multiplies, so 16-bit it is.
6090
+ qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
6091
+ qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
6092
+ __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
6093
+
6094
+ __m256i qx45 = MM256_SET_M128I(qx5, qx4);
6095
+
6096
+ // Cancel the +1 from avg so that it behaves like a halving add
6097
+ qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
6098
+ qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
6099
+ qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
6100
+ // Multiply by 3 and get the top 2 bits
6101
+ qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
6102
+ qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
6103
+ qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
6104
+ qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
6105
+ qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
6106
+ qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
6107
+
6108
+ const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
6109
+ const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
6110
+ const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
6111
+
6112
+ qx01 = _mm256_maddubs_epi16(qx01, qy01);
6113
+ qx23 = _mm256_maddubs_epi16(qx23, qy23);
6114
+ qx45 = _mm256_maddubs_epi16(qx45, qy45);
6115
+
6116
+ sumi0 = _mm256_add_epi16(sumi0, qx01);
6117
+ sumi1 = _mm256_add_epi16(sumi1, qx23);
6118
+ sumi2 = _mm256_add_epi16(sumi2, qx45);
6119
+ }
6120
+
6121
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
6122
+ const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d));
6123
+
6124
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
6125
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
6126
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
6127
+
6128
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
6129
+ }
6130
+
6131
+ *s = hsum_float_8(sumf);
6132
+
6133
+ #else
6134
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
6135
+
6136
+ float sumf = 0.0f;
6137
+
6138
+ for (int i = 0; i < nb; ++i) {
6139
+ int sum = 0;
6140
+
6141
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
6142
+ for (size_t l = 0; l < 5; ++l) {
6143
+ for (size_t m = 0; m < 32; ++m) {
6144
+ uint8_t q = x[i].qs[j + m] * pow3[l];
6145
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
6146
+ sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
6147
+ }
6148
+ }
6149
+ }
6150
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
6151
+ for (size_t l = 0; l < 5; ++l) {
6152
+ for (size_t m = 0; m < 16; ++m) {
6153
+ uint8_t q = x[i].qs[j + m] * pow3[l];
6154
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
6155
+ sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
6156
+ }
6157
+ }
6158
+ }
6159
+
6160
+ for (size_t l = 0; l < 4; ++l) {
6161
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
6162
+ uint8_t q = x[i].qh[j] * pow3[l];
6163
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
6164
+ sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
6165
+ }
6166
+ }
6167
+
6168
+ sumf += (float) sum * (LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d);
6169
+ }
6170
+
6171
+ *s = sumf;
6172
+ #endif
6173
+ }
6174
+
6175
+ void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6176
+ assert(nrc == 1);
6177
+ UNUSED(nrc);
6178
+ UNUSED(bx);
6179
+ UNUSED(by);
6180
+ UNUSED(bs);
6181
+
6182
+ const block_tq2_0 * restrict x = vx;
6183
+ const block_q8_K * restrict y = vy;
6184
+
6185
+ const int nb = n / QK_K;
6186
+
6187
+ #if defined(__ARM_NEON)
6188
+ float sumf = 0.0f;
6189
+
6190
+ const uint8x16_t m3 = vdupq_n_u8(3);
6191
+
6192
+ for (int i = 0; i < nb; ++i) {
6193
+ #if defined(__ARM_FEATURE_DOTPROD)
6194
+ int32x4_t sumi0 = vdupq_n_s32(0);
6195
+ int32x4_t sumi1 = vdupq_n_s32(0);
6196
+ #else
6197
+ int16x8_t sumi0 = vdupq_n_s16(0);
6198
+ int16x8_t sumi1 = vdupq_n_s16(0);
6199
+ #endif
6200
+
6201
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6202
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
6203
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
6204
+ uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
6205
+ uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
6206
+ uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
6207
+ uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
6208
+ uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
6209
+ uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
6210
+
6211
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
6212
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
6213
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
6214
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
6215
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
6216
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
6217
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
6218
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
6219
+
6220
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
6221
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
6222
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
6223
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
6224
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
6225
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
6226
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
6227
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
6228
+
6229
+ #if defined(__ARM_FEATURE_DOTPROD)
6230
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
6231
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
6232
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
6233
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
6234
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
6235
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
6236
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
6237
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
6238
+ #else
6239
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
6240
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
6241
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
6242
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
6243
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
6244
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
6245
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
6246
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
6247
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
6248
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
6249
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
6250
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
6251
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
6252
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
6253
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
6254
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
6255
+ #endif
6256
+ }
6257
+
6258
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
6259
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
6260
+
6261
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6262
+
6263
+ #if defined(__ARM_FEATURE_DOTPROD)
6264
+ sumi0 = vaddq_s32(sumi0, sumi1);
6265
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
6266
+
6267
+ sumf += d * (float) vaddvq_s32(sumi0);
6268
+ #else
6269
+ sumi0 = vaddq_s16(sumi0, sumi1);
6270
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
6271
+
6272
+ sumf += d * (float) vaddlvq_s16(sumi0);
6273
+ #endif
6274
+ }
6275
+
6276
+ *s = sumf;
6277
+
6278
+ #elif defined(__AVX2__)
6279
+ __m256 sumf = _mm256_setzero_ps();
6280
+
6281
+ for (int i = 0; i < nb; ++i) {
6282
+ // 16-bit sums, because 256*127 still fits
6283
+ __m256i sumi0 = _mm256_setzero_si256();
6284
+ __m256i sumi1 = _mm256_setzero_si256();
6285
+
6286
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6287
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
6288
+ __m256i qx1 = _mm256_srli_epi16(qx0, 2);
6289
+ __m256i qx2 = _mm256_srli_epi16(qx0, 4);
6290
+ __m256i qx3 = _mm256_srli_epi16(qx0, 6);
6291
+
6292
+ // 0, 1, 2 (should not be 3)
6293
+ qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
6294
+ qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
6295
+ qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
6296
+ qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
6297
+
6298
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
6299
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
6300
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
6301
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
6302
+
6303
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
6304
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
6305
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
6306
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
6307
+
6308
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
6309
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
6310
+ }
6311
+
6312
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
6313
+ const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d));
6314
+
6315
+ sumi0 = _mm256_add_epi16(sumi0, sumi1);
6316
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
6317
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
6318
+
6319
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
6320
+ }
6321
+
6322
+ *s = hsum_float_8(sumf);
6323
+
6324
+ #else
6325
+ float sumf = 0.0f;
6326
+
6327
+ for (int i = 0; i < nb; ++i) {
6328
+ int32_t sumi = 0;
6329
+
6330
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
6331
+ for (size_t l = 0; l < 4; ++l) {
6332
+ for (size_t k = 0; k < 32; ++k) {
6333
+ sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
6334
+ }
6335
+ }
6336
+ }
6337
+
6338
+ const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
6339
+
6340
+ sumf += (float) sumi * d;
6341
+ }
6342
+
6343
+ *s = sumf;
6344
+ #endif
6345
+ }
6346
+
5473
6347
  void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5474
6348
  assert(nrc == 1);
5475
6349
  UNUSED(nrc);
@@ -14800,6 +15674,14 @@ bool lm_ggml_validate_row_data(enum lm_ggml_type type, const void * data, size_t
14800
15674
  }
14801
15675
  }
14802
15676
  } break;
15677
+ case LM_GGML_TYPE_TQ1_0:
15678
+ {
15679
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);
15680
+ } break;
15681
+ case LM_GGML_TYPE_TQ2_0:
15682
+ {
15683
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
15684
+ } break;
14803
15685
  case LM_GGML_TYPE_IQ1_S:
14804
15686
  {
14805
15687
  VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);