llama_cpp 0.12.6 → 0.12.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -438,6 +438,30 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
438
438
  return res;
439
439
  }
440
440
 
441
+ // NOTE: not tested
442
+ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
443
+ int8x16_t res;
444
+
445
+ res[ 0] = a[b[ 0]];
446
+ res[ 1] = a[b[ 1]];
447
+ res[ 2] = a[b[ 2]];
448
+ res[ 3] = a[b[ 3]];
449
+ res[ 4] = a[b[ 4]];
450
+ res[ 5] = a[b[ 5]];
451
+ res[ 6] = a[b[ 6]];
452
+ res[ 7] = a[b[ 7]];
453
+ res[ 8] = a[b[ 8]];
454
+ res[ 9] = a[b[ 9]];
455
+ res[10] = a[b[10]];
456
+ res[11] = a[b[11]];
457
+ res[12] = a[b[12]];
458
+ res[13] = a[b[13]];
459
+ res[14] = a[b[14]];
460
+ res[15] = a[b[15]];
461
+
462
+ return res;
463
+ }
464
+
441
465
  #else
442
466
 
443
467
  #define ggml_int16x8x2_t int16x8x2_t
@@ -451,6 +475,7 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
451
475
  #define ggml_vld1q_u8_x4 vld1q_u8_x4
452
476
  #define ggml_vld1q_s8_x2 vld1q_s8_x2
453
477
  #define ggml_vld1q_s8_x4 vld1q_s8_x4
478
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
454
479
 
455
480
  #endif
456
481
 
@@ -1837,9 +1862,9 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
1837
1862
  float sigma2 = sumx2/QK_K;
1838
1863
  for (int j = 0; j < QK_K/16; ++j) {
1839
1864
  const float * restrict qw = quant_weights + QK_K * i + 16*j;
1840
- for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
1841
- for (int l = 0; l < 16; ++l) sw[j] += weight[l];
1842
- scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1865
+ for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
1866
+ for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
1867
+ scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1843
1868
  }
1844
1869
 
1845
1870
  float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
@@ -3480,6 +3505,139 @@ static const uint32_t iq3xxs_grid[256] = {
3480
3505
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3481
3506
  };
3482
3507
 
3508
+ #define NGRID_IQ2XXS 512
3509
+ static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
3510
+ 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
3511
+ 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
3512
+ 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
3513
+ 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
3514
+ 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
3515
+ 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
3516
+ 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
3517
+ 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
3518
+ 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
3519
+ 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
3520
+ 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
3521
+ 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
3522
+ 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
3523
+ 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
3524
+ 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
3525
+ 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
3526
+ 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
3527
+ 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
3528
+ 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
3529
+ 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
3530
+ 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
3531
+ 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
3532
+ 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
3533
+ 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
3534
+ 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
3535
+ 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
3536
+ 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
3537
+ 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
3538
+ 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
3539
+ 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
3540
+ 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
3541
+ 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
3542
+ 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
3543
+ 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
3544
+ 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
3545
+ 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
3546
+ 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
3547
+ 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
3548
+ 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
3549
+ 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
3550
+ 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
3551
+ 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
3552
+ 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
3553
+ 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
3554
+ 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
3555
+ 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
3556
+ 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
3557
+ 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
3558
+ 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
3559
+ 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
3560
+ 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
3561
+ 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
3562
+ 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
3563
+ 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
3564
+ 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
3565
+ 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
3566
+ 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
3567
+ 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
3568
+ 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
3569
+ 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
3570
+ 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
3571
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
3572
+ 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
3573
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
3574
+ 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
3575
+ 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
3576
+ 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
3577
+ 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
3578
+ 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
3579
+ 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
3580
+ 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
3581
+ 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
3582
+ 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
3583
+ 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
3584
+ 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
3585
+ 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
3586
+ 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
3587
+ 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
3588
+ 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
3589
+ 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
3590
+ 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
3591
+ 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
3592
+ 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
3593
+ 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
3594
+ 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
3595
+ 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
3596
+ 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
3597
+ 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
3598
+ 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
3599
+ 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
3600
+ 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
3601
+ 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
3602
+ 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
3603
+ 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
3604
+ 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
3605
+ 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
3606
+ 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
3607
+ 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
3608
+ 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
3609
+ 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
3610
+ 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
3611
+ 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
3612
+ 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
3613
+ 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
3614
+ 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
3615
+ 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
3616
+ 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
3617
+ 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
3618
+ 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
3619
+ 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
3620
+ 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
3621
+ 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
3622
+ 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
3623
+ 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
3624
+ 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
3625
+ 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
3626
+ 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
3627
+ 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
3628
+ 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
3629
+ 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
3630
+ 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
3631
+ 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
3632
+ 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
3633
+ 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
3634
+ 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
3635
+ 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
3636
+ 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
3637
+ 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
3638
+
3639
+ };
3640
+
3483
3641
  static const uint8_t ksigns_iq2xs[128] = {
3484
3642
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3485
3643
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3578,6 +3736,69 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
3578
3736
  }
3579
3737
  }
3580
3738
 
3739
+ // ====================== 1.5625 bpw (de)-quantization
3740
+
3741
+ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
3742
+ assert(k % QK_K == 0);
3743
+ const int nb = k / QK_K;
3744
+
3745
+ float db[4];
3746
+ uint16_t idx[4];
3747
+ //const int8_t * grid[4];
3748
+
3749
+ for (int i = 0; i < nb; i++) {
3750
+
3751
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3752
+ const uint8_t * sc = x[i].scales;
3753
+ const uint8_t * qs = x[i].qs;
3754
+
3755
+ for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
3756
+ idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
3757
+ idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
3758
+ idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
3759
+ idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
3760
+ //grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
3761
+ //grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
3762
+ //grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
3763
+ //grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
3764
+ db[0] = d * (2*(sc[0] & 7) + 1);
3765
+ db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
3766
+ db[2] = d * (2*(sc[1] & 7) + 1);
3767
+ db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
3768
+ for (int l = 0; l < 4; ++l) {
3769
+ const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
3770
+ for (int j = 0; j < 8; ++j) {
3771
+ //y[j] = db[l] * grid[l][j];
3772
+ y[j] = db[l] * grid[j];
3773
+ }
3774
+ y += 8;
3775
+ }
3776
+ qs += 4;
3777
+ sc += 2;
3778
+ }
3779
+ }
3780
+ }
3781
+
3782
+ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
3783
+
3784
+ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
3785
+ assert(k % QK4_NL == 0);
3786
+ const int nb = k / QK4_NL;
3787
+
3788
+ for (int i = 0; i < nb; i++) {
3789
+
3790
+ const uint8_t * qs = x[i].qs;
3791
+
3792
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3793
+ for (int j = 0; j < QK4_NL/2; ++j) {
3794
+ y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf];
3795
+ y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4];
3796
+ }
3797
+ y += QK4_NL;
3798
+ qs += QK4_NL/2;
3799
+ }
3800
+ }
3801
+
3581
3802
  //===================================== Q8_K ==============================================
3582
3803
 
3583
3804
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -3848,15 +4069,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3848
4069
 
3849
4070
  const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
3850
4071
 
3851
- __m128i bx = _mm_and_si128(lowMask, tmp);
3852
- __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
3853
- bx = _mm_sub_epi8(bx, off);
3854
- const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
4072
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp);
4073
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
4074
+ bx_0 = _mm_sub_epi8(bx_0, off);
4075
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
3855
4076
 
3856
- bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
3857
- by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
3858
- bx = _mm_sub_epi8(bx, off);
3859
- const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
4077
+ bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
4078
+ by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
4079
+ bx_0 = _mm_sub_epi8(bx_0, off);
4080
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
3860
4081
 
3861
4082
  // Convert int32_t to float
3862
4083
  __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
@@ -4442,21 +4663,21 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4442
4663
  /* Compute combined scale for the block */
4443
4664
  const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
4444
4665
 
4445
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4666
+ __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
4446
4667
  const __m256i bxhi = bytes_from_bits_32(x[i].qh);
4447
4668
  __m128i bxhil = _mm256_castsi256_si128(bxhi);
4448
4669
  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
4449
4670
  bxhil = _mm_andnot_si128(bxhil, mask);
4450
4671
  bxhih = _mm_andnot_si128(bxhih, mask);
4451
- __m128i bxl = _mm256_castsi256_si128(bx);
4452
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
4672
+ __m128i bxl = _mm256_castsi256_si128(bx_0);
4673
+ __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
4453
4674
  bxl = _mm_or_si128(bxl, bxhil);
4454
4675
  bxh = _mm_or_si128(bxh, bxhih);
4455
- bx = MM256_SET_M128I(bxh, bxl);
4676
+ bx_0 = MM256_SET_M128I(bxh, bxl);
4456
4677
 
4457
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4678
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
4458
4679
 
4459
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
4680
+ const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
4460
4681
 
4461
4682
  /* Multiply q with scale and accumulate */
4462
4683
  acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
@@ -4749,22 +4970,22 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4749
4970
 
4750
4971
  summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
4751
4972
 
4752
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4973
+ __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
4753
4974
  const __m256i bxhi = bytes_from_bits_32(x[i].qh);
4754
4975
  __m128i bxhil = _mm256_castsi256_si128(bxhi);
4755
4976
  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
4756
4977
  bxhil = _mm_and_si128(bxhil, mask);
4757
4978
  bxhih = _mm_and_si128(bxhih, mask);
4758
- __m128i bxl = _mm256_castsi256_si128(bx);
4759
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
4979
+ __m128i bxl = _mm256_castsi256_si128(bx_0);
4980
+ __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
4760
4981
  bxl = _mm_or_si128(bxl, bxhil);
4761
4982
  bxh = _mm_or_si128(bxh, bxhih);
4762
- bx = MM256_SET_M128I(bxh, bxl);
4983
+ bx_0 = MM256_SET_M128I(bxh, bxl);
4763
4984
 
4764
4985
  const __m256 dy = _mm256_set1_ps(y[i].d);
4765
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4986
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
4766
4987
 
4767
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
4988
+ const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
4768
4989
 
4769
4990
  acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
4770
4991
  }
@@ -4993,10 +5214,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4993
5214
 
4994
5215
  for (int i = 0; i < nb; i++) {
4995
5216
  // load elements
4996
- vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
4997
- vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
5217
+ vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl);
5218
+ vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
4998
5219
 
4999
- vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
5220
+ vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
5000
5221
 
5001
5222
  vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
5002
5223
  vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
@@ -5433,8 +5654,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5433
5654
 
5434
5655
  for (int i = 0; i < nb; ++i) {
5435
5656
 
5436
- const float d = y[i].d * (float)x[i].d;
5437
- const float dmin = -y[i].d * (float)x[i].dmin;
5657
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5658
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5438
5659
 
5439
5660
  const uint8_t * restrict q2 = x[i].qs;
5440
5661
  const int8_t * restrict q8 = y[i].qs;
@@ -5583,8 +5804,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5583
5804
 
5584
5805
  for (int i = 0; i < nb; ++i) {
5585
5806
 
5586
- const float d = y[i].d * (float)x[i].d;
5587
- const float dmin = -y[i].d * (float)x[i].dmin;
5807
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5808
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5588
5809
 
5589
5810
  const uint8_t * restrict q2 = x[i].qs;
5590
5811
  const int8_t * restrict q8 = y[i].qs;
@@ -6237,7 +6458,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6237
6458
 
6238
6459
  int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
6239
6460
 
6240
- const float d = y[i].d * (float)x[i].d;
6461
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6241
6462
 
6242
6463
  const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
6243
6464
  q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
@@ -6439,7 +6660,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6439
6660
 
6440
6661
  int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
6441
6662
 
6442
- const float d = y[i].d * (float)x[i].d;
6663
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6443
6664
 
6444
6665
  vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
6445
6666
 
@@ -6942,9 +7163,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6942
7163
  aux16[1] = (a[0] >> 4) & 0x0f0f;
6943
7164
 
6944
7165
  const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
6945
- sum_mins += y[i].d * (float)x[i].d[1] * summi;
7166
+ sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi;
6946
7167
 
6947
- const float d = y[i].d * (float)x[i].d[0];
7168
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
6948
7169
 
6949
7170
  const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
6950
7171
 
@@ -7602,7 +7823,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7602
7823
 
7603
7824
  for (int i = 0; i < nb; ++i) {
7604
7825
 
7605
- const float d = y[i].d * (float)x[i].d;
7826
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
7606
7827
  const int8_t * sc = x[i].scales;
7607
7828
 
7608
7829
  const uint8_t * restrict q5 = x[i].qs;
@@ -7744,7 +7965,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7744
7965
 
7745
7966
  for (int i = 0; i < nb; ++i) {
7746
7967
 
7747
- const float d = y[i].d * (float)x[i].d;
7968
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
7748
7969
  const int8_t * sc = x[i].scales;
7749
7970
 
7750
7971
  const uint8_t * restrict q5 = x[i].qs;
@@ -8312,7 +8533,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8312
8533
 
8313
8534
  for (int i = 0; i < nb; ++i) {
8314
8535
 
8315
- const float d_all = (float)x[i].d;
8536
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
8316
8537
 
8317
8538
  const uint8_t * restrict q6 = x[i].ql;
8318
8539
  const uint8_t * restrict qh = x[i].qh;
@@ -8483,7 +8704,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8483
8704
 
8484
8705
  for (int i = 0; i < nb; ++i) {
8485
8706
 
8486
- const float d_all = (float)x[i].d;
8707
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
8487
8708
 
8488
8709
  const uint8_t * restrict q6 = x[i].ql;
8489
8710
  const uint8_t * restrict qh = x[i].qh;
@@ -8972,7 +9193,6 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
8972
9193
  #endif
8973
9194
  }
8974
9195
 
8975
- // TODO
8976
9196
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8977
9197
  assert(n % QK_K == 0);
8978
9198
  assert(nrc == 1);
@@ -9107,6 +9327,271 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
9107
9327
  #endif
9108
9328
  }
9109
9329
 
9330
+ #ifdef __AVX2__
9331
+ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9332
+ const __m256i ax = _mm256_sign_epi8(x, x);
9333
+ const __m256i sy = _mm256_sign_epi8(y, x);
9334
+ return _mm256_maddubs_epi16(ax, sy);
9335
+ }
9336
+ #endif
9337
+
9338
+ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
9339
+ assert(n % QK_K == 0);
9340
+ assert(nrc == 1);
9341
+ UNUSED(nrc);
9342
+ UNUSED(bx);
9343
+ UNUSED(by);
9344
+ UNUSED(bs);
9345
+
9346
+ const block_iq1_s * restrict x = vx;
9347
+ const block_q8_K * restrict y = vy;
9348
+
9349
+ const int nb = n / QK_K;
9350
+
9351
+ #if defined __ARM_NEON
9352
+
9353
+ const uint8x16_t m8 = vdupq_n_u8(0x08);
9354
+ const uint8x16_t m7 = vdupq_n_u8(0x07);
9355
+ const uint8x16_t m1 = vdupq_n_u8(0x01);
9356
+ const int32x4_t vzero = vdupq_n_s32(0);
9357
+
9358
+ uint16_t gindex[8];
9359
+ uint16x8x2_t vindex;
9360
+ int8x16x4_t q1b;
9361
+ ggml_int8x16x4_t q8b;
9362
+ uint16x8x4_t scales;
9363
+ int32x4x2_t sumi;
9364
+ int32x4x2_t dotq;
9365
+
9366
+ float sumf = 0;
9367
+ for (int i = 0; i < nb; ++i) {
9368
+
9369
+ const int8_t * q8 = y[i].qs;
9370
+ const uint8_t * qs = x[i].qs;
9371
+ const uint8_t * sc = x[i].scales;
9372
+
9373
+ sumi.val[0] = sumi.val[1] = vzero;
9374
+
9375
+ for (int i128 = 0; i128 < QK_K/128; ++i128) {
9376
+ const uint8x16_t ql = vld1q_u8(qs); qs += 16;
9377
+ const uint8x8_t tm1 = vld1_u8 (sc); sc += 8;
9378
+ const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
9379
+ const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
9380
+ const uint8x16_t hbit = vandq_u8(qh, m8);
9381
+ vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
9382
+ vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
9383
+ const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
9384
+ scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
9385
+ scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
9386
+
9387
+ for (int l = 0; l < 2; ++l) {
9388
+ vst1q_u16(gindex+0, vindex.val[l]);
9389
+ q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
9390
+ q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
9391
+ q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
9392
+ q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
9393
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
9394
+
9395
+ dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
9396
+ dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
9397
+
9398
+ sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
9399
+ sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
9400
+ }
9401
+ }
9402
+
9403
+ sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
9404
+ }
9405
+
9406
+ *s = sumf;
9407
+
9408
+ #elif defined __AVX2__
9409
+
9410
+ const __m128i m8 = _mm_set1_epi8(0x08);
9411
+ const __m128i m7 = _mm_set1_epi8(0x07);
9412
+ const __m128i m1 = _mm_set1_epi8(0x01);
9413
+ const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
9414
+ const __m128i shuffle_s[4] = {
9415
+ _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
9416
+ _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
9417
+ _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
9418
+ _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
9419
+ };
9420
+
9421
+ uint64_t aux64;
9422
+
9423
+ __m256i v_gindex;
9424
+ const uint16_t * gindex = (const uint16_t *)&v_gindex;
9425
+
9426
+ __m256 accum = _mm256_setzero_ps();
9427
+ for (int i = 0; i < nb; ++i) {
9428
+
9429
+ const int8_t * q8 = y[i].qs;
9430
+ const uint8_t * qs = x[i].qs;
9431
+ const uint8_t * sc = x[i].scales;
9432
+
9433
+ __m256i sumi = _mm256_setzero_si256();
9434
+ for (int i128 = 0; i128 < QK_K/128; ++i128) {
9435
+ const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
9436
+ memcpy(&aux64, sc, 8); sc += 8;
9437
+ const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
9438
+ const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
9439
+ v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
9440
+ const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
9441
+
9442
+ for (int i32 = 0; i32 < 4; ++i32) {
9443
+ const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
9444
+ const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
9445
+ iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
9446
+ const __m256i dot = mul_add_epi8(q1b, q8b);
9447
+ const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
9448
+ const __m256i p = _mm256_madd_epi16(s16, dot);
9449
+ sumi = _mm256_add_epi32(sumi, p);
9450
+ }
9451
+
9452
+ }
9453
+
9454
+ accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
9455
+
9456
+ }
9457
+
9458
+ *s = hsum_float_8(accum);
9459
+
9460
+ #else
9461
+
9462
+ int db[4];
9463
+ uint16_t idx[4];
9464
+
9465
+ float sumf = 0;
9466
+ for (int i = 0; i < nb; ++i) {
9467
+
9468
+ const int8_t * q8 = y[i].qs;
9469
+ const uint8_t * qs = x[i].qs;
9470
+ const uint8_t * sc = x[i].scales;
9471
+
9472
+ int sumi = 0;
9473
+ for (int i32 = 0; i32 < QK_K/32; ++i32) {
9474
+ idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
9475
+ idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
9476
+ idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
9477
+ idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
9478
+ db[0] = (2*(sc[0] & 7) + 1);
9479
+ db[1] = (2*((sc[0] >> 4) & 7) + 1);
9480
+ db[2] = (2*(sc[1] & 7) + 1);
9481
+ db[3] = (2*((sc[1] >> 4) & 7) + 1);
9482
+ for (int l = 0; l < 4; ++l) {
9483
+ const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
9484
+ int suml = 0;
9485
+ for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
9486
+ sumi += db[l] * suml;
9487
+ q8 += 8;
9488
+ }
9489
+ qs += 4;
9490
+ sc += 2;
9491
+ }
9492
+
9493
+ sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
9494
+ }
9495
+
9496
+ *s = sumf;
9497
+
9498
+ #endif
9499
+ }
9500
+
9501
+ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
9502
+ assert(nrc == 1);
9503
+ UNUSED(nrc);
9504
+ UNUSED(bx);
9505
+ UNUSED(by);
9506
+ UNUSED(bs);
9507
+ assert(n % QK4_NL == 0);
9508
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
9509
+
9510
+ const block_iq4_nl * restrict x = vx;
9511
+ const block_q8_0 * restrict y = vy;
9512
+
9513
+ const int nb = n / QK4_NL;
9514
+
9515
+ #if defined __ARM_NEON
9516
+ const int8x16_t values = vld1q_s8(kvalues_iq4nl);
9517
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
9518
+ uint8x16x2_t q4bits;
9519
+ int8x16x4_t q4b;
9520
+ int8x16x4_t q8b;
9521
+ int32x4_t prod_1, prod_2;
9522
+
9523
+ float sumf = 0;
9524
+
9525
+ for (int ib = 0; ib < nb; ib += 2) {
9526
+ q4bits.val[0] = vld1q_u8(x[ib+0].qs);
9527
+ q4bits.val[1] = vld1q_u8(x[ib+1].qs);
9528
+ q8b.val[0] = vld1q_s8(y[ib+0].qs);
9529
+ q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
9530
+ q8b.val[2] = vld1q_s8(y[ib+1].qs);
9531
+ q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
9532
+
9533
+ q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
9534
+ q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
9535
+ q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
9536
+ q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
9537
+
9538
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
9539
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
9540
+
9541
+ sumf +=
9542
+ GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
9543
+ GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
9544
+ }
9545
+
9546
+ *s = sumf;
9547
+
9548
+ #elif defined __AVX2__
9549
+
9550
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
9551
+ const __m128i m4b = _mm_set1_epi8(0x0f);
9552
+ const __m256i mone = _mm256_set1_epi16(1);
9553
+
9554
+ __m256 accum1 = _mm256_setzero_ps();
9555
+ __m256 accum2 = _mm256_setzero_ps();
9556
+ for (int ib = 0; ib < nb; ib += 2) {
9557
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
9558
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
9559
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
9560
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
9561
+ const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
9562
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
9563
+ const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
9564
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
9565
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
9566
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
9567
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
9568
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
9569
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
9570
+ _mm256_cvtepi32_ps(p_1), accum1);
9571
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
9572
+ _mm256_cvtepi32_ps(p_2), accum2);
9573
+
9574
+ y += 2;
9575
+ x += 2;
9576
+ }
9577
+
9578
+ *s = hsum_float_8(_mm256_add_ps(accum1, accum2));
9579
+
9580
+ #else
9581
+ float sumf = 0;
9582
+ for (int ib = 0; ib < nb; ++ib) {
9583
+ const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
9584
+ int sumi1 = 0, sumi2 = 0;
9585
+ for (int j = 0; j < QK4_NL/2; ++j) {
9586
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
9587
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
9588
+ }
9589
+ sumf += d * (sumi1 + sumi2);
9590
+ }
9591
+ *s = sumf;
9592
+ #endif
9593
+ }
9594
+
9110
9595
  // ================================ IQ2 quantization =============================================
9111
9596
 
9112
9597
  typedef struct {
@@ -9115,14 +9600,22 @@ typedef struct {
9115
9600
  uint16_t * neighbours;
9116
9601
  } iq2_entry_t;
9117
9602
 
9118
- static iq2_entry_t iq2_data[2] = {
9603
+ static iq2_entry_t iq2_data[3] = {
9604
+ {NULL, NULL, NULL},
9119
9605
  {NULL, NULL, NULL},
9120
9606
  {NULL, NULL, NULL},
9121
9607
  };
9122
9608
 
9123
- static inline int iq2_data_index(int grid_size) {
9124
- GGML_ASSERT(grid_size == 256 || grid_size == 512);
9125
- return grid_size == 256 ? 0 : 1;
9609
+ static inline int iq2_data_index(enum ggml_type type) {
9610
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9611
+ return type == GGML_TYPE_IQ2_XXS ? 0 :
9612
+ type == GGML_TYPE_IQ2_XS ? 1 : 2;
9613
+ }
9614
+
9615
+ static inline int iq2_grid_size(enum ggml_type type) {
9616
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9617
+ return type == GGML_TYPE_IQ2_XXS ? 256 :
9618
+ type == GGML_TYPE_IQ2_XS ? 512 : 512;
9126
9619
  }
9127
9620
 
9128
9621
  static int iq2_compare_func(const void * left, const void * right) {
@@ -9131,12 +9624,13 @@ static int iq2_compare_func(const void * left, const void * right) {
9131
9624
  return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
9132
9625
  }
9133
9626
 
9134
- void iq2xs_init_impl(int grid_size) {
9135
- const int gindex = iq2_data_index(grid_size);
9627
+ void iq2xs_init_impl(enum ggml_type type) {
9628
+ const int gindex = iq2_data_index(type);
9629
+ const int grid_size = iq2_grid_size(type);
9136
9630
  if (iq2_data[gindex].grid) {
9137
9631
  return;
9138
9632
  }
9139
- static const uint16_t kgrid_256[256] = {
9633
+ static const uint16_t kgrid_2bit_256[256] = {
9140
9634
  0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
9141
9635
  100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
9142
9636
  1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
@@ -9154,7 +9648,7 @@ void iq2xs_init_impl(int grid_size) {
9154
9648
  33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
9155
9649
  37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
9156
9650
  };
9157
- static const uint16_t kgrid_512[512] = {
9651
+ static const uint16_t kgrid_2bit_512[512] = {
9158
9652
  0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
9159
9653
  73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
9160
9654
  260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
@@ -9188,9 +9682,45 @@ void iq2xs_init_impl(int grid_size) {
9188
9682
  40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
9189
9683
  42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
9190
9684
  };
9685
+ static const uint16_t kgrid_1bit_512[512] = {
9686
+ 10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545,
9687
+ 553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444,
9688
+ 1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440,
9689
+ 2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422,
9690
+ 4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397,
9691
+ 5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769,
9692
+ 5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788,
9693
+ 6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794,
9694
+ 9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272,
9695
+ 10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
9696
+ 16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
9697
+ 17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
9698
+ 18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
9699
+ 20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
9700
+ 20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
9701
+ 21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
9702
+ 21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
9703
+ 21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
9704
+ 22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
9705
+ 23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
9706
+ 25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
9707
+ 25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
9708
+ 26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
9709
+ 32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
9710
+ 33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
9711
+ 34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
9712
+ 37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
9713
+ 38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
9714
+ 38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
9715
+ 39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
9716
+ 41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
9717
+ 42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
9718
+ };
9719
+
9191
9720
  const int kmap_size = 43692;
9192
- const int nwant = 2;
9193
- const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
9721
+ const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
9722
+ const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
9723
+ type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : kgrid_1bit_512;
9194
9724
  uint64_t * kgrid_q2xs;
9195
9725
  int * kmap_q2xs;
9196
9726
  uint16_t * kneighbors_q2xs;
@@ -9286,9 +9816,9 @@ void iq2xs_init_impl(int grid_size) {
9286
9816
  free(dist2);
9287
9817
  }
9288
9818
 
9289
- void iq2xs_free_impl(int grid_size) {
9290
- GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
9291
- const int gindex = iq2_data_index(grid_size);
9819
+ void iq2xs_free_impl(enum ggml_type type) {
9820
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9821
+ const int gindex = iq2_data_index(type);
9292
9822
  if (iq2_data[gindex].grid) {
9293
9823
  free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
9294
9824
  free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
@@ -9322,7 +9852,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
9322
9852
 
9323
9853
  static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9324
9854
 
9325
- const int gindex = iq2_data_index(256);
9855
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
9326
9856
 
9327
9857
  const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9328
9858
  const int * kmap_q2xs = iq2_data[gindex].map;
@@ -9495,7 +10025,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9495
10025
 
9496
10026
  static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9497
10027
 
9498
- const int gindex = iq2_data_index(512);
10028
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
9499
10029
 
9500
10030
  const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9501
10031
  const int * kmap_q2xs = iq2_data[gindex].map;
@@ -10132,3 +10662,327 @@ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * re
10132
10662
  assert(k % QK_K == 0);
10133
10663
  quantize_row_iq3_xxs_impl(x, y, k, NULL);
10134
10664
  }
10665
+
10666
+ // =================================== 1.5 bpw ===================================================
10667
+
10668
+ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
10669
+ const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
10670
+ int num_neighbors = neighbours[0];
10671
+ GGML_ASSERT(num_neighbors > 0);
10672
+ float best_score = 0;
10673
+ int grid_index = -1;
10674
+ for (int j = 1; j <= num_neighbors; ++j) {
10675
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
10676
+ float sumqx = 0, sumq2 = 0;
10677
+ for (int i = 0; i < 8; ++i) {
10678
+ float q = (pg[i] - 3)/2;
10679
+ float w = weight[i];
10680
+ sumqx += w*q*xval[i];
10681
+ sumq2 += w*q*q;
10682
+ }
10683
+ if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10684
+ *scale = sumqx/sumq2; best_score = *scale * sumqx;
10685
+ grid_index = neighbours[j];
10686
+ }
10687
+ }
10688
+ if (grid_index < 0) {
10689
+ for (int i = 0; i < ngrid; ++i) {
10690
+ const int8_t * grid_i = (const int8_t *)(grid + i);
10691
+ float sumqx = 0, sumq2 = 0;
10692
+ for (int j = 0; j < 8; ++j) {
10693
+ float w = weight[j];
10694
+ float q = (grid_i[j] - 3)/2;
10695
+ sumqx += w*q*xval[j];
10696
+ sumq2 += w*q*q;
10697
+ }
10698
+ if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10699
+ *scale = sumqx/sumq2; best_score = *scale*sumqx;
10700
+ grid_index = i;
10701
+ }
10702
+ }
10703
+ }
10704
+ if (grid_index < 0) {
10705
+ printf("Oops, did not find grid point\n");
10706
+ printf("Have %d neighbours\n", num_neighbors);
10707
+ for (int j = 1; j <= num_neighbors; ++j) {
10708
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
10709
+ float sumqx = 0, sumq2 = 0;
10710
+ for (int i = 0; i < 8; ++i) {
10711
+ float q = (pg[i] - 3)/2;
10712
+ float w = weight[i];
10713
+ sumqx += w*q*xval[i];
10714
+ sumq2 += w*q*q;
10715
+ }
10716
+ printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
10717
+ }
10718
+ }
10719
+ GGML_ASSERT(grid_index >= 0);
10720
+ //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
10721
+ *scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result.
10722
+ //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
10723
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
10724
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
10725
+ return grid_index;
10726
+ }
10727
+
10728
+ static int iq1_sort_helper(const void * left, const void * right) {
10729
+ const float * l = left;
10730
+ const float * r = right;
10731
+ return *l < *r ? -1 : *l > *r ? 1 : 0;
10732
+ }
10733
+
10734
+ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
10735
+
10736
+ const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
10737
+
10738
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
10739
+ const int * kmap_q2xs = iq2_data[gindex].map;
10740
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
10741
+
10742
+ GGML_ASSERT(quant_weights && "missing quantization weights");
10743
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
10744
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
10745
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
10746
+ GGML_ASSERT(n%QK_K == 0);
10747
+
10748
+ const int nbl = n/256;
10749
+
10750
+ block_iq1_s * y = vy;
10751
+
10752
+ float scales[QK_K/8];
10753
+ float weight[8];
10754
+ int8_t L[8];
10755
+ float sumx[9];
10756
+ float sumw[9];
10757
+ float pairs[16];
10758
+ int * idx = (int *)(pairs + 1);
10759
+ uint8_t hbit[QK_K/8];
10760
+
10761
+ for (int ibl = 0; ibl < nbl; ++ibl) {
10762
+
10763
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
10764
+ memset(y[ibl].qs, 0, QK_K/8);
10765
+ memset(y[ibl].scales, 0, QK_K/16);
10766
+
10767
+ float max_scale = 0;
10768
+
10769
+ const float * xbl = x + QK_K*ibl;
10770
+ float sumx2 = 0;
10771
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
10772
+ float sigma2 = sumx2/QK_K;
10773
+
10774
+ for (int ib = 0; ib < QK_K/8; ++ib) {
10775
+ const float * xb = xbl + 8*ib;
10776
+ const float * qw = quant_weights + QK_K*ibl + 8*ib;
10777
+ for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
10778
+ float max = fabsf(xb[0]);
10779
+ for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
10780
+ if (!max) {
10781
+ scales[ib] = 0;
10782
+ memset(L, 1, 8);
10783
+ continue;
10784
+ }
10785
+ // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
10786
+ // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
10787
+ // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
10788
+ // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
10789
+ // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
10790
+ // for each possible and score for each split.
10791
+ for (int j = 0; j < 8; ++j) {
10792
+ pairs[2*j] = xb[j];
10793
+ idx[2*j] = j;
10794
+ }
10795
+ qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
10796
+ {
10797
+ sumx[0] = sumw[0] = 0;
10798
+ for (int j = 0; j < 8; ++j) {
10799
+ int i = idx[2*j];
10800
+ sumx[j+1] = sumx[j] + weight[i]*xb[i];
10801
+ sumw[j+1] = sumw[j] + weight[i];
10802
+ }
10803
+ }
10804
+ float best_score = 0, scale = max;
10805
+ int besti1 = 0, besti2 = 0;
10806
+ for (int i1 = 0; i1 <= 8; ++i1) {
10807
+ for (int i2 = i1; i2 <= 8; ++i2) {
10808
+ float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
10809
+ float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
10810
+ if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10811
+ scale = sumqx/sumq2; best_score = scale*sumqx;
10812
+ besti1 = i1; besti2 = i2;
10813
+ }
10814
+ }
10815
+ }
10816
+ for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
10817
+ for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
10818
+ for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2;
10819
+ if (scale < 0) {
10820
+ for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
10821
+ scale = -scale;
10822
+ }
10823
+ // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
10824
+ // grid point that minimizes SSD.
10825
+ uint16_t u = 0;
10826
+ for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
10827
+ int grid_index = kmap_q2xs[u];
10828
+ if (grid_index < 0) {
10829
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
10830
+ grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
10831
+ GGML_ASSERT(grid_index >= 0);
10832
+ }
10833
+ y[ibl].qs[ib] = grid_index & 255;
10834
+ hbit[ib] = grid_index >> 8;
10835
+ GGML_ASSERT(scale >= 0);
10836
+ scales[ib] = scale;
10837
+ max_scale = MAX(max_scale, scale);
10838
+ }
10839
+
10840
+ if (!max_scale) {
10841
+ memset(y[ibl].qs, 0, QK_K/8);
10842
+ continue;
10843
+ }
10844
+
10845
+ float d = max_scale/15;
10846
+ y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
10847
+ float id = 1/d;
10848
+ for (int ib = 0; ib < QK_K/8; ++ib) {
10849
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
10850
+ l = MAX(0, MIN(7, l));
10851
+ if (hbit[ib]) l |= 8;
10852
+ y[ibl].scales[ib/2] |= (l << 4*(ib%2));
10853
+ }
10854
+ }
10855
+ }
10856
+
10857
+ size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
10858
+ (void)hist;
10859
+ GGML_ASSERT(n_per_row%QK_K == 0);
10860
+ int nblock = n_per_row/QK_K;
10861
+ char * qrow = (char *)dst;
10862
+ for (int row = 0; row < nrow; ++row) {
10863
+ quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
10864
+ src += n_per_row;
10865
+ qrow += nblock*sizeof(block_iq1_s);
10866
+ }
10867
+ return nrow * nblock * sizeof(block_iq1_s);
10868
+ }
10869
+
10870
+ // ============================ 4-bit non-linear quants
10871
+
10872
+ static inline int best_index_int8(int n, const int8_t * val, float x) {
10873
+ if (x <= val[0]) return 0;
10874
+ if (x >= val[n-1]) return n-1;
10875
+ int ml = 0, mu = n-1;
10876
+ while (mu-ml > 1) {
10877
+ int mav = (ml+mu)/2;
10878
+ if (x < val[mav]) mu = mav; else ml = mav;
10879
+ }
10880
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
10881
+ }
10882
+
10883
+ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x,
10884
+ ggml_fp16_t * dh, uint8_t * q4,
10885
+ float * weight, uint8_t * L,
10886
+ const int8_t * values,
10887
+ const float * quant_weights) {
10888
+
10889
+ const int ntry = 7;
10890
+
10891
+ float sigma2 = 0;
10892
+ for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j];
10893
+ sigma2 *= 2.f/QK4_NL;
10894
+
10895
+ const int nb = QK4_NL/block_size;
10896
+
10897
+ memset(q4, 0, QK4_NL/2);
10898
+ for (int ib = 0; ib < nb; ++ib) {
10899
+ dh[ib] = GGML_FP32_TO_FP16(0.f);
10900
+ const float * xb = x + ib*block_size;
10901
+ if (quant_weights) {
10902
+ const float * qw = quant_weights + ib*block_size;
10903
+ for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
10904
+ } else {
10905
+ for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
10906
+ }
10907
+ float amax = 0, max = 0;
10908
+ for (int j = 0; j < block_size; ++j) {
10909
+ float ax = fabsf(xb[j]);
10910
+ if (ax > amax) {
10911
+ amax = ax; max = xb[j];
10912
+ }
10913
+ }
10914
+ if (!amax) {
10915
+ continue;
10916
+ }
10917
+ float d = -max/values[0];
10918
+ float id = 1/d;
10919
+ float sumqx = 0, sumq2 = 0;
10920
+ for (int j = 0; j < block_size; ++j) {
10921
+ float al = id*xb[j];
10922
+ int l = best_index_int8(16, values, al);
10923
+ float q = values[l];
10924
+ float w = weight[j];
10925
+ sumqx += w*q*xb[j];
10926
+ sumq2 += w*q*q;
10927
+ }
10928
+ float best_id = id;
10929
+ d = sumqx/sumq2;
10930
+ float best = d*sumqx;
10931
+ for (int itry = -ntry; itry <= ntry; ++itry) {
10932
+ id = (itry + values[0])/max;
10933
+ sumqx = sumq2 = 0;
10934
+ for (int j = 0; j < block_size; ++j) {
10935
+ float al = id*xb[j];
10936
+ int l = best_index_int8(16, values, al);
10937
+ float q = values[l];
10938
+ float w = weight[j];
10939
+ sumqx += w*q*xb[j];
10940
+ sumq2 += w*q*q;
10941
+ }
10942
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
10943
+ d = sumqx/sumq2; best = d * sumqx;
10944
+ best_id = id;
10945
+ }
10946
+ }
10947
+ dh[ib] = GGML_FP32_TO_FP16(d);
10948
+ for (int j = 0; j < block_size; ++j) {
10949
+ L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]);
10950
+ }
10951
+ }
10952
+ for (int i = 0; i < QK4_NL/32; ++i) {
10953
+ for (int j = 0; j < 16; ++j) {
10954
+ q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
10955
+ }
10956
+ }
10957
+ }
10958
+
10959
+ size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
10960
+ (void)hist;
10961
+ GGML_ASSERT(n_per_row%QK4_NL == 0);
10962
+ int nblock = n_per_row/QK4_NL;
10963
+ char * qrow = (char *)dst;
10964
+ uint8_t L[QK4_NL];
10965
+ float weight[32];
10966
+ for (int row = 0; row < nrow; ++row) {
10967
+ block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
10968
+ for (int ibl = 0; ibl < nblock; ++ibl) {
10969
+ const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
10970
+ quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
10971
+ }
10972
+ src += n_per_row;
10973
+ qrow += nblock*sizeof(block_iq4_nl);
10974
+ }
10975
+ return nrow * nblock * sizeof(block_iq4_nl);
10976
+ }
10977
+
10978
+ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
10979
+ assert(k % QK4_NL == 0);
10980
+ block_iq4_nl * restrict y = vy;
10981
+ quantize_row_iq4_nl_reference(x, y, k);
10982
+ }
10983
+
10984
+ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
10985
+ assert(k % QK4_NL == 0);
10986
+ quantize_iq4_nl(x, y, 1, k, NULL, NULL);
10987
+ }
10988
+