llama_cpp 0.12.5 → 0.12.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -0
- data/ext/llama_cpp/llama_cpp.cpp +67 -10
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +15 -1
- data/vendor/tmp/llama.cpp/Makefile +51 -12
- data/vendor/tmp/llama.cpp/ggml-alloc.c +595 -492
- data/vendor/tmp/llama.cpp/ggml-alloc.h +39 -65
- data/vendor/tmp/llama.cpp/ggml-backend.c +268 -271
- data/vendor/tmp/llama.cpp/ggml-backend.h +8 -12
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +560 -346
- data/vendor/tmp/llama.cpp/ggml-impl.h +20 -7
- data/vendor/tmp/llama.cpp/ggml-metal.m +101 -11
- data/vendor/tmp/llama.cpp/ggml-metal.metal +608 -9
- data/vendor/tmp/llama.cpp/ggml-quants.c +1255 -94
- data/vendor/tmp/llama.cpp/ggml-quants.h +39 -16
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +95 -264
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +213 -58
- data/vendor/tmp/llama.cpp/ggml.c +1082 -564
- data/vendor/tmp/llama.cpp/ggml.h +50 -17
- data/vendor/tmp/llama.cpp/llama.cpp +1329 -280
- data/vendor/tmp/llama.cpp/llama.h +43 -1
- data/vendor/tmp/llama.cpp/scripts/get-flags.mk +1 -1
- data/vendor/tmp/llama.cpp/unicode.h +42 -30
- metadata +2 -2
@@ -49,6 +49,8 @@
|
|
49
49
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
50
50
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
51
51
|
|
52
|
+
#define UNUSED GGML_UNUSED
|
53
|
+
|
52
54
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
53
55
|
|
54
56
|
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
@@ -268,6 +270,17 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
|
|
268
270
|
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
269
271
|
|
270
272
|
#if defined(__ARM_NEON)
|
273
|
+
|
274
|
+
#ifdef _MSC_VER
|
275
|
+
|
276
|
+
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
|
277
|
+
|
278
|
+
#else
|
279
|
+
|
280
|
+
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
|
281
|
+
|
282
|
+
#endif
|
283
|
+
|
271
284
|
#if !defined(__aarch64__)
|
272
285
|
|
273
286
|
// 64-bit compatibility
|
@@ -425,6 +438,30 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
|
425
438
|
return res;
|
426
439
|
}
|
427
440
|
|
441
|
+
// NOTE: not tested
|
442
|
+
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
|
443
|
+
int8x16_t res;
|
444
|
+
|
445
|
+
res[ 0] = a[b[ 0]];
|
446
|
+
res[ 1] = a[b[ 1]];
|
447
|
+
res[ 2] = a[b[ 2]];
|
448
|
+
res[ 3] = a[b[ 3]];
|
449
|
+
res[ 4] = a[b[ 4]];
|
450
|
+
res[ 5] = a[b[ 5]];
|
451
|
+
res[ 6] = a[b[ 6]];
|
452
|
+
res[ 7] = a[b[ 7]];
|
453
|
+
res[ 8] = a[b[ 8]];
|
454
|
+
res[ 9] = a[b[ 9]];
|
455
|
+
res[10] = a[b[10]];
|
456
|
+
res[11] = a[b[11]];
|
457
|
+
res[12] = a[b[12]];
|
458
|
+
res[13] = a[b[13]];
|
459
|
+
res[14] = a[b[14]];
|
460
|
+
res[15] = a[b[15]];
|
461
|
+
|
462
|
+
return res;
|
463
|
+
}
|
464
|
+
|
428
465
|
#else
|
429
466
|
|
430
467
|
#define ggml_int16x8x2_t int16x8x2_t
|
@@ -438,6 +475,7 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
|
438
475
|
#define ggml_vld1q_u8_x4 vld1q_u8_x4
|
439
476
|
#define ggml_vld1q_s8_x2 vld1q_s8_x2
|
440
477
|
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
478
|
+
#define ggml_vqtbl1q_s8 vqtbl1q_s8
|
441
479
|
|
442
480
|
#endif
|
443
481
|
|
@@ -1824,9 +1862,9 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
|
|
1824
1862
|
float sigma2 = sumx2/QK_K;
|
1825
1863
|
for (int j = 0; j < QK_K/16; ++j) {
|
1826
1864
|
const float * restrict qw = quant_weights + QK_K * i + 16*j;
|
1827
|
-
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
|
1828
|
-
for (int l = 0; l < 16; ++l) sw[j] += weight[l];
|
1829
|
-
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
1865
|
+
for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
|
1866
|
+
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
|
1867
|
+
scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
1830
1868
|
}
|
1831
1869
|
|
1832
1870
|
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
|
@@ -3467,6 +3505,139 @@ static const uint32_t iq3xxs_grid[256] = {
|
|
3467
3505
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
3468
3506
|
};
|
3469
3507
|
|
3508
|
+
#define NGRID_IQ2XXS 512
|
3509
|
+
static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
|
3510
|
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
3511
|
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
3512
|
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
3513
|
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
3514
|
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
3515
|
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
3516
|
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
3517
|
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
3518
|
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
3519
|
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
3520
|
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
3521
|
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
3522
|
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
3523
|
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
3524
|
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
3525
|
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
3526
|
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
3527
|
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
3528
|
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
3529
|
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
3530
|
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
3531
|
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
3532
|
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
3533
|
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
3534
|
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
3535
|
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
3536
|
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
3537
|
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
3538
|
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
3539
|
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
3540
|
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
3541
|
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
3542
|
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
3543
|
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
3544
|
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
3545
|
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
3546
|
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
3547
|
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
3548
|
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
3549
|
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
3550
|
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
3551
|
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
3552
|
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
3553
|
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
3554
|
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
3555
|
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
3556
|
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
3557
|
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
3558
|
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
3559
|
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
3560
|
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
3561
|
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
3562
|
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
3563
|
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
3564
|
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
3565
|
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
3566
|
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
3567
|
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
3568
|
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
3569
|
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
3570
|
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
3571
|
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
3572
|
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
3573
|
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
3574
|
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
3575
|
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
3576
|
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
3577
|
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
3578
|
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
3579
|
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
3580
|
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
3581
|
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
3582
|
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
3583
|
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
3584
|
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
3585
|
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
3586
|
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
3587
|
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
3588
|
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
3589
|
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
3590
|
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
3591
|
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
3592
|
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
3593
|
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
3594
|
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
3595
|
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
3596
|
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
3597
|
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
3598
|
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
3599
|
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
3600
|
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
3601
|
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
3602
|
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
3603
|
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
3604
|
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
3605
|
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
3606
|
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
3607
|
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
3608
|
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
3609
|
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
3610
|
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
3611
|
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
3612
|
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
3613
|
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
3614
|
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
3615
|
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
3616
|
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
3617
|
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
3618
|
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
3619
|
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
3620
|
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
3621
|
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
3622
|
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
3623
|
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
3624
|
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
3625
|
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
3626
|
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
3627
|
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
3628
|
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
3629
|
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
3630
|
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
3631
|
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
3632
|
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
3633
|
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
3634
|
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
3635
|
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
3636
|
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
3637
|
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
3638
|
+
|
3639
|
+
};
|
3640
|
+
|
3470
3641
|
static const uint8_t ksigns_iq2xs[128] = {
|
3471
3642
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
3472
3643
|
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
@@ -3565,6 +3736,69 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
|
|
3565
3736
|
}
|
3566
3737
|
}
|
3567
3738
|
|
3739
|
+
// ====================== 1.5625 bpw (de)-quantization
|
3740
|
+
|
3741
|
+
void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
|
3742
|
+
assert(k % QK_K == 0);
|
3743
|
+
const int nb = k / QK_K;
|
3744
|
+
|
3745
|
+
float db[4];
|
3746
|
+
uint16_t idx[4];
|
3747
|
+
//const int8_t * grid[4];
|
3748
|
+
|
3749
|
+
for (int i = 0; i < nb; i++) {
|
3750
|
+
|
3751
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3752
|
+
const uint8_t * sc = x[i].scales;
|
3753
|
+
const uint8_t * qs = x[i].qs;
|
3754
|
+
|
3755
|
+
for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
|
3756
|
+
idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
|
3757
|
+
idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
|
3758
|
+
idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
|
3759
|
+
idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
|
3760
|
+
//grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
3761
|
+
//grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
3762
|
+
//grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
|
3763
|
+
//grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
|
3764
|
+
db[0] = d * (2*(sc[0] & 7) + 1);
|
3765
|
+
db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
|
3766
|
+
db[2] = d * (2*(sc[1] & 7) + 1);
|
3767
|
+
db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
|
3768
|
+
for (int l = 0; l < 4; ++l) {
|
3769
|
+
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
|
3770
|
+
for (int j = 0; j < 8; ++j) {
|
3771
|
+
//y[j] = db[l] * grid[l][j];
|
3772
|
+
y[j] = db[l] * grid[j];
|
3773
|
+
}
|
3774
|
+
y += 8;
|
3775
|
+
}
|
3776
|
+
qs += 4;
|
3777
|
+
sc += 2;
|
3778
|
+
}
|
3779
|
+
}
|
3780
|
+
}
|
3781
|
+
|
3782
|
+
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
3783
|
+
|
3784
|
+
void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
|
3785
|
+
assert(k % QK4_NL == 0);
|
3786
|
+
const int nb = k / QK4_NL;
|
3787
|
+
|
3788
|
+
for (int i = 0; i < nb; i++) {
|
3789
|
+
|
3790
|
+
const uint8_t * qs = x[i].qs;
|
3791
|
+
|
3792
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
3793
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
3794
|
+
y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf];
|
3795
|
+
y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4];
|
3796
|
+
}
|
3797
|
+
y += QK4_NL;
|
3798
|
+
qs += QK4_NL/2;
|
3799
|
+
}
|
3800
|
+
}
|
3801
|
+
|
3568
3802
|
//===================================== Q8_K ==============================================
|
3569
3803
|
|
3570
3804
|
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
@@ -3666,15 +3900,92 @@ static inline __m128i get_scale_shuffle(int i) {
|
|
3666
3900
|
}
|
3667
3901
|
#endif
|
3668
3902
|
|
3669
|
-
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3903
|
+
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
3670
3904
|
const int qk = QK8_0;
|
3671
3905
|
const int nb = n / qk;
|
3672
3906
|
|
3673
3907
|
assert(n % qk == 0);
|
3908
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
3909
|
+
assert((nrc == 2) || (nrc == 1));
|
3910
|
+
#else
|
3911
|
+
assert(nrc == 1);
|
3912
|
+
#endif
|
3913
|
+
UNUSED(nrc);
|
3914
|
+
UNUSED(bx);
|
3915
|
+
UNUSED(by);
|
3916
|
+
UNUSED(bs);
|
3674
3917
|
|
3675
3918
|
const block_q4_0 * restrict x = vx;
|
3676
3919
|
const block_q8_0 * restrict y = vy;
|
3677
3920
|
|
3921
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
3922
|
+
if (nrc == 2) {
|
3923
|
+
const block_q4_0 * restrict vx0 = vx;
|
3924
|
+
const block_q4_0 * restrict vx1 = vx + bx;
|
3925
|
+
|
3926
|
+
const block_q8_0 * restrict vy0 = vy;
|
3927
|
+
const block_q8_0 * restrict vy1 = vy + by;
|
3928
|
+
|
3929
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3930
|
+
|
3931
|
+
for (int i = 0; i < nb; i++) {
|
3932
|
+
const block_q4_0 * restrict b_x0 = &vx0[i];
|
3933
|
+
const block_q4_0 * restrict b_x1 = &vx1[i];
|
3934
|
+
const block_q8_0 * restrict b_y0 = &vy0[i];
|
3935
|
+
const block_q8_0 * restrict b_y1 = &vy1[i];
|
3936
|
+
|
3937
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
3938
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
3939
|
+
|
3940
|
+
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
|
3941
|
+
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
|
3942
|
+
|
3943
|
+
// 4-bit -> 8-bit
|
3944
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
3945
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
3946
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
3947
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
3948
|
+
|
3949
|
+
// sub 8
|
3950
|
+
const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
|
3951
|
+
const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
|
3952
|
+
const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
|
3953
|
+
const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
|
3954
|
+
|
3955
|
+
// load y
|
3956
|
+
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
|
3957
|
+
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
|
3958
|
+
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
3959
|
+
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
3960
|
+
|
3961
|
+
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
3962
|
+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
3963
|
+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
3964
|
+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
|
3965
|
+
|
3966
|
+
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
3967
|
+
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
3968
|
+
|
3969
|
+
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
3970
|
+
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
3971
|
+
|
3972
|
+
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
3973
|
+
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
3974
|
+
|
3975
|
+
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
3976
|
+
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
3977
|
+
|
3978
|
+
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
3979
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
3980
|
+
}
|
3981
|
+
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
3982
|
+
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
3983
|
+
|
3984
|
+
vst1_f32(s, vget_low_f32(sumv2));
|
3985
|
+
vst1_f32(s + bs, vget_high_f32(sumv2));
|
3986
|
+
return;
|
3987
|
+
}
|
3988
|
+
#endif
|
3678
3989
|
#if defined(__ARM_NEON)
|
3679
3990
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
3680
3991
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
@@ -3729,15 +4040,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
|
|
3729
4040
|
/* Compute combined scale for the block */
|
3730
4041
|
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
3731
4042
|
|
3732
|
-
__m256i
|
4043
|
+
__m256i qx = bytes_from_nibbles_32(x[i].qs);
|
3733
4044
|
|
3734
4045
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
3735
4046
|
const __m256i off = _mm256_set1_epi8( 8 );
|
3736
|
-
|
4047
|
+
qx = _mm256_sub_epi8( qx, off );
|
3737
4048
|
|
3738
|
-
__m256i
|
4049
|
+
__m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
3739
4050
|
|
3740
|
-
const __m256 q = mul_sum_i8_pairs_float(
|
4051
|
+
const __m256 q = mul_sum_i8_pairs_float(qx, qy);
|
3741
4052
|
|
3742
4053
|
/* Multiply q with scale and accumulate */
|
3743
4054
|
acc = _mm256_fmadd_ps( d, q, acc );
|
@@ -3758,15 +4069,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
|
|
3758
4069
|
|
3759
4070
|
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
|
3760
4071
|
|
3761
|
-
__m128i
|
3762
|
-
__m128i
|
3763
|
-
|
3764
|
-
const __m128i i32_0 = mul_sum_i8_pairs(
|
4072
|
+
__m128i bx_0 = _mm_and_si128(lowMask, tmp);
|
4073
|
+
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
|
4074
|
+
bx_0 = _mm_sub_epi8(bx_0, off);
|
4075
|
+
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
3765
4076
|
|
3766
|
-
|
3767
|
-
|
3768
|
-
|
3769
|
-
const __m128i i32_1 = mul_sum_i8_pairs(
|
4077
|
+
bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
4078
|
+
by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
4079
|
+
bx_0 = _mm_sub_epi8(bx_0, off);
|
4080
|
+
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
|
3770
4081
|
|
3771
4082
|
// Convert int32_t to float
|
3772
4083
|
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
|
@@ -3956,15 +4267,93 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
|
|
3956
4267
|
#endif
|
3957
4268
|
}
|
3958
4269
|
|
3959
|
-
void ggml_vec_dot_q4_1_q8_1(
|
4270
|
+
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
3960
4271
|
const int qk = QK8_1;
|
3961
4272
|
const int nb = n / qk;
|
3962
4273
|
|
3963
4274
|
assert(n % qk == 0);
|
4275
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
4276
|
+
assert((nrc == 2) || (nrc == 1));
|
4277
|
+
#else
|
4278
|
+
assert(nrc == 1);
|
4279
|
+
#endif
|
4280
|
+
UNUSED(nrc);
|
4281
|
+
UNUSED(bx);
|
4282
|
+
UNUSED(by);
|
4283
|
+
UNUSED(bs);
|
3964
4284
|
|
3965
4285
|
const block_q4_1 * restrict x = vx;
|
3966
4286
|
const block_q8_1 * restrict y = vy;
|
3967
4287
|
|
4288
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
4289
|
+
if (nrc == 2) {
|
4290
|
+
const block_q4_1 * restrict vx0 = vx;
|
4291
|
+
const block_q4_1 * restrict vx1 = vx + bx;
|
4292
|
+
const block_q8_1 * restrict vy0 = vy;
|
4293
|
+
const block_q8_1 * restrict vy1 = vy + by;
|
4294
|
+
|
4295
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
4296
|
+
float32x4_t summs0 = vdupq_n_f32(0.0f);
|
4297
|
+
|
4298
|
+
for (int i = 0; i < nb; i++) {
|
4299
|
+
const block_q4_1 * restrict b_x0 = &vx0[i];
|
4300
|
+
const block_q4_1 * restrict b_x1 = &vx1[i];
|
4301
|
+
const block_q8_1 * restrict b_y0 = &vy0[i];
|
4302
|
+
const block_q8_1 * restrict b_y1 = &vy1[i];
|
4303
|
+
|
4304
|
+
float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
|
4305
|
+
GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
|
4306
|
+
GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
|
4307
|
+
GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
|
4308
|
+
summs0 += summs_t;
|
4309
|
+
|
4310
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
4311
|
+
|
4312
|
+
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
|
4313
|
+
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
|
4314
|
+
|
4315
|
+
// 4-bit -> 8-bit
|
4316
|
+
const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
4317
|
+
const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
4318
|
+
const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
4319
|
+
const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
4320
|
+
|
4321
|
+
// load y
|
4322
|
+
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
|
4323
|
+
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
|
4324
|
+
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
4325
|
+
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
4326
|
+
|
4327
|
+
// mmla into int32x4_t
|
4328
|
+
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
4329
|
+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
4330
|
+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
4331
|
+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
|
4332
|
+
|
4333
|
+
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
4334
|
+
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
4335
|
+
|
4336
|
+
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
4337
|
+
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
4338
|
+
|
4339
|
+
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
4340
|
+
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
4341
|
+
|
4342
|
+
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
4343
|
+
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
4344
|
+
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
4345
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
4346
|
+
}
|
4347
|
+
|
4348
|
+
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
4349
|
+
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
4350
|
+
sumv2 = sumv2 + summs0;
|
4351
|
+
|
4352
|
+
vst1_f32(s, vget_low_f32(sumv2));
|
4353
|
+
vst1_f32(s + bs, vget_high_f32(sumv2));
|
4354
|
+
return;
|
4355
|
+
}
|
4356
|
+
#endif
|
3968
4357
|
// TODO: add WASM SIMD
|
3969
4358
|
#if defined(__ARM_NEON)
|
3970
4359
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
@@ -4028,10 +4417,10 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
|
|
4028
4417
|
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
4029
4418
|
|
4030
4419
|
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
4031
|
-
const __m256i
|
4032
|
-
const __m256i
|
4420
|
+
const __m256i qx = bytes_from_nibbles_32(x[i].qs);
|
4421
|
+
const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
4033
4422
|
|
4034
|
-
const __m256 xy = mul_sum_us8_pairs_float(
|
4423
|
+
const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
|
4035
4424
|
|
4036
4425
|
// Accumulate d0*d1*x*y
|
4037
4426
|
#if defined(__AVX2__)
|
@@ -4096,12 +4485,17 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
|
|
4096
4485
|
#endif
|
4097
4486
|
}
|
4098
4487
|
|
4099
|
-
void ggml_vec_dot_q5_0_q8_0(
|
4488
|
+
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
4100
4489
|
const int qk = QK8_0;
|
4101
4490
|
const int nb = n / qk;
|
4102
4491
|
|
4103
4492
|
assert(n % qk == 0);
|
4104
4493
|
assert(qk == QK5_0);
|
4494
|
+
assert(nrc == 1);
|
4495
|
+
UNUSED(nrc);
|
4496
|
+
UNUSED(bx);
|
4497
|
+
UNUSED(by);
|
4498
|
+
UNUSED(bs);
|
4105
4499
|
|
4106
4500
|
const block_q5_0 * restrict x = vx;
|
4107
4501
|
const block_q8_0 * restrict y = vy;
|
@@ -4245,14 +4639,14 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|
4245
4639
|
/* Compute combined scale for the block */
|
4246
4640
|
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
|
4247
4641
|
|
4248
|
-
__m256i
|
4642
|
+
__m256i qx = bytes_from_nibbles_32(x[i].qs);
|
4249
4643
|
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
4250
4644
|
bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
|
4251
|
-
|
4645
|
+
qx = _mm256_or_si256(qx, bxhi);
|
4252
4646
|
|
4253
|
-
__m256i
|
4647
|
+
__m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
4254
4648
|
|
4255
|
-
const __m256 q = mul_sum_i8_pairs_float(
|
4649
|
+
const __m256 q = mul_sum_i8_pairs_float(qx, qy);
|
4256
4650
|
|
4257
4651
|
/* Multiply q with scale and accumulate */
|
4258
4652
|
acc = _mm256_fmadd_ps(d, q, acc);
|
@@ -4269,21 +4663,21 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|
4269
4663
|
/* Compute combined scale for the block */
|
4270
4664
|
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
|
4271
4665
|
|
4272
|
-
__m256i
|
4666
|
+
__m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
|
4273
4667
|
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
4274
4668
|
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
4275
4669
|
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
4276
4670
|
bxhil = _mm_andnot_si128(bxhil, mask);
|
4277
4671
|
bxhih = _mm_andnot_si128(bxhih, mask);
|
4278
|
-
__m128i bxl = _mm256_castsi256_si128(
|
4279
|
-
__m128i bxh = _mm256_extractf128_si256(
|
4672
|
+
__m128i bxl = _mm256_castsi256_si128(bx_0);
|
4673
|
+
__m128i bxh = _mm256_extractf128_si256(bx_0, 1);
|
4280
4674
|
bxl = _mm_or_si128(bxl, bxhil);
|
4281
4675
|
bxh = _mm_or_si128(bxh, bxhih);
|
4282
|
-
|
4676
|
+
bx_0 = MM256_SET_M128I(bxh, bxl);
|
4283
4677
|
|
4284
|
-
const __m256i
|
4678
|
+
const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
4285
4679
|
|
4286
|
-
const __m256 q = mul_sum_i8_pairs_float(
|
4680
|
+
const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
|
4287
4681
|
|
4288
4682
|
/* Multiply q with scale and accumulate */
|
4289
4683
|
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
|
@@ -4382,12 +4776,17 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|
4382
4776
|
#endif
|
4383
4777
|
}
|
4384
4778
|
|
4385
|
-
void ggml_vec_dot_q5_1_q8_1(
|
4779
|
+
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
4386
4780
|
const int qk = QK8_1;
|
4387
4781
|
const int nb = n / qk;
|
4388
4782
|
|
4389
4783
|
assert(n % qk == 0);
|
4390
4784
|
assert(qk == QK5_1);
|
4785
|
+
assert(nrc == 1);
|
4786
|
+
UNUSED(nrc);
|
4787
|
+
UNUSED(bx);
|
4788
|
+
UNUSED(by);
|
4789
|
+
UNUSED(bs);
|
4391
4790
|
|
4392
4791
|
const block_q5_1 * restrict x = vx;
|
4393
4792
|
const block_q8_1 * restrict y = vy;
|
@@ -4544,15 +4943,15 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
|
|
4544
4943
|
|
4545
4944
|
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
4546
4945
|
|
4547
|
-
__m256i
|
4946
|
+
__m256i qx = bytes_from_nibbles_32(x[i].qs);
|
4548
4947
|
__m256i bxhi = bytes_from_bits_32(x[i].qh);
|
4549
4948
|
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
|
4550
|
-
|
4949
|
+
qx = _mm256_or_si256(qx, bxhi);
|
4551
4950
|
|
4552
4951
|
const __m256 dy = _mm256_set1_ps(y[i].d);
|
4553
|
-
const __m256i
|
4952
|
+
const __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
4554
4953
|
|
4555
|
-
const __m256 q = mul_sum_us8_pairs_float(
|
4954
|
+
const __m256 q = mul_sum_us8_pairs_float(qx, qy);
|
4556
4955
|
|
4557
4956
|
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
4558
4957
|
}
|
@@ -4571,22 +4970,22 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
|
|
4571
4970
|
|
4572
4971
|
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
4573
4972
|
|
4574
|
-
__m256i
|
4973
|
+
__m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
|
4575
4974
|
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
4576
4975
|
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
4577
4976
|
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
4578
4977
|
bxhil = _mm_and_si128(bxhil, mask);
|
4579
4978
|
bxhih = _mm_and_si128(bxhih, mask);
|
4580
|
-
__m128i bxl = _mm256_castsi256_si128(
|
4581
|
-
__m128i bxh = _mm256_extractf128_si256(
|
4979
|
+
__m128i bxl = _mm256_castsi256_si128(bx_0);
|
4980
|
+
__m128i bxh = _mm256_extractf128_si256(bx_0, 1);
|
4582
4981
|
bxl = _mm_or_si128(bxl, bxhil);
|
4583
4982
|
bxh = _mm_or_si128(bxh, bxhih);
|
4584
|
-
|
4983
|
+
bx_0 = MM256_SET_M128I(bxh, bxl);
|
4585
4984
|
|
4586
4985
|
const __m256 dy = _mm256_set1_ps(y[i].d);
|
4587
|
-
const __m256i
|
4986
|
+
const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
4588
4987
|
|
4589
|
-
const __m256 q = mul_sum_us8_pairs_float(
|
4988
|
+
const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
|
4590
4989
|
|
4591
4990
|
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
|
4592
4991
|
}
|
@@ -4681,15 +5080,79 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
|
|
4681
5080
|
#endif
|
4682
5081
|
}
|
4683
5082
|
|
4684
|
-
void ggml_vec_dot_q8_0_q8_0(
|
5083
|
+
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
4685
5084
|
const int qk = QK8_0;
|
4686
5085
|
const int nb = n / qk;
|
4687
5086
|
|
4688
5087
|
assert(n % qk == 0);
|
5088
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
5089
|
+
assert((nrc == 2) || (nrc == 1));
|
5090
|
+
#else
|
5091
|
+
assert(nrc == 1);
|
5092
|
+
#endif
|
5093
|
+
UNUSED(nrc);
|
5094
|
+
UNUSED(bx);
|
5095
|
+
UNUSED(by);
|
5096
|
+
UNUSED(bs);
|
4689
5097
|
|
4690
5098
|
const block_q8_0 * restrict x = vx;
|
4691
5099
|
const block_q8_0 * restrict y = vy;
|
4692
5100
|
|
5101
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
5102
|
+
if (nrc == 2) {
|
5103
|
+
const block_q8_0 * restrict vx0 = vx;
|
5104
|
+
const block_q8_0 * restrict vx1 = vx + bx;
|
5105
|
+
const block_q8_0 * restrict vy0 = vy;
|
5106
|
+
const block_q8_0 * restrict vy1 = vy + by;
|
5107
|
+
|
5108
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
5109
|
+
|
5110
|
+
for (int i = 0; i < nb; i++) {
|
5111
|
+
const block_q8_0 * restrict b_x0 = &vx0[i];
|
5112
|
+
const block_q8_0 * restrict b_y0 = &vy0[i];
|
5113
|
+
|
5114
|
+
const block_q8_0 * restrict b_x1 = &vx1[i];
|
5115
|
+
const block_q8_0 * restrict b_y1 = &vy1[i];
|
5116
|
+
|
5117
|
+
const int8x16_t x0_l = vld1q_s8(b_x0->qs);
|
5118
|
+
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
|
5119
|
+
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
|
5120
|
+
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
|
5121
|
+
|
5122
|
+
// load y
|
5123
|
+
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
|
5124
|
+
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
|
5125
|
+
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
5126
|
+
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
5127
|
+
|
5128
|
+
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
5129
|
+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
5130
|
+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
5131
|
+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
|
5132
|
+
|
5133
|
+
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
5134
|
+
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
5135
|
+
|
5136
|
+
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
5137
|
+
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
5138
|
+
|
5139
|
+
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
5140
|
+
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
5141
|
+
|
5142
|
+
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
5143
|
+
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
5144
|
+
|
5145
|
+
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
5146
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
5147
|
+
}
|
5148
|
+
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
5149
|
+
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
5150
|
+
|
5151
|
+
vst1_f32(s, vget_low_f32(sumv2));
|
5152
|
+
vst1_f32(s + bs, vget_high_f32(sumv2));
|
5153
|
+
return;
|
5154
|
+
}
|
5155
|
+
#endif
|
4693
5156
|
#if defined(__ARM_NEON)
|
4694
5157
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
4695
5158
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
@@ -4731,10 +5194,10 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
4731
5194
|
for (int i = 0; i < nb; ++i) {
|
4732
5195
|
// Compute combined scale for the block
|
4733
5196
|
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
|
4734
|
-
__m256i
|
4735
|
-
__m256i
|
5197
|
+
__m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs);
|
5198
|
+
__m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
4736
5199
|
|
4737
|
-
const __m256 q = mul_sum_i8_pairs_float(
|
5200
|
+
const __m256 q = mul_sum_i8_pairs_float(qx, qy);
|
4738
5201
|
|
4739
5202
|
// Multiply q with scale and accumulate
|
4740
5203
|
#if defined(__AVX2__)
|
@@ -4751,10 +5214,10 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
4751
5214
|
|
4752
5215
|
for (int i = 0; i < nb; i++) {
|
4753
5216
|
// load elements
|
4754
|
-
vint8m1_t
|
4755
|
-
vint8m1_t
|
5217
|
+
vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl);
|
5218
|
+
vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
|
4756
5219
|
|
4757
|
-
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(
|
5220
|
+
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
|
4758
5221
|
|
4759
5222
|
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
4760
5223
|
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
|
@@ -4784,7 +5247,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|
4784
5247
|
}
|
4785
5248
|
|
4786
5249
|
#if QK_K == 256
|
4787
|
-
void ggml_vec_dot_q2_K_q8_K(
|
5250
|
+
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
5251
|
+
assert(nrc == 1);
|
5252
|
+
UNUSED(nrc);
|
5253
|
+
UNUSED(bx);
|
5254
|
+
UNUSED(by);
|
5255
|
+
UNUSED(bs);
|
4788
5256
|
|
4789
5257
|
const block_q2_K * restrict x = vx;
|
4790
5258
|
const block_q8_K * restrict y = vy;
|
@@ -5160,7 +5628,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5160
5628
|
|
5161
5629
|
#else
|
5162
5630
|
|
5163
|
-
void ggml_vec_dot_q2_K_q8_K(
|
5631
|
+
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
5632
|
+
assert(nrc == 1);
|
5633
|
+
UNUSED(nrc);
|
5634
|
+
UNUSED(bx);
|
5635
|
+
UNUSED(by);
|
5636
|
+
UNUSED(bs);
|
5164
5637
|
|
5165
5638
|
const block_q2_K * restrict x = vx;
|
5166
5639
|
const block_q8_K * restrict y = vy;
|
@@ -5181,8 +5654,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5181
5654
|
|
5182
5655
|
for (int i = 0; i < nb; ++i) {
|
5183
5656
|
|
5184
|
-
const float d
|
5185
|
-
const float dmin = -y[i].d * (
|
5657
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
5658
|
+
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
5186
5659
|
|
5187
5660
|
const uint8_t * restrict q2 = x[i].qs;
|
5188
5661
|
const int8_t * restrict q8 = y[i].qs;
|
@@ -5331,8 +5804,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5331
5804
|
|
5332
5805
|
for (int i = 0; i < nb; ++i) {
|
5333
5806
|
|
5334
|
-
const float d
|
5335
|
-
const float dmin = -y[i].d * (
|
5807
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
5808
|
+
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
5336
5809
|
|
5337
5810
|
const uint8_t * restrict q2 = x[i].qs;
|
5338
5811
|
const int8_t * restrict q8 = y[i].qs;
|
@@ -5418,8 +5891,13 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5418
5891
|
#endif
|
5419
5892
|
|
5420
5893
|
#if QK_K == 256
|
5421
|
-
void ggml_vec_dot_q3_K_q8_K(
|
5894
|
+
void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
5422
5895
|
assert(n % QK_K == 0);
|
5896
|
+
assert(nrc == 1);
|
5897
|
+
UNUSED(nrc);
|
5898
|
+
UNUSED(bx);
|
5899
|
+
UNUSED(by);
|
5900
|
+
UNUSED(bs);
|
5423
5901
|
|
5424
5902
|
const uint32_t kmask1 = 0x03030303;
|
5425
5903
|
const uint32_t kmask2 = 0x0f0f0f0f;
|
@@ -5938,8 +6416,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5938
6416
|
|
5939
6417
|
#else
|
5940
6418
|
|
5941
|
-
void ggml_vec_dot_q3_K_q8_K(
|
6419
|
+
void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
5942
6420
|
assert(n % QK_K == 0);
|
6421
|
+
assert(nrc == 1);
|
6422
|
+
UNUSED(nrc);
|
6423
|
+
UNUSED(bx);
|
6424
|
+
UNUSED(by);
|
6425
|
+
UNUSED(bs);
|
5943
6426
|
|
5944
6427
|
const block_q3_K * restrict x = vx;
|
5945
6428
|
const block_q8_K * restrict y = vy;
|
@@ -5975,7 +6458,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
5975
6458
|
|
5976
6459
|
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
|
5977
6460
|
|
5978
|
-
const float d = y[i].d * (
|
6461
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
5979
6462
|
|
5980
6463
|
const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
|
5981
6464
|
q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
|
@@ -6177,7 +6660,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6177
6660
|
|
6178
6661
|
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
|
6179
6662
|
|
6180
|
-
const float d = y[i].d * (
|
6663
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
6181
6664
|
|
6182
6665
|
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
6183
6666
|
|
@@ -6281,8 +6764,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6281
6764
|
#endif
|
6282
6765
|
|
6283
6766
|
#if QK_K == 256
|
6284
|
-
void ggml_vec_dot_q4_K_q8_K(
|
6767
|
+
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
6285
6768
|
assert(n % QK_K == 0);
|
6769
|
+
assert(nrc == 1);
|
6770
|
+
UNUSED(nrc);
|
6771
|
+
UNUSED(bx);
|
6772
|
+
UNUSED(by);
|
6773
|
+
UNUSED(bs);
|
6286
6774
|
|
6287
6775
|
const block_q4_K * restrict x = vx;
|
6288
6776
|
const block_q8_K * restrict y = vy;
|
@@ -6637,8 +7125,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6637
7125
|
#endif
|
6638
7126
|
}
|
6639
7127
|
#else
|
6640
|
-
void ggml_vec_dot_q4_K_q8_K(
|
7128
|
+
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
6641
7129
|
assert(n % QK_K == 0);
|
7130
|
+
assert(nrc == 1);
|
7131
|
+
UNUSED(nrc);
|
7132
|
+
UNUSED(bx);
|
7133
|
+
UNUSED(by);
|
7134
|
+
UNUSED(bs);
|
6642
7135
|
|
6643
7136
|
const block_q4_K * restrict x = vx;
|
6644
7137
|
const block_q8_K * restrict y = vy;
|
@@ -6670,9 +7163,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6670
7163
|
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
6671
7164
|
|
6672
7165
|
const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
|
6673
|
-
sum_mins += y[i].d * (
|
7166
|
+
sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi;
|
6674
7167
|
|
6675
|
-
const float d = y[i].d * (
|
7168
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
|
6676
7169
|
|
6677
7170
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
|
6678
7171
|
|
@@ -6880,8 +7373,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
6880
7373
|
#endif
|
6881
7374
|
|
6882
7375
|
#if QK_K == 256
|
6883
|
-
void ggml_vec_dot_q5_K_q8_K(
|
7376
|
+
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
6884
7377
|
assert(n % QK_K == 0);
|
7378
|
+
assert(nrc == 1);
|
7379
|
+
UNUSED(nrc);
|
7380
|
+
UNUSED(bx);
|
7381
|
+
UNUSED(by);
|
7382
|
+
UNUSED(bs);
|
6885
7383
|
|
6886
7384
|
const block_q5_K * restrict x = vx;
|
6887
7385
|
const block_q8_K * restrict y = vy;
|
@@ -7300,8 +7798,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7300
7798
|
|
7301
7799
|
#else
|
7302
7800
|
|
7303
|
-
void ggml_vec_dot_q5_K_q8_K(
|
7801
|
+
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
7304
7802
|
assert(n % QK_K == 0);
|
7803
|
+
assert(nrc == 1);
|
7804
|
+
UNUSED(nrc);
|
7805
|
+
UNUSED(bx);
|
7806
|
+
UNUSED(by);
|
7807
|
+
UNUSED(bs);
|
7305
7808
|
|
7306
7809
|
const block_q5_K * restrict x = vx;
|
7307
7810
|
const block_q8_K * restrict y = vy;
|
@@ -7320,7 +7823,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7320
7823
|
|
7321
7824
|
for (int i = 0; i < nb; ++i) {
|
7322
7825
|
|
7323
|
-
const float d = y[i].d * (
|
7826
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
7324
7827
|
const int8_t * sc = x[i].scales;
|
7325
7828
|
|
7326
7829
|
const uint8_t * restrict q5 = x[i].qs;
|
@@ -7462,7 +7965,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7462
7965
|
|
7463
7966
|
for (int i = 0; i < nb; ++i) {
|
7464
7967
|
|
7465
|
-
const float d = y[i].d * (
|
7968
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
7466
7969
|
const int8_t * sc = x[i].scales;
|
7467
7970
|
|
7468
7971
|
const uint8_t * restrict q5 = x[i].qs;
|
@@ -7566,8 +8069,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7566
8069
|
|
7567
8070
|
|
7568
8071
|
#if QK_K == 256
|
7569
|
-
void ggml_vec_dot_q6_K_q8_K(
|
8072
|
+
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
7570
8073
|
assert(n % QK_K == 0);
|
8074
|
+
assert(nrc == 1);
|
8075
|
+
UNUSED(nrc);
|
8076
|
+
UNUSED(bx);
|
8077
|
+
UNUSED(by);
|
8078
|
+
UNUSED(bs);
|
7571
8079
|
|
7572
8080
|
const block_q6_K * restrict x = vx;
|
7573
8081
|
const block_q8_K * restrict y = vy;
|
@@ -7998,8 +8506,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
7998
8506
|
|
7999
8507
|
#else
|
8000
8508
|
|
8001
|
-
void ggml_vec_dot_q6_K_q8_K(
|
8509
|
+
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
8002
8510
|
assert(n % QK_K == 0);
|
8511
|
+
assert(nrc == 1);
|
8512
|
+
UNUSED(nrc);
|
8513
|
+
UNUSED(bx);
|
8514
|
+
UNUSED(by);
|
8515
|
+
UNUSED(bs);
|
8003
8516
|
|
8004
8517
|
const block_q6_K * restrict x = vx;
|
8005
8518
|
const block_q8_K * restrict y = vy;
|
@@ -8020,7 +8533,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
8020
8533
|
|
8021
8534
|
for (int i = 0; i < nb; ++i) {
|
8022
8535
|
|
8023
|
-
const float d_all = (
|
8536
|
+
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
8024
8537
|
|
8025
8538
|
const uint8_t * restrict q6 = x[i].ql;
|
8026
8539
|
const uint8_t * restrict qh = x[i].qh;
|
@@ -8191,7 +8704,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
8191
8704
|
|
8192
8705
|
for (int i = 0; i < nb; ++i) {
|
8193
8706
|
|
8194
|
-
const float d_all = (
|
8707
|
+
const float d_all = GGML_FP16_TO_FP32(x[i].d);
|
8195
8708
|
|
8196
8709
|
const uint8_t * restrict q6 = x[i].ql;
|
8197
8710
|
const uint8_t * restrict qh = x[i].qh;
|
@@ -8328,8 +8841,13 @@ static const int8_t keven_signs_q2xs[1024] = {
|
|
8328
8841
|
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
8329
8842
|
};
|
8330
8843
|
|
8331
|
-
void ggml_vec_dot_iq2_xxs_q8_K(
|
8844
|
+
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
8332
8845
|
assert(n % QK_K == 0);
|
8846
|
+
assert(nrc == 1);
|
8847
|
+
UNUSED(nrc);
|
8848
|
+
UNUSED(bx);
|
8849
|
+
UNUSED(by);
|
8850
|
+
UNUSED(bs);
|
8333
8851
|
|
8334
8852
|
const block_iq2_xxs * restrict x = vx;
|
8335
8853
|
const block_q8_K * restrict y = vy;
|
@@ -8451,8 +8969,13 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
|
|
8451
8969
|
#endif
|
8452
8970
|
}
|
8453
8971
|
|
8454
|
-
void ggml_vec_dot_iq2_xs_q8_K(
|
8972
|
+
void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
8455
8973
|
assert(n % QK_K == 0);
|
8974
|
+
assert(nrc == 1);
|
8975
|
+
UNUSED(nrc);
|
8976
|
+
UNUSED(bx);
|
8977
|
+
UNUSED(by);
|
8978
|
+
UNUSED(bs);
|
8456
8979
|
|
8457
8980
|
const block_iq2_xs * restrict x = vx;
|
8458
8981
|
const block_q8_K * restrict y = vy;
|
@@ -8670,9 +9193,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|
8670
9193
|
#endif
|
8671
9194
|
}
|
8672
9195
|
|
8673
|
-
|
8674
|
-
void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
9196
|
+
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
8675
9197
|
assert(n % QK_K == 0);
|
9198
|
+
assert(nrc == 1);
|
9199
|
+
UNUSED(nrc);
|
9200
|
+
UNUSED(bx);
|
9201
|
+
UNUSED(by);
|
9202
|
+
UNUSED(bs);
|
8676
9203
|
|
8677
9204
|
const block_iq3_xxs * restrict x = vx;
|
8678
9205
|
const block_q8_K * restrict y = vy;
|
@@ -8698,10 +9225,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res
|
|
8698
9225
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
8699
9226
|
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
8700
9227
|
memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
|
8701
|
-
const uint32x4_t aux32x4_0 =
|
8702
|
-
const uint32x4_t aux32x4_1 =
|
8703
|
-
const uint32x4_t aux32x4_2 =
|
8704
|
-
const uint32x4_t aux32x4_3 =
|
9228
|
+
const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
|
9229
|
+
const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
|
9230
|
+
const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
|
9231
|
+
const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
|
8705
9232
|
q3 += 16;
|
8706
9233
|
q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
|
8707
9234
|
q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
|
@@ -8800,6 +9327,271 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res
|
|
8800
9327
|
#endif
|
8801
9328
|
}
|
8802
9329
|
|
9330
|
+
#ifdef __AVX2__
|
9331
|
+
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
9332
|
+
const __m256i ax = _mm256_sign_epi8(x, x);
|
9333
|
+
const __m256i sy = _mm256_sign_epi8(y, x);
|
9334
|
+
return _mm256_maddubs_epi16(ax, sy);
|
9335
|
+
}
|
9336
|
+
#endif
|
9337
|
+
|
9338
|
+
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
9339
|
+
assert(n % QK_K == 0);
|
9340
|
+
assert(nrc == 1);
|
9341
|
+
UNUSED(nrc);
|
9342
|
+
UNUSED(bx);
|
9343
|
+
UNUSED(by);
|
9344
|
+
UNUSED(bs);
|
9345
|
+
|
9346
|
+
const block_iq1_s * restrict x = vx;
|
9347
|
+
const block_q8_K * restrict y = vy;
|
9348
|
+
|
9349
|
+
const int nb = n / QK_K;
|
9350
|
+
|
9351
|
+
#if defined __ARM_NEON
|
9352
|
+
|
9353
|
+
const uint8x16_t m8 = vdupq_n_u8(0x08);
|
9354
|
+
const uint8x16_t m7 = vdupq_n_u8(0x07);
|
9355
|
+
const uint8x16_t m1 = vdupq_n_u8(0x01);
|
9356
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
9357
|
+
|
9358
|
+
uint16_t gindex[8];
|
9359
|
+
uint16x8x2_t vindex;
|
9360
|
+
int8x16x4_t q1b;
|
9361
|
+
ggml_int8x16x4_t q8b;
|
9362
|
+
uint16x8x4_t scales;
|
9363
|
+
int32x4x2_t sumi;
|
9364
|
+
int32x4x2_t dotq;
|
9365
|
+
|
9366
|
+
float sumf = 0;
|
9367
|
+
for (int i = 0; i < nb; ++i) {
|
9368
|
+
|
9369
|
+
const int8_t * q8 = y[i].qs;
|
9370
|
+
const uint8_t * qs = x[i].qs;
|
9371
|
+
const uint8_t * sc = x[i].scales;
|
9372
|
+
|
9373
|
+
sumi.val[0] = sumi.val[1] = vzero;
|
9374
|
+
|
9375
|
+
for (int i128 = 0; i128 < QK_K/128; ++i128) {
|
9376
|
+
const uint8x16_t ql = vld1q_u8(qs); qs += 16;
|
9377
|
+
const uint8x8_t tm1 = vld1_u8 (sc); sc += 8;
|
9378
|
+
const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
|
9379
|
+
const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
|
9380
|
+
const uint8x16_t hbit = vandq_u8(qh, m8);
|
9381
|
+
vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
|
9382
|
+
vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
|
9383
|
+
const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
|
9384
|
+
scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
|
9385
|
+
scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
|
9386
|
+
|
9387
|
+
for (int l = 0; l < 2; ++l) {
|
9388
|
+
vst1q_u16(gindex+0, vindex.val[l]);
|
9389
|
+
q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
|
9390
|
+
q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
|
9391
|
+
q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
|
9392
|
+
q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
|
9393
|
+
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
9394
|
+
|
9395
|
+
dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
|
9396
|
+
dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
|
9397
|
+
|
9398
|
+
sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
|
9399
|
+
sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
|
9400
|
+
}
|
9401
|
+
}
|
9402
|
+
|
9403
|
+
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
|
9404
|
+
}
|
9405
|
+
|
9406
|
+
*s = sumf;
|
9407
|
+
|
9408
|
+
#elif defined __AVX2__
|
9409
|
+
|
9410
|
+
const __m128i m8 = _mm_set1_epi8(0x08);
|
9411
|
+
const __m128i m7 = _mm_set1_epi8(0x07);
|
9412
|
+
const __m128i m1 = _mm_set1_epi8(0x01);
|
9413
|
+
const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
|
9414
|
+
const __m128i shuffle_s[4] = {
|
9415
|
+
_mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
|
9416
|
+
_mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
|
9417
|
+
_mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
|
9418
|
+
_mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
|
9419
|
+
};
|
9420
|
+
|
9421
|
+
uint64_t aux64;
|
9422
|
+
|
9423
|
+
__m256i v_gindex;
|
9424
|
+
const uint16_t * gindex = (const uint16_t *)&v_gindex;
|
9425
|
+
|
9426
|
+
__m256 accum = _mm256_setzero_ps();
|
9427
|
+
for (int i = 0; i < nb; ++i) {
|
9428
|
+
|
9429
|
+
const int8_t * q8 = y[i].qs;
|
9430
|
+
const uint8_t * qs = x[i].qs;
|
9431
|
+
const uint8_t * sc = x[i].scales;
|
9432
|
+
|
9433
|
+
__m256i sumi = _mm256_setzero_si256();
|
9434
|
+
for (int i128 = 0; i128 < QK_K/128; ++i128) {
|
9435
|
+
const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
|
9436
|
+
memcpy(&aux64, sc, 8); sc += 8;
|
9437
|
+
const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
|
9438
|
+
const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
|
9439
|
+
v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
|
9440
|
+
const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
|
9441
|
+
|
9442
|
+
for (int i32 = 0; i32 < 4; ++i32) {
|
9443
|
+
const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
9444
|
+
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
|
9445
|
+
iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
|
9446
|
+
const __m256i dot = mul_add_epi8(q1b, q8b);
|
9447
|
+
const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
|
9448
|
+
const __m256i p = _mm256_madd_epi16(s16, dot);
|
9449
|
+
sumi = _mm256_add_epi32(sumi, p);
|
9450
|
+
}
|
9451
|
+
|
9452
|
+
}
|
9453
|
+
|
9454
|
+
accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
|
9455
|
+
|
9456
|
+
}
|
9457
|
+
|
9458
|
+
*s = hsum_float_8(accum);
|
9459
|
+
|
9460
|
+
#else
|
9461
|
+
|
9462
|
+
int db[4];
|
9463
|
+
uint16_t idx[4];
|
9464
|
+
|
9465
|
+
float sumf = 0;
|
9466
|
+
for (int i = 0; i < nb; ++i) {
|
9467
|
+
|
9468
|
+
const int8_t * q8 = y[i].qs;
|
9469
|
+
const uint8_t * qs = x[i].qs;
|
9470
|
+
const uint8_t * sc = x[i].scales;
|
9471
|
+
|
9472
|
+
int sumi = 0;
|
9473
|
+
for (int i32 = 0; i32 < QK_K/32; ++i32) {
|
9474
|
+
idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
|
9475
|
+
idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
|
9476
|
+
idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
|
9477
|
+
idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
|
9478
|
+
db[0] = (2*(sc[0] & 7) + 1);
|
9479
|
+
db[1] = (2*((sc[0] >> 4) & 7) + 1);
|
9480
|
+
db[2] = (2*(sc[1] & 7) + 1);
|
9481
|
+
db[3] = (2*((sc[1] >> 4) & 7) + 1);
|
9482
|
+
for (int l = 0; l < 4; ++l) {
|
9483
|
+
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
|
9484
|
+
int suml = 0;
|
9485
|
+
for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
|
9486
|
+
sumi += db[l] * suml;
|
9487
|
+
q8 += 8;
|
9488
|
+
}
|
9489
|
+
qs += 4;
|
9490
|
+
sc += 2;
|
9491
|
+
}
|
9492
|
+
|
9493
|
+
sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
|
9494
|
+
}
|
9495
|
+
|
9496
|
+
*s = sumf;
|
9497
|
+
|
9498
|
+
#endif
|
9499
|
+
}
|
9500
|
+
|
9501
|
+
void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
9502
|
+
assert(nrc == 1);
|
9503
|
+
UNUSED(nrc);
|
9504
|
+
UNUSED(bx);
|
9505
|
+
UNUSED(by);
|
9506
|
+
UNUSED(bs);
|
9507
|
+
assert(n % QK4_NL == 0);
|
9508
|
+
static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
|
9509
|
+
|
9510
|
+
const block_iq4_nl * restrict x = vx;
|
9511
|
+
const block_q8_0 * restrict y = vy;
|
9512
|
+
|
9513
|
+
const int nb = n / QK4_NL;
|
9514
|
+
|
9515
|
+
#if defined __ARM_NEON
|
9516
|
+
const int8x16_t values = vld1q_s8(kvalues_iq4nl);
|
9517
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
9518
|
+
uint8x16x2_t q4bits;
|
9519
|
+
int8x16x4_t q4b;
|
9520
|
+
int8x16x4_t q8b;
|
9521
|
+
int32x4_t prod_1, prod_2;
|
9522
|
+
|
9523
|
+
float sumf = 0;
|
9524
|
+
|
9525
|
+
for (int ib = 0; ib < nb; ib += 2) {
|
9526
|
+
q4bits.val[0] = vld1q_u8(x[ib+0].qs);
|
9527
|
+
q4bits.val[1] = vld1q_u8(x[ib+1].qs);
|
9528
|
+
q8b.val[0] = vld1q_s8(y[ib+0].qs);
|
9529
|
+
q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
|
9530
|
+
q8b.val[2] = vld1q_s8(y[ib+1].qs);
|
9531
|
+
q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
|
9532
|
+
|
9533
|
+
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
|
9534
|
+
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
|
9535
|
+
q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
|
9536
|
+
q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
|
9537
|
+
|
9538
|
+
prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
|
9539
|
+
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
|
9540
|
+
|
9541
|
+
sumf +=
|
9542
|
+
GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
|
9543
|
+
GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
|
9544
|
+
}
|
9545
|
+
|
9546
|
+
*s = sumf;
|
9547
|
+
|
9548
|
+
#elif defined __AVX2__
|
9549
|
+
|
9550
|
+
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
|
9551
|
+
const __m128i m4b = _mm_set1_epi8(0x0f);
|
9552
|
+
const __m256i mone = _mm256_set1_epi16(1);
|
9553
|
+
|
9554
|
+
__m256 accum1 = _mm256_setzero_ps();
|
9555
|
+
__m256 accum2 = _mm256_setzero_ps();
|
9556
|
+
for (int ib = 0; ib < nb; ib += 2) {
|
9557
|
+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
|
9558
|
+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
|
9559
|
+
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
|
9560
|
+
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
|
9561
|
+
const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
|
9562
|
+
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
|
9563
|
+
const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
|
9564
|
+
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
|
9565
|
+
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
|
9566
|
+
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
|
9567
|
+
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
|
9568
|
+
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
|
9569
|
+
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
|
9570
|
+
_mm256_cvtepi32_ps(p_1), accum1);
|
9571
|
+
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
|
9572
|
+
_mm256_cvtepi32_ps(p_2), accum2);
|
9573
|
+
|
9574
|
+
y += 2;
|
9575
|
+
x += 2;
|
9576
|
+
}
|
9577
|
+
|
9578
|
+
*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
9579
|
+
|
9580
|
+
#else
|
9581
|
+
float sumf = 0;
|
9582
|
+
for (int ib = 0; ib < nb; ++ib) {
|
9583
|
+
const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
|
9584
|
+
int sumi1 = 0, sumi2 = 0;
|
9585
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
9586
|
+
sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
|
9587
|
+
sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
|
9588
|
+
}
|
9589
|
+
sumf += d * (sumi1 + sumi2);
|
9590
|
+
}
|
9591
|
+
*s = sumf;
|
9592
|
+
#endif
|
9593
|
+
}
|
9594
|
+
|
8803
9595
|
// ================================ IQ2 quantization =============================================
|
8804
9596
|
|
8805
9597
|
typedef struct {
|
@@ -8808,14 +9600,22 @@ typedef struct {
|
|
8808
9600
|
uint16_t * neighbours;
|
8809
9601
|
} iq2_entry_t;
|
8810
9602
|
|
8811
|
-
static iq2_entry_t iq2_data[
|
9603
|
+
static iq2_entry_t iq2_data[3] = {
|
9604
|
+
{NULL, NULL, NULL},
|
8812
9605
|
{NULL, NULL, NULL},
|
8813
9606
|
{NULL, NULL, NULL},
|
8814
9607
|
};
|
8815
9608
|
|
8816
|
-
static inline int iq2_data_index(
|
8817
|
-
GGML_ASSERT(
|
8818
|
-
return
|
9609
|
+
static inline int iq2_data_index(enum ggml_type type) {
|
9610
|
+
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
|
9611
|
+
return type == GGML_TYPE_IQ2_XXS ? 0 :
|
9612
|
+
type == GGML_TYPE_IQ2_XS ? 1 : 2;
|
9613
|
+
}
|
9614
|
+
|
9615
|
+
static inline int iq2_grid_size(enum ggml_type type) {
|
9616
|
+
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
|
9617
|
+
return type == GGML_TYPE_IQ2_XXS ? 256 :
|
9618
|
+
type == GGML_TYPE_IQ2_XS ? 512 : 512;
|
8819
9619
|
}
|
8820
9620
|
|
8821
9621
|
static int iq2_compare_func(const void * left, const void * right) {
|
@@ -8824,12 +9624,13 @@ static int iq2_compare_func(const void * left, const void * right) {
|
|
8824
9624
|
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
|
8825
9625
|
}
|
8826
9626
|
|
8827
|
-
void iq2xs_init_impl(
|
8828
|
-
const int gindex = iq2_data_index(
|
9627
|
+
void iq2xs_init_impl(enum ggml_type type) {
|
9628
|
+
const int gindex = iq2_data_index(type);
|
9629
|
+
const int grid_size = iq2_grid_size(type);
|
8829
9630
|
if (iq2_data[gindex].grid) {
|
8830
9631
|
return;
|
8831
9632
|
}
|
8832
|
-
static const uint16_t
|
9633
|
+
static const uint16_t kgrid_2bit_256[256] = {
|
8833
9634
|
0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
|
8834
9635
|
100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
|
8835
9636
|
1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
|
@@ -8847,7 +9648,7 @@ void iq2xs_init_impl(int grid_size) {
|
|
8847
9648
|
33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
|
8848
9649
|
37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
|
8849
9650
|
};
|
8850
|
-
static const uint16_t
|
9651
|
+
static const uint16_t kgrid_2bit_512[512] = {
|
8851
9652
|
0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
|
8852
9653
|
73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
|
8853
9654
|
260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
|
@@ -8881,9 +9682,45 @@ void iq2xs_init_impl(int grid_size) {
|
|
8881
9682
|
40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
|
8882
9683
|
42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
|
8883
9684
|
};
|
9685
|
+
static const uint16_t kgrid_1bit_512[512] = {
|
9686
|
+
10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545,
|
9687
|
+
553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444,
|
9688
|
+
1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440,
|
9689
|
+
2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422,
|
9690
|
+
4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397,
|
9691
|
+
5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769,
|
9692
|
+
5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788,
|
9693
|
+
6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794,
|
9694
|
+
9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272,
|
9695
|
+
10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
|
9696
|
+
16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
|
9697
|
+
17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
|
9698
|
+
18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
|
9699
|
+
20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
|
9700
|
+
20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
|
9701
|
+
21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
|
9702
|
+
21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
|
9703
|
+
21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
|
9704
|
+
22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
|
9705
|
+
23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
|
9706
|
+
25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
|
9707
|
+
25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
|
9708
|
+
26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
|
9709
|
+
32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
|
9710
|
+
33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
|
9711
|
+
34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
|
9712
|
+
37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
|
9713
|
+
38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
|
9714
|
+
38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
|
9715
|
+
39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
|
9716
|
+
41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
|
9717
|
+
42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
|
9718
|
+
};
|
9719
|
+
|
8884
9720
|
const int kmap_size = 43692;
|
8885
|
-
const int nwant = 2;
|
8886
|
-
const uint16_t * kgrid =
|
9721
|
+
const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
|
9722
|
+
const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
|
9723
|
+
type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : kgrid_1bit_512;
|
8887
9724
|
uint64_t * kgrid_q2xs;
|
8888
9725
|
int * kmap_q2xs;
|
8889
9726
|
uint16_t * kneighbors_q2xs;
|
@@ -8979,9 +9816,9 @@ void iq2xs_init_impl(int grid_size) {
|
|
8979
9816
|
free(dist2);
|
8980
9817
|
}
|
8981
9818
|
|
8982
|
-
void iq2xs_free_impl(
|
8983
|
-
GGML_ASSERT(
|
8984
|
-
const int gindex = iq2_data_index(
|
9819
|
+
void iq2xs_free_impl(enum ggml_type type) {
|
9820
|
+
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
|
9821
|
+
const int gindex = iq2_data_index(type);
|
8985
9822
|
if (iq2_data[gindex].grid) {
|
8986
9823
|
free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
|
8987
9824
|
free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
|
@@ -9015,7 +9852,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
|
|
9015
9852
|
|
9016
9853
|
static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
9017
9854
|
|
9018
|
-
const int gindex = iq2_data_index(
|
9855
|
+
const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
|
9019
9856
|
|
9020
9857
|
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
9021
9858
|
const int * kmap_q2xs = iq2_data[gindex].map;
|
@@ -9188,7 +10025,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
|
|
9188
10025
|
|
9189
10026
|
static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
9190
10027
|
|
9191
|
-
const int gindex = iq2_data_index(
|
10028
|
+
const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
|
9192
10029
|
|
9193
10030
|
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
9194
10031
|
const int * kmap_q2xs = iq2_data[gindex].map;
|
@@ -9825,3 +10662,327 @@ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * re
|
|
9825
10662
|
assert(k % QK_K == 0);
|
9826
10663
|
quantize_row_iq3_xxs_impl(x, y, k, NULL);
|
9827
10664
|
}
|
10665
|
+
|
10666
|
+
// =================================== 1.5 bpw ===================================================
|
10667
|
+
|
10668
|
+
static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
10669
|
+
const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
|
10670
|
+
int num_neighbors = neighbours[0];
|
10671
|
+
GGML_ASSERT(num_neighbors > 0);
|
10672
|
+
float best_score = 0;
|
10673
|
+
int grid_index = -1;
|
10674
|
+
for (int j = 1; j <= num_neighbors; ++j) {
|
10675
|
+
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
10676
|
+
float sumqx = 0, sumq2 = 0;
|
10677
|
+
for (int i = 0; i < 8; ++i) {
|
10678
|
+
float q = (pg[i] - 3)/2;
|
10679
|
+
float w = weight[i];
|
10680
|
+
sumqx += w*q*xval[i];
|
10681
|
+
sumq2 += w*q*q;
|
10682
|
+
}
|
10683
|
+
if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
10684
|
+
*scale = sumqx/sumq2; best_score = *scale * sumqx;
|
10685
|
+
grid_index = neighbours[j];
|
10686
|
+
}
|
10687
|
+
}
|
10688
|
+
if (grid_index < 0) {
|
10689
|
+
for (int i = 0; i < ngrid; ++i) {
|
10690
|
+
const int8_t * grid_i = (const int8_t *)(grid + i);
|
10691
|
+
float sumqx = 0, sumq2 = 0;
|
10692
|
+
for (int j = 0; j < 8; ++j) {
|
10693
|
+
float w = weight[j];
|
10694
|
+
float q = (grid_i[j] - 3)/2;
|
10695
|
+
sumqx += w*q*xval[j];
|
10696
|
+
sumq2 += w*q*q;
|
10697
|
+
}
|
10698
|
+
if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
10699
|
+
*scale = sumqx/sumq2; best_score = *scale*sumqx;
|
10700
|
+
grid_index = i;
|
10701
|
+
}
|
10702
|
+
}
|
10703
|
+
}
|
10704
|
+
if (grid_index < 0) {
|
10705
|
+
printf("Oops, did not find grid point\n");
|
10706
|
+
printf("Have %d neighbours\n", num_neighbors);
|
10707
|
+
for (int j = 1; j <= num_neighbors; ++j) {
|
10708
|
+
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
10709
|
+
float sumqx = 0, sumq2 = 0;
|
10710
|
+
for (int i = 0; i < 8; ++i) {
|
10711
|
+
float q = (pg[i] - 3)/2;
|
10712
|
+
float w = weight[i];
|
10713
|
+
sumqx += w*q*xval[i];
|
10714
|
+
sumq2 += w*q*q;
|
10715
|
+
}
|
10716
|
+
printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
|
10717
|
+
}
|
10718
|
+
}
|
10719
|
+
GGML_ASSERT(grid_index >= 0);
|
10720
|
+
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
10721
|
+
*scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result.
|
10722
|
+
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
10723
|
+
const int8_t * pg = (const int8_t *)(grid + grid_index);
|
10724
|
+
for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
|
10725
|
+
return grid_index;
|
10726
|
+
}
|
10727
|
+
|
10728
|
+
static int iq1_sort_helper(const void * left, const void * right) {
|
10729
|
+
const float * l = left;
|
10730
|
+
const float * r = right;
|
10731
|
+
return *l < *r ? -1 : *l > *r ? 1 : 0;
|
10732
|
+
}
|
10733
|
+
|
10734
|
+
static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
10735
|
+
|
10736
|
+
const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
|
10737
|
+
|
10738
|
+
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
|
10739
|
+
const int * kmap_q2xs = iq2_data[gindex].map;
|
10740
|
+
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
|
10741
|
+
|
10742
|
+
GGML_ASSERT(quant_weights && "missing quantization weights");
|
10743
|
+
GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
|
10744
|
+
GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
|
10745
|
+
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
|
10746
|
+
GGML_ASSERT(n%QK_K == 0);
|
10747
|
+
|
10748
|
+
const int nbl = n/256;
|
10749
|
+
|
10750
|
+
block_iq1_s * y = vy;
|
10751
|
+
|
10752
|
+
float scales[QK_K/8];
|
10753
|
+
float weight[8];
|
10754
|
+
int8_t L[8];
|
10755
|
+
float sumx[9];
|
10756
|
+
float sumw[9];
|
10757
|
+
float pairs[16];
|
10758
|
+
int * idx = (int *)(pairs + 1);
|
10759
|
+
uint8_t hbit[QK_K/8];
|
10760
|
+
|
10761
|
+
for (int ibl = 0; ibl < nbl; ++ibl) {
|
10762
|
+
|
10763
|
+
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
10764
|
+
memset(y[ibl].qs, 0, QK_K/8);
|
10765
|
+
memset(y[ibl].scales, 0, QK_K/16);
|
10766
|
+
|
10767
|
+
float max_scale = 0;
|
10768
|
+
|
10769
|
+
const float * xbl = x + QK_K*ibl;
|
10770
|
+
float sumx2 = 0;
|
10771
|
+
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
10772
|
+
float sigma2 = sumx2/QK_K;
|
10773
|
+
|
10774
|
+
for (int ib = 0; ib < QK_K/8; ++ib) {
|
10775
|
+
const float * xb = xbl + 8*ib;
|
10776
|
+
const float * qw = quant_weights + QK_K*ibl + 8*ib;
|
10777
|
+
for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
10778
|
+
float max = fabsf(xb[0]);
|
10779
|
+
for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
|
10780
|
+
if (!max) {
|
10781
|
+
scales[ib] = 0;
|
10782
|
+
memset(L, 1, 8);
|
10783
|
+
continue;
|
10784
|
+
}
|
10785
|
+
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
|
10786
|
+
// With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
|
10787
|
+
// boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
|
10788
|
+
// in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
|
10789
|
+
// Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
|
10790
|
+
// for each possible and score for each split.
|
10791
|
+
for (int j = 0; j < 8; ++j) {
|
10792
|
+
pairs[2*j] = xb[j];
|
10793
|
+
idx[2*j] = j;
|
10794
|
+
}
|
10795
|
+
qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
|
10796
|
+
{
|
10797
|
+
sumx[0] = sumw[0] = 0;
|
10798
|
+
for (int j = 0; j < 8; ++j) {
|
10799
|
+
int i = idx[2*j];
|
10800
|
+
sumx[j+1] = sumx[j] + weight[i]*xb[i];
|
10801
|
+
sumw[j+1] = sumw[j] + weight[i];
|
10802
|
+
}
|
10803
|
+
}
|
10804
|
+
float best_score = 0, scale = max;
|
10805
|
+
int besti1 = 0, besti2 = 0;
|
10806
|
+
for (int i1 = 0; i1 <= 8; ++i1) {
|
10807
|
+
for (int i2 = i1; i2 <= 8; ++i2) {
|
10808
|
+
float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
|
10809
|
+
float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
|
10810
|
+
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
10811
|
+
scale = sumqx/sumq2; best_score = scale*sumqx;
|
10812
|
+
besti1 = i1; besti2 = i2;
|
10813
|
+
}
|
10814
|
+
}
|
10815
|
+
}
|
10816
|
+
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
|
10817
|
+
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
|
10818
|
+
for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2;
|
10819
|
+
if (scale < 0) {
|
10820
|
+
for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
|
10821
|
+
scale = -scale;
|
10822
|
+
}
|
10823
|
+
// Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
|
10824
|
+
// grid point that minimizes SSD.
|
10825
|
+
uint16_t u = 0;
|
10826
|
+
for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
|
10827
|
+
int grid_index = kmap_q2xs[u];
|
10828
|
+
if (grid_index < 0) {
|
10829
|
+
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
10830
|
+
grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
|
10831
|
+
GGML_ASSERT(grid_index >= 0);
|
10832
|
+
}
|
10833
|
+
y[ibl].qs[ib] = grid_index & 255;
|
10834
|
+
hbit[ib] = grid_index >> 8;
|
10835
|
+
GGML_ASSERT(scale >= 0);
|
10836
|
+
scales[ib] = scale;
|
10837
|
+
max_scale = MAX(max_scale, scale);
|
10838
|
+
}
|
10839
|
+
|
10840
|
+
if (!max_scale) {
|
10841
|
+
memset(y[ibl].qs, 0, QK_K/8);
|
10842
|
+
continue;
|
10843
|
+
}
|
10844
|
+
|
10845
|
+
float d = max_scale/15;
|
10846
|
+
y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
|
10847
|
+
float id = 1/d;
|
10848
|
+
for (int ib = 0; ib < QK_K/8; ++ib) {
|
10849
|
+
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
10850
|
+
l = MAX(0, MIN(7, l));
|
10851
|
+
if (hbit[ib]) l |= 8;
|
10852
|
+
y[ibl].scales[ib/2] |= (l << 4*(ib%2));
|
10853
|
+
}
|
10854
|
+
}
|
10855
|
+
}
|
10856
|
+
|
10857
|
+
size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
10858
|
+
(void)hist;
|
10859
|
+
GGML_ASSERT(n_per_row%QK_K == 0);
|
10860
|
+
int nblock = n_per_row/QK_K;
|
10861
|
+
char * qrow = (char *)dst;
|
10862
|
+
for (int row = 0; row < nrow; ++row) {
|
10863
|
+
quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
|
10864
|
+
src += n_per_row;
|
10865
|
+
qrow += nblock*sizeof(block_iq1_s);
|
10866
|
+
}
|
10867
|
+
return nrow * nblock * sizeof(block_iq1_s);
|
10868
|
+
}
|
10869
|
+
|
10870
|
+
// ============================ 4-bit non-linear quants
|
10871
|
+
|
10872
|
+
static inline int best_index_int8(int n, const int8_t * val, float x) {
|
10873
|
+
if (x <= val[0]) return 0;
|
10874
|
+
if (x >= val[n-1]) return n-1;
|
10875
|
+
int ml = 0, mu = n-1;
|
10876
|
+
while (mu-ml > 1) {
|
10877
|
+
int mav = (ml+mu)/2;
|
10878
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
10879
|
+
}
|
10880
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
10881
|
+
}
|
10882
|
+
|
10883
|
+
static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x,
|
10884
|
+
ggml_fp16_t * dh, uint8_t * q4,
|
10885
|
+
float * weight, uint8_t * L,
|
10886
|
+
const int8_t * values,
|
10887
|
+
const float * quant_weights) {
|
10888
|
+
|
10889
|
+
const int ntry = 7;
|
10890
|
+
|
10891
|
+
float sigma2 = 0;
|
10892
|
+
for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j];
|
10893
|
+
sigma2 *= 2.f/QK4_NL;
|
10894
|
+
|
10895
|
+
const int nb = QK4_NL/block_size;
|
10896
|
+
|
10897
|
+
memset(q4, 0, QK4_NL/2);
|
10898
|
+
for (int ib = 0; ib < nb; ++ib) {
|
10899
|
+
dh[ib] = GGML_FP32_TO_FP16(0.f);
|
10900
|
+
const float * xb = x + ib*block_size;
|
10901
|
+
if (quant_weights) {
|
10902
|
+
const float * qw = quant_weights + ib*block_size;
|
10903
|
+
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
10904
|
+
} else {
|
10905
|
+
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
|
10906
|
+
}
|
10907
|
+
float amax = 0, max = 0;
|
10908
|
+
for (int j = 0; j < block_size; ++j) {
|
10909
|
+
float ax = fabsf(xb[j]);
|
10910
|
+
if (ax > amax) {
|
10911
|
+
amax = ax; max = xb[j];
|
10912
|
+
}
|
10913
|
+
}
|
10914
|
+
if (!amax) {
|
10915
|
+
continue;
|
10916
|
+
}
|
10917
|
+
float d = -max/values[0];
|
10918
|
+
float id = 1/d;
|
10919
|
+
float sumqx = 0, sumq2 = 0;
|
10920
|
+
for (int j = 0; j < block_size; ++j) {
|
10921
|
+
float al = id*xb[j];
|
10922
|
+
int l = best_index_int8(16, values, al);
|
10923
|
+
float q = values[l];
|
10924
|
+
float w = weight[j];
|
10925
|
+
sumqx += w*q*xb[j];
|
10926
|
+
sumq2 += w*q*q;
|
10927
|
+
}
|
10928
|
+
float best_id = id;
|
10929
|
+
d = sumqx/sumq2;
|
10930
|
+
float best = d*sumqx;
|
10931
|
+
for (int itry = -ntry; itry <= ntry; ++itry) {
|
10932
|
+
id = (itry + values[0])/max;
|
10933
|
+
sumqx = sumq2 = 0;
|
10934
|
+
for (int j = 0; j < block_size; ++j) {
|
10935
|
+
float al = id*xb[j];
|
10936
|
+
int l = best_index_int8(16, values, al);
|
10937
|
+
float q = values[l];
|
10938
|
+
float w = weight[j];
|
10939
|
+
sumqx += w*q*xb[j];
|
10940
|
+
sumq2 += w*q*q;
|
10941
|
+
}
|
10942
|
+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
10943
|
+
d = sumqx/sumq2; best = d * sumqx;
|
10944
|
+
best_id = id;
|
10945
|
+
}
|
10946
|
+
}
|
10947
|
+
dh[ib] = GGML_FP32_TO_FP16(d);
|
10948
|
+
for (int j = 0; j < block_size; ++j) {
|
10949
|
+
L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]);
|
10950
|
+
}
|
10951
|
+
}
|
10952
|
+
for (int i = 0; i < QK4_NL/32; ++i) {
|
10953
|
+
for (int j = 0; j < 16; ++j) {
|
10954
|
+
q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
|
10955
|
+
}
|
10956
|
+
}
|
10957
|
+
}
|
10958
|
+
|
10959
|
+
size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
|
10960
|
+
(void)hist;
|
10961
|
+
GGML_ASSERT(n_per_row%QK4_NL == 0);
|
10962
|
+
int nblock = n_per_row/QK4_NL;
|
10963
|
+
char * qrow = (char *)dst;
|
10964
|
+
uint8_t L[QK4_NL];
|
10965
|
+
float weight[32];
|
10966
|
+
for (int row = 0; row < nrow; ++row) {
|
10967
|
+
block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
|
10968
|
+
for (int ibl = 0; ibl < nblock; ++ibl) {
|
10969
|
+
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
|
10970
|
+
quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
|
10971
|
+
}
|
10972
|
+
src += n_per_row;
|
10973
|
+
qrow += nblock*sizeof(block_iq4_nl);
|
10974
|
+
}
|
10975
|
+
return nrow * nblock * sizeof(block_iq4_nl);
|
10976
|
+
}
|
10977
|
+
|
10978
|
+
void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
|
10979
|
+
assert(k % QK4_NL == 0);
|
10980
|
+
block_iq4_nl * restrict y = vy;
|
10981
|
+
quantize_row_iq4_nl_reference(x, y, k);
|
10982
|
+
}
|
10983
|
+
|
10984
|
+
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
|
10985
|
+
assert(k % QK4_NL == 0);
|
10986
|
+
quantize_iq4_nl(x, y, 1, k, NULL, NULL);
|
10987
|
+
}
|
10988
|
+
|