cui-llama.rn 1.1.2 → 1.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +1 -2
- package/android/src/main/jni.cpp +26 -21
- package/cpp/common.cpp +2028 -1520
- package/cpp/common.h +134 -18
- package/cpp/ggml-aarch64.c +612 -0
- package/cpp/ggml-alloc.h +2 -2
- package/cpp/ggml-backend.c +33 -6
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-common.h +20 -0
- package/cpp/ggml-impl.h +4 -7
- package/cpp/ggml-metal.m +63 -2
- package/cpp/ggml-quants.c +690 -2
- package/cpp/ggml-quants.h +15 -0
- package/cpp/ggml.c +1650 -317
- package/cpp/ggml.h +155 -48
- package/cpp/llama-grammar.cpp +721 -122
- package/cpp/llama-grammar.h +120 -15
- package/cpp/llama-impl.h +132 -1
- package/cpp/llama-sampling.cpp +1361 -356
- package/cpp/llama-sampling.h +20 -48
- package/cpp/llama-vocab.cpp +140 -7
- package/cpp/llama-vocab.h +3 -2
- package/cpp/llama.cpp +810 -307
- package/cpp/llama.h +213 -259
- package/cpp/rn-llama.hpp +17 -14
- package/cpp/sampling.cpp +347 -355
- package/cpp/sampling.h +106 -135
- package/cpp/sgemm.cpp +153 -0
- package/package.json +1 -1
- package/cpp/grammar-parser.cpp +0 -539
- package/cpp/grammar-parser.h +0 -29
package/cpp/ggml-quants.c
CHANGED
@@ -1630,7 +1630,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6
|
|
1630
1630
|
// ===================== Helper functions
|
1631
1631
|
//
|
1632
1632
|
static inline int nearest_int(float fval) {
|
1633
|
-
assert(fval <= 4194303.f);
|
1633
|
+
assert(fabsf(fval) <= 4194303.f);
|
1634
1634
|
float val = fval + 12582912.f;
|
1635
1635
|
int i; memcpy(&i, &val, sizeof(int));
|
1636
1636
|
return (i & 0x007fffff) - 0x00400000;
|
@@ -3306,6 +3306,191 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
|
|
3306
3306
|
return nrow * row_size;
|
3307
3307
|
}
|
3308
3308
|
|
3309
|
+
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
3310
|
+
|
3311
|
+
void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
|
3312
|
+
assert(k % QK_K == 0);
|
3313
|
+
const int64_t nb = k / QK_K;
|
3314
|
+
|
3315
|
+
for (int64_t i = 0; i < nb; i++) {
|
3316
|
+
float amax = 0.0f; // absolute max
|
3317
|
+
|
3318
|
+
for (int j = 0; j < QK_K; j++) {
|
3319
|
+
const float v = x[j];
|
3320
|
+
amax = MAX(amax, fabsf(v));
|
3321
|
+
}
|
3322
|
+
|
3323
|
+
const float d = amax;
|
3324
|
+
const float id = d ? 1.0f/d : 0.0f;
|
3325
|
+
|
3326
|
+
y[i].d = LM_GGML_FP32_TO_FP16(d);
|
3327
|
+
|
3328
|
+
// 5 elements per byte, along 32 bytes
|
3329
|
+
for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
|
3330
|
+
for (size_t m = 0; m < 32; ++m) {
|
3331
|
+
uint8_t q = 0;
|
3332
|
+
for (size_t n = 0; n < 5; ++n) {
|
3333
|
+
int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2
|
3334
|
+
q *= 3;
|
3335
|
+
q += xi;
|
3336
|
+
}
|
3337
|
+
// ceiling division (243 == pow(3, 5))
|
3338
|
+
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
3339
|
+
y[i].qs[j + m] = q;
|
3340
|
+
}
|
3341
|
+
x += 5*32;
|
3342
|
+
}
|
3343
|
+
// along 16 bytes
|
3344
|
+
for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
|
3345
|
+
for (size_t m = 0; m < 16; ++m) {
|
3346
|
+
uint8_t q = 0;
|
3347
|
+
for (size_t n = 0; n < 5; ++n) {
|
3348
|
+
int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2
|
3349
|
+
q *= 3;
|
3350
|
+
q += xi;
|
3351
|
+
}
|
3352
|
+
// ceiling division (243 == pow(3, 5))
|
3353
|
+
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
3354
|
+
y[i].qs[j + m] = q;
|
3355
|
+
}
|
3356
|
+
x += 5*16;
|
3357
|
+
}
|
3358
|
+
// 4 elements per byte
|
3359
|
+
for (size_t j = 0; j < sizeof(y->qh); ++j) {
|
3360
|
+
uint8_t q = 0;
|
3361
|
+
for (size_t m = 0; m < 4; ++m) {
|
3362
|
+
// -1, 0, 1 -> 0, 1, 2
|
3363
|
+
int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;
|
3364
|
+
q *= 3;
|
3365
|
+
q += xi;
|
3366
|
+
}
|
3367
|
+
// shift the first value to the most significant trit
|
3368
|
+
q *= 3;
|
3369
|
+
// ceiling division (243 == pow(3, 5))
|
3370
|
+
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
3371
|
+
y[i].qh[j] = q;
|
3372
|
+
}
|
3373
|
+
x += 4*sizeof(y->qh);
|
3374
|
+
}
|
3375
|
+
}
|
3376
|
+
|
3377
|
+
void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
|
3378
|
+
assert(k % QK_K == 0);
|
3379
|
+
const int64_t nb = k / QK_K;
|
3380
|
+
|
3381
|
+
for (int64_t i = 0; i < nb; i++) {
|
3382
|
+
float amax = 0.0f; // absolute max
|
3383
|
+
|
3384
|
+
for (int j = 0; j < QK_K; j++) {
|
3385
|
+
const float v = x[j];
|
3386
|
+
amax = MAX(amax, fabsf(v));
|
3387
|
+
}
|
3388
|
+
|
3389
|
+
const float d = amax;
|
3390
|
+
const float id = d ? 1.0f/d : 0.0f;
|
3391
|
+
|
3392
|
+
y[i].d = LM_GGML_FP32_TO_FP16(d);
|
3393
|
+
|
3394
|
+
for (size_t j = 0; j < sizeof(y->qs); j += 32) {
|
3395
|
+
for (size_t m = 0; m < 32; ++m) {
|
3396
|
+
uint8_t q = 0;
|
3397
|
+
for (size_t n = 0; n < 4; ++n) {
|
3398
|
+
// -1, 0, 1 -> 0, 1, 2
|
3399
|
+
int xi = lroundf(x[m + n*32] * id) + 1;
|
3400
|
+
q += (xi & 3) << (2*n);
|
3401
|
+
}
|
3402
|
+
y[i].qs[j + m] = q;
|
3403
|
+
}
|
3404
|
+
x += 4*32;
|
3405
|
+
}
|
3406
|
+
}
|
3407
|
+
}
|
3408
|
+
|
3409
|
+
void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
|
3410
|
+
assert(k % QK_K == 0);
|
3411
|
+
block_tq1_0 * restrict y = vy;
|
3412
|
+
quantize_row_tq1_0_ref(x, y, k);
|
3413
|
+
}
|
3414
|
+
|
3415
|
+
void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
|
3416
|
+
assert(k % QK_K == 0);
|
3417
|
+
block_tq2_0 * restrict y = vy;
|
3418
|
+
quantize_row_tq2_0_ref(x, y, k);
|
3419
|
+
}
|
3420
|
+
|
3421
|
+
size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
3422
|
+
(void)quant_weights; // not used
|
3423
|
+
const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_TQ1_0, n_per_row);
|
3424
|
+
quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row);
|
3425
|
+
return nrow * row_size;
|
3426
|
+
}
|
3427
|
+
|
3428
|
+
size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
3429
|
+
(void)quant_weights; // not used
|
3430
|
+
const size_t row_size = lm_ggml_row_size(LM_GGML_TYPE_TQ2_0, n_per_row);
|
3431
|
+
quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row);
|
3432
|
+
return nrow * row_size;
|
3433
|
+
}
|
3434
|
+
|
3435
|
+
|
3436
|
+
void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) {
|
3437
|
+
assert(k % QK_K == 0);
|
3438
|
+
const int64_t nb = k / QK_K;
|
3439
|
+
|
3440
|
+
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
3441
|
+
|
3442
|
+
for (int64_t i = 0; i < nb; ++i) {
|
3443
|
+
|
3444
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d);
|
3445
|
+
|
3446
|
+
for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
|
3447
|
+
for (size_t n = 0; n < 5; ++n) {
|
3448
|
+
for (size_t m = 0; m < 32; ++m) {
|
3449
|
+
uint8_t q = x[i].qs[j + m] * pow3[n];
|
3450
|
+
int16_t xi = ((uint16_t) q * 3) >> 8;
|
3451
|
+
*y++ = (float) (xi - 1) * d;
|
3452
|
+
}
|
3453
|
+
}
|
3454
|
+
}
|
3455
|
+
for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
|
3456
|
+
for (size_t n = 0; n < 5; ++n) {
|
3457
|
+
for (size_t m = 0; m < 16; ++m) {
|
3458
|
+
uint8_t q = x[i].qs[j + m] * pow3[n];
|
3459
|
+
int16_t xi = ((uint16_t) q * 3) >> 8;
|
3460
|
+
*y++ = (float) (xi - 1) * d;
|
3461
|
+
}
|
3462
|
+
}
|
3463
|
+
}
|
3464
|
+
|
3465
|
+
for (size_t n = 0; n < 4; ++n) {
|
3466
|
+
for (size_t j = 0; j < sizeof(x->qh); ++j) {
|
3467
|
+
uint8_t q = x[i].qh[j] * pow3[n];
|
3468
|
+
int16_t xi = ((uint16_t) q * 3) >> 8;
|
3469
|
+
*y++ = (float) (xi - 1) * d;
|
3470
|
+
}
|
3471
|
+
}
|
3472
|
+
}
|
3473
|
+
}
|
3474
|
+
|
3475
|
+
void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) {
|
3476
|
+
assert(k % QK_K == 0);
|
3477
|
+
const int64_t nb = k / QK_K;
|
3478
|
+
|
3479
|
+
for (int64_t i = 0; i < nb; ++i) {
|
3480
|
+
|
3481
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d);
|
3482
|
+
|
3483
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
3484
|
+
for (size_t l = 0; l < 4; ++l) {
|
3485
|
+
for (size_t m = 0; m < 32; ++m) {
|
3486
|
+
int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;
|
3487
|
+
*y++ = (float) (q - 1) * d;
|
3488
|
+
}
|
3489
|
+
}
|
3490
|
+
}
|
3491
|
+
}
|
3492
|
+
}
|
3493
|
+
|
3309
3494
|
// ====================== "True" 2-bit (de)-quantization
|
3310
3495
|
|
3311
3496
|
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
|
@@ -3644,7 +3829,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
|
3644
3829
|
quantize_row_q8_K_ref(x, y, k);
|
3645
3830
|
}
|
3646
3831
|
|
3647
|
-
//===================================== Dot
|
3832
|
+
//===================================== Dot products =================================
|
3648
3833
|
|
3649
3834
|
//
|
3650
3835
|
// Helper functions
|
@@ -5470,6 +5655,501 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
5470
5655
|
*s = sumf;
|
5471
5656
|
}
|
5472
5657
|
|
5658
|
+
void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
5659
|
+
assert(nrc == 1);
|
5660
|
+
UNUSED(nrc);
|
5661
|
+
UNUSED(bx);
|
5662
|
+
UNUSED(by);
|
5663
|
+
UNUSED(bs);
|
5664
|
+
|
5665
|
+
const block_tq1_0 * restrict x = vx;
|
5666
|
+
const block_q8_K * restrict y = vy;
|
5667
|
+
|
5668
|
+
const int nb = n / QK_K;
|
5669
|
+
|
5670
|
+
#if defined(__ARM_NEON)
|
5671
|
+
float sumf = 0.0f;
|
5672
|
+
|
5673
|
+
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
5674
|
+
|
5675
|
+
const uint8x16_t shift = vld1q_u8(k_shift);
|
5676
|
+
|
5677
|
+
for (int i = 0; i < nb; ++i) {
|
5678
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
5679
|
+
int32x4_t sumi0 = vdupq_n_s32(0);
|
5680
|
+
int32x4_t sumi1 = vdupq_n_s32(0);
|
5681
|
+
#else
|
5682
|
+
int16x8_t sumi0 = vdupq_n_s16(0);
|
5683
|
+
int16x8_t sumi1 = vdupq_n_s16(0);
|
5684
|
+
#endif
|
5685
|
+
|
5686
|
+
// first 32 bytes of 5 elements
|
5687
|
+
{
|
5688
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
|
5689
|
+
uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
|
5690
|
+
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
|
5691
|
+
uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
|
5692
|
+
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
|
5693
|
+
uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
|
5694
|
+
uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
|
5695
|
+
uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
|
5696
|
+
uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
|
5697
|
+
uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
|
5698
|
+
|
5699
|
+
// multiply by 3 and keep the 2 bits above 8 bits
|
5700
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
5701
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
5702
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
5703
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
5704
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
5705
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
5706
|
+
int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
|
5707
|
+
int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
|
5708
|
+
int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
|
5709
|
+
int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
|
5710
|
+
|
5711
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
|
5712
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
|
5713
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
|
5714
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
|
5715
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
|
5716
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
|
5717
|
+
const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
|
5718
|
+
const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
|
5719
|
+
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
|
5720
|
+
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
|
5721
|
+
|
5722
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
5723
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
5724
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
5725
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
5726
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
5727
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
5728
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
5729
|
+
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
5730
|
+
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
5731
|
+
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
|
5732
|
+
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
|
5733
|
+
#else
|
5734
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
5735
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
5736
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
5737
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
5738
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
5739
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
5740
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
5741
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
5742
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
5743
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
5744
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
5745
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
5746
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
|
5747
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
5748
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
5749
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
5750
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
|
5751
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
|
5752
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
|
5753
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
|
5754
|
+
#endif
|
5755
|
+
}
|
5756
|
+
|
5757
|
+
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
5758
|
+
{
|
5759
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
|
5760
|
+
uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
|
5761
|
+
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
|
5762
|
+
uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
|
5763
|
+
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
|
5764
|
+
uint32_t qh;
|
5765
|
+
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
5766
|
+
uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
|
5767
|
+
qx5 = vmulq_u8(qx5, shift);
|
5768
|
+
|
5769
|
+
// multiply by 3 and keep the 2 bits above 8 bits
|
5770
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
5771
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
5772
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
5773
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
5774
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
5775
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
5776
|
+
|
5777
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
|
5778
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
|
5779
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
|
5780
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
|
5781
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
|
5782
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
|
5783
|
+
|
5784
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
5785
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
5786
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
5787
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
5788
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
5789
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
5790
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
5791
|
+
#else
|
5792
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
5793
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
5794
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
5795
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
5796
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
5797
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
5798
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
5799
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
5800
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
5801
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
5802
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
5803
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
5804
|
+
#endif
|
5805
|
+
}
|
5806
|
+
|
5807
|
+
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
5808
|
+
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
5809
|
+
|
5810
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
5811
|
+
|
5812
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
5813
|
+
sumi0 = vaddq_s32(sumi0, sumi1);
|
5814
|
+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
5815
|
+
|
5816
|
+
sumf += d * (float) vaddvq_s32(sumi0);
|
5817
|
+
#else
|
5818
|
+
sumi0 = vaddq_s16(sumi0, sumi1);
|
5819
|
+
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
5820
|
+
|
5821
|
+
sumf += d * (float) vaddlvq_s16(sumi0);
|
5822
|
+
#endif
|
5823
|
+
}
|
5824
|
+
|
5825
|
+
*s = sumf;
|
5826
|
+
|
5827
|
+
#elif defined(__AVX2__)
|
5828
|
+
__m256 sumf = _mm256_setzero_ps();
|
5829
|
+
|
5830
|
+
for (int i = 0; i < nb; ++i) {
|
5831
|
+
// 16-bit sums
|
5832
|
+
__m256i sumi0 = _mm256_setzero_si256();
|
5833
|
+
__m256i sumi1 = _mm256_setzero_si256();
|
5834
|
+
__m256i sumi2 = _mm256_setzero_si256();
|
5835
|
+
|
5836
|
+
// first 32 bytes of 5 elements
|
5837
|
+
{
|
5838
|
+
__m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
|
5839
|
+
// 8-bit multiplies with shifts, masks and adds
|
5840
|
+
__m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
|
5841
|
+
__m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
|
5842
|
+
__m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
|
5843
|
+
__m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
|
5844
|
+
|
5845
|
+
// TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
|
5846
|
+
|
5847
|
+
// Cancel the +1 from avg so that it behaves like a halving add
|
5848
|
+
qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
|
5849
|
+
qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
|
5850
|
+
qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
|
5851
|
+
qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
|
5852
|
+
qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
|
5853
|
+
// Multiply by 3 and get the top 2 bits
|
5854
|
+
qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
|
5855
|
+
qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
|
5856
|
+
qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
|
5857
|
+
qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
|
5858
|
+
qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
|
5859
|
+
qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
|
5860
|
+
qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
|
5861
|
+
qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
|
5862
|
+
qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
|
5863
|
+
qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
|
5864
|
+
|
5865
|
+
const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
|
5866
|
+
const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
|
5867
|
+
const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
|
5868
|
+
const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
|
5869
|
+
const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
|
5870
|
+
|
5871
|
+
qx0 = _mm256_maddubs_epi16(qx0, qy0);
|
5872
|
+
qx1 = _mm256_maddubs_epi16(qx1, qy1);
|
5873
|
+
qx2 = _mm256_maddubs_epi16(qx2, qy2);
|
5874
|
+
qx3 = _mm256_maddubs_epi16(qx3, qy3);
|
5875
|
+
qx4 = _mm256_maddubs_epi16(qx4, qy4);
|
5876
|
+
|
5877
|
+
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
|
5878
|
+
sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
|
5879
|
+
sumi2 = _mm256_add_epi16(sumi2, qx4);
|
5880
|
+
}
|
5881
|
+
|
5882
|
+
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
5883
|
+
{
|
5884
|
+
__m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
|
5885
|
+
uint32_t qh;
|
5886
|
+
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
5887
|
+
__m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
|
5888
|
+
__m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
|
5889
|
+
__m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
|
5890
|
+
__m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
|
5891
|
+
__m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
|
5892
|
+
__m256i qx01 = MM256_SET_M128I(qx1, qx0);
|
5893
|
+
__m256i qx23 = MM256_SET_M128I(qx3, qx2);
|
5894
|
+
|
5895
|
+
// avx2 does not have 8-bit multiplies, so 16-bit it is.
|
5896
|
+
qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
|
5897
|
+
qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
|
5898
|
+
__m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
|
5899
|
+
|
5900
|
+
__m256i qx45 = MM256_SET_M128I(qx5, qx4);
|
5901
|
+
|
5902
|
+
// Cancel the +1 from avg so that it behaves like a halving add
|
5903
|
+
qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
|
5904
|
+
qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
|
5905
|
+
qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
|
5906
|
+
// Multiply by 3 and get the top 2 bits
|
5907
|
+
qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
|
5908
|
+
qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
|
5909
|
+
qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
|
5910
|
+
qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
|
5911
|
+
qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
|
5912
|
+
qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
|
5913
|
+
|
5914
|
+
const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
|
5915
|
+
const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
|
5916
|
+
const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
|
5917
|
+
|
5918
|
+
qx01 = _mm256_maddubs_epi16(qx01, qy01);
|
5919
|
+
qx23 = _mm256_maddubs_epi16(qx23, qy23);
|
5920
|
+
qx45 = _mm256_maddubs_epi16(qx45, qy45);
|
5921
|
+
|
5922
|
+
sumi0 = _mm256_add_epi16(sumi0, qx01);
|
5923
|
+
sumi1 = _mm256_add_epi16(sumi1, qx23);
|
5924
|
+
sumi2 = _mm256_add_epi16(sumi2, qx45);
|
5925
|
+
}
|
5926
|
+
|
5927
|
+
const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
|
5928
|
+
const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d));
|
5929
|
+
|
5930
|
+
sumi0 = _mm256_sub_epi16(sumi0, ysum);
|
5931
|
+
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
|
5932
|
+
sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
|
5933
|
+
|
5934
|
+
sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
|
5935
|
+
}
|
5936
|
+
|
5937
|
+
*s = hsum_float_8(sumf);
|
5938
|
+
|
5939
|
+
#else
|
5940
|
+
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
5941
|
+
|
5942
|
+
float sumf = 0.0f;
|
5943
|
+
|
5944
|
+
for (int i = 0; i < nb; ++i) {
|
5945
|
+
int sum = 0;
|
5946
|
+
|
5947
|
+
for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
|
5948
|
+
for (size_t l = 0; l < 5; ++l) {
|
5949
|
+
for (size_t m = 0; m < 32; ++m) {
|
5950
|
+
uint8_t q = x[i].qs[j + m] * pow3[l];
|
5951
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
5952
|
+
sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
|
5953
|
+
}
|
5954
|
+
}
|
5955
|
+
}
|
5956
|
+
for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
|
5957
|
+
for (size_t l = 0; l < 5; ++l) {
|
5958
|
+
for (size_t m = 0; m < 16; ++m) {
|
5959
|
+
uint8_t q = x[i].qs[j + m] * pow3[l];
|
5960
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
5961
|
+
sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
|
5962
|
+
}
|
5963
|
+
}
|
5964
|
+
}
|
5965
|
+
|
5966
|
+
for (size_t l = 0; l < 4; ++l) {
|
5967
|
+
for (size_t j = 0; j < sizeof(x->qh); ++j) {
|
5968
|
+
uint8_t q = x[i].qh[j] * pow3[l];
|
5969
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
5970
|
+
sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
|
5971
|
+
}
|
5972
|
+
}
|
5973
|
+
|
5974
|
+
sumf += (float) sum * (LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d);
|
5975
|
+
}
|
5976
|
+
|
5977
|
+
*s = sumf;
|
5978
|
+
#endif
|
5979
|
+
}
|
5980
|
+
|
5981
|
+
void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
5982
|
+
assert(nrc == 1);
|
5983
|
+
UNUSED(nrc);
|
5984
|
+
UNUSED(bx);
|
5985
|
+
UNUSED(by);
|
5986
|
+
UNUSED(bs);
|
5987
|
+
|
5988
|
+
const block_tq2_0 * restrict x = vx;
|
5989
|
+
const block_q8_K * restrict y = vy;
|
5990
|
+
|
5991
|
+
const int nb = n / QK_K;
|
5992
|
+
|
5993
|
+
#if defined(__ARM_NEON)
|
5994
|
+
float sumf = 0.0f;
|
5995
|
+
|
5996
|
+
const uint8x16_t m3 = vdupq_n_u8(3);
|
5997
|
+
|
5998
|
+
for (int i = 0; i < nb; ++i) {
|
5999
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
6000
|
+
int32x4_t sumi0 = vdupq_n_s32(0);
|
6001
|
+
int32x4_t sumi1 = vdupq_n_s32(0);
|
6002
|
+
#else
|
6003
|
+
int16x8_t sumi0 = vdupq_n_s16(0);
|
6004
|
+
int16x8_t sumi1 = vdupq_n_s16(0);
|
6005
|
+
#endif
|
6006
|
+
|
6007
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
6008
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
|
6009
|
+
uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
|
6010
|
+
uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
|
6011
|
+
uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
|
6012
|
+
uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
|
6013
|
+
uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
|
6014
|
+
uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
|
6015
|
+
uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
|
6016
|
+
|
6017
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
|
6018
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
|
6019
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
|
6020
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
|
6021
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
|
6022
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
|
6023
|
+
int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
|
6024
|
+
int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
|
6025
|
+
|
6026
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
|
6027
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
|
6028
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
|
6029
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
|
6030
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
|
6031
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
|
6032
|
+
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
|
6033
|
+
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
|
6034
|
+
|
6035
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
6036
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
6037
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
6038
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
6039
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
6040
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
6041
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
6042
|
+
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
6043
|
+
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
6044
|
+
#else
|
6045
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
6046
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
6047
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
6048
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
6049
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
6050
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
6051
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
6052
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
6053
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
6054
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
6055
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
6056
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
6057
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
|
6058
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
6059
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
6060
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
6061
|
+
#endif
|
6062
|
+
}
|
6063
|
+
|
6064
|
+
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
6065
|
+
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
6066
|
+
|
6067
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
6068
|
+
|
6069
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
6070
|
+
sumi0 = vaddq_s32(sumi0, sumi1);
|
6071
|
+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
6072
|
+
|
6073
|
+
sumf += d * (float) vaddvq_s32(sumi0);
|
6074
|
+
#else
|
6075
|
+
sumi0 = vaddq_s16(sumi0, sumi1);
|
6076
|
+
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
6077
|
+
|
6078
|
+
sumf += d * (float) vaddlvq_s16(sumi0);
|
6079
|
+
#endif
|
6080
|
+
}
|
6081
|
+
|
6082
|
+
*s = sumf;
|
6083
|
+
|
6084
|
+
#elif defined(__AVX2__)
|
6085
|
+
__m256 sumf = _mm256_setzero_ps();
|
6086
|
+
|
6087
|
+
for (int i = 0; i < nb; ++i) {
|
6088
|
+
// 16-bit sums, because 256*127 still fits
|
6089
|
+
__m256i sumi0 = _mm256_setzero_si256();
|
6090
|
+
__m256i sumi1 = _mm256_setzero_si256();
|
6091
|
+
|
6092
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
6093
|
+
__m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
|
6094
|
+
__m256i qx1 = _mm256_srli_epi16(qx0, 2);
|
6095
|
+
__m256i qx2 = _mm256_srli_epi16(qx0, 4);
|
6096
|
+
__m256i qx3 = _mm256_srli_epi16(qx0, 6);
|
6097
|
+
|
6098
|
+
// 0, 1, 2 (should not be 3)
|
6099
|
+
qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
|
6100
|
+
qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
|
6101
|
+
qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
|
6102
|
+
qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
|
6103
|
+
|
6104
|
+
const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
|
6105
|
+
const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
|
6106
|
+
const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
|
6107
|
+
const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
|
6108
|
+
|
6109
|
+
qx0 = _mm256_maddubs_epi16(qx0, qy0);
|
6110
|
+
qx1 = _mm256_maddubs_epi16(qx1, qy1);
|
6111
|
+
qx2 = _mm256_maddubs_epi16(qx2, qy2);
|
6112
|
+
qx3 = _mm256_maddubs_epi16(qx3, qy3);
|
6113
|
+
|
6114
|
+
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
|
6115
|
+
sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
|
6116
|
+
}
|
6117
|
+
|
6118
|
+
const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
|
6119
|
+
const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_FP16_TO_FP32(x[i].d));
|
6120
|
+
|
6121
|
+
sumi0 = _mm256_add_epi16(sumi0, sumi1);
|
6122
|
+
sumi0 = _mm256_sub_epi16(sumi0, ysum);
|
6123
|
+
sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
|
6124
|
+
|
6125
|
+
sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
|
6126
|
+
}
|
6127
|
+
|
6128
|
+
*s = hsum_float_8(sumf);
|
6129
|
+
|
6130
|
+
#else
|
6131
|
+
float sumf = 0.0f;
|
6132
|
+
|
6133
|
+
for (int i = 0; i < nb; ++i) {
|
6134
|
+
int32_t sumi = 0;
|
6135
|
+
|
6136
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
6137
|
+
for (size_t l = 0; l < 4; ++l) {
|
6138
|
+
for (size_t k = 0; k < 32; ++k) {
|
6139
|
+
sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
|
6140
|
+
}
|
6141
|
+
}
|
6142
|
+
}
|
6143
|
+
|
6144
|
+
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
|
6145
|
+
|
6146
|
+
sumf += (float) sumi * d;
|
6147
|
+
}
|
6148
|
+
|
6149
|
+
*s = sumf;
|
6150
|
+
#endif
|
6151
|
+
}
|
6152
|
+
|
5473
6153
|
void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
5474
6154
|
assert(nrc == 1);
|
5475
6155
|
UNUSED(nrc);
|
@@ -14800,6 +15480,14 @@ bool lm_ggml_validate_row_data(enum lm_ggml_type type, const void * data, size_t
|
|
14800
15480
|
}
|
14801
15481
|
}
|
14802
15482
|
} break;
|
15483
|
+
case LM_GGML_TYPE_TQ1_0:
|
15484
|
+
{
|
15485
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);
|
15486
|
+
} break;
|
15487
|
+
case LM_GGML_TYPE_TQ2_0:
|
15488
|
+
{
|
15489
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
|
15490
|
+
} break;
|
14803
15491
|
case LM_GGML_TYPE_IQ1_S:
|
14804
15492
|
{
|
14805
15493
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
|