llama_cpp 0.12.5 → 0.12.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,6 +49,8 @@
49
49
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
50
50
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
51
51
 
52
+ #define UNUSED GGML_UNUSED
53
+
52
54
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
53
55
 
54
56
  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
@@ -268,6 +270,17 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
268
270
  #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
269
271
 
270
272
  #if defined(__ARM_NEON)
273
+
274
+ #ifdef _MSC_VER
275
+
276
+ #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
277
+
278
+ #else
279
+
280
+ #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
281
+
282
+ #endif
283
+
271
284
  #if !defined(__aarch64__)
272
285
 
273
286
  // 64-bit compatibility
@@ -425,6 +438,30 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
425
438
  return res;
426
439
  }
427
440
 
441
+ // NOTE: not tested
442
+ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
443
+ int8x16_t res;
444
+
445
+ res[ 0] = a[b[ 0]];
446
+ res[ 1] = a[b[ 1]];
447
+ res[ 2] = a[b[ 2]];
448
+ res[ 3] = a[b[ 3]];
449
+ res[ 4] = a[b[ 4]];
450
+ res[ 5] = a[b[ 5]];
451
+ res[ 6] = a[b[ 6]];
452
+ res[ 7] = a[b[ 7]];
453
+ res[ 8] = a[b[ 8]];
454
+ res[ 9] = a[b[ 9]];
455
+ res[10] = a[b[10]];
456
+ res[11] = a[b[11]];
457
+ res[12] = a[b[12]];
458
+ res[13] = a[b[13]];
459
+ res[14] = a[b[14]];
460
+ res[15] = a[b[15]];
461
+
462
+ return res;
463
+ }
464
+
428
465
  #else
429
466
 
430
467
  #define ggml_int16x8x2_t int16x8x2_t
@@ -438,6 +475,7 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
438
475
  #define ggml_vld1q_u8_x4 vld1q_u8_x4
439
476
  #define ggml_vld1q_s8_x2 vld1q_s8_x2
440
477
  #define ggml_vld1q_s8_x4 vld1q_s8_x4
478
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
441
479
 
442
480
  #endif
443
481
 
@@ -1824,9 +1862,9 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
1824
1862
  float sigma2 = sumx2/QK_K;
1825
1863
  for (int j = 0; j < QK_K/16; ++j) {
1826
1864
  const float * restrict qw = quant_weights + QK_K * i + 16*j;
1827
- for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
1828
- for (int l = 0; l < 16; ++l) sw[j] += weight[l];
1829
- scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1865
+ for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
1866
+ for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
1867
+ scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
1830
1868
  }
1831
1869
 
1832
1870
  float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
@@ -3467,6 +3505,139 @@ static const uint32_t iq3xxs_grid[256] = {
3467
3505
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3468
3506
  };
3469
3507
 
3508
+ #define NGRID_IQ2XXS 512
3509
+ static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
3510
+ 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
3511
+ 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
3512
+ 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
3513
+ 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
3514
+ 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
3515
+ 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
3516
+ 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
3517
+ 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
3518
+ 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
3519
+ 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
3520
+ 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
3521
+ 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
3522
+ 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
3523
+ 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
3524
+ 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
3525
+ 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
3526
+ 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
3527
+ 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
3528
+ 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
3529
+ 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
3530
+ 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
3531
+ 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
3532
+ 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
3533
+ 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
3534
+ 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
3535
+ 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
3536
+ 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
3537
+ 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
3538
+ 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
3539
+ 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
3540
+ 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
3541
+ 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
3542
+ 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
3543
+ 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
3544
+ 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
3545
+ 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
3546
+ 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
3547
+ 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
3548
+ 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
3549
+ 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
3550
+ 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
3551
+ 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
3552
+ 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
3553
+ 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
3554
+ 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
3555
+ 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
3556
+ 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
3557
+ 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
3558
+ 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
3559
+ 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
3560
+ 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
3561
+ 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
3562
+ 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
3563
+ 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
3564
+ 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
3565
+ 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
3566
+ 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
3567
+ 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
3568
+ 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
3569
+ 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
3570
+ 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
3571
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
3572
+ 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
3573
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
3574
+ 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
3575
+ 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
3576
+ 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
3577
+ 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
3578
+ 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
3579
+ 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
3580
+ 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
3581
+ 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
3582
+ 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
3583
+ 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
3584
+ 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
3585
+ 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
3586
+ 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
3587
+ 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
3588
+ 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
3589
+ 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
3590
+ 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
3591
+ 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
3592
+ 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
3593
+ 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
3594
+ 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
3595
+ 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
3596
+ 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
3597
+ 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
3598
+ 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
3599
+ 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
3600
+ 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
3601
+ 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
3602
+ 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
3603
+ 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
3604
+ 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
3605
+ 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
3606
+ 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
3607
+ 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
3608
+ 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
3609
+ 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
3610
+ 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
3611
+ 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
3612
+ 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
3613
+ 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
3614
+ 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
3615
+ 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
3616
+ 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
3617
+ 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
3618
+ 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
3619
+ 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
3620
+ 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
3621
+ 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
3622
+ 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
3623
+ 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
3624
+ 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
3625
+ 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
3626
+ 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
3627
+ 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
3628
+ 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
3629
+ 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
3630
+ 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
3631
+ 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
3632
+ 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
3633
+ 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
3634
+ 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
3635
+ 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
3636
+ 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
3637
+ 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
3638
+
3639
+ };
3640
+
3470
3641
  static const uint8_t ksigns_iq2xs[128] = {
3471
3642
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3472
3643
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3565,6 +3736,69 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
3565
3736
  }
3566
3737
  }
3567
3738
 
3739
+ // ====================== 1.5625 bpw (de)-quantization
3740
+
3741
+ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
3742
+ assert(k % QK_K == 0);
3743
+ const int nb = k / QK_K;
3744
+
3745
+ float db[4];
3746
+ uint16_t idx[4];
3747
+ //const int8_t * grid[4];
3748
+
3749
+ for (int i = 0; i < nb; i++) {
3750
+
3751
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3752
+ const uint8_t * sc = x[i].scales;
3753
+ const uint8_t * qs = x[i].qs;
3754
+
3755
+ for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
3756
+ idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
3757
+ idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
3758
+ idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
3759
+ idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
3760
+ //grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
3761
+ //grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
3762
+ //grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
3763
+ //grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
3764
+ db[0] = d * (2*(sc[0] & 7) + 1);
3765
+ db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
3766
+ db[2] = d * (2*(sc[1] & 7) + 1);
3767
+ db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
3768
+ for (int l = 0; l < 4; ++l) {
3769
+ const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
3770
+ for (int j = 0; j < 8; ++j) {
3771
+ //y[j] = db[l] * grid[l][j];
3772
+ y[j] = db[l] * grid[j];
3773
+ }
3774
+ y += 8;
3775
+ }
3776
+ qs += 4;
3777
+ sc += 2;
3778
+ }
3779
+ }
3780
+ }
3781
+
3782
+ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
3783
+
3784
+ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
3785
+ assert(k % QK4_NL == 0);
3786
+ const int nb = k / QK4_NL;
3787
+
3788
+ for (int i = 0; i < nb; i++) {
3789
+
3790
+ const uint8_t * qs = x[i].qs;
3791
+
3792
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3793
+ for (int j = 0; j < QK4_NL/2; ++j) {
3794
+ y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf];
3795
+ y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4];
3796
+ }
3797
+ y += QK4_NL;
3798
+ qs += QK4_NL/2;
3799
+ }
3800
+ }
3801
+
3568
3802
  //===================================== Q8_K ==============================================
3569
3803
 
3570
3804
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -3666,15 +3900,92 @@ static inline __m128i get_scale_shuffle(int i) {
3666
3900
  }
3667
3901
  #endif
3668
3902
 
3669
- void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3903
+ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
3670
3904
  const int qk = QK8_0;
3671
3905
  const int nb = n / qk;
3672
3906
 
3673
3907
  assert(n % qk == 0);
3908
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
3909
+ assert((nrc == 2) || (nrc == 1));
3910
+ #else
3911
+ assert(nrc == 1);
3912
+ #endif
3913
+ UNUSED(nrc);
3914
+ UNUSED(bx);
3915
+ UNUSED(by);
3916
+ UNUSED(bs);
3674
3917
 
3675
3918
  const block_q4_0 * restrict x = vx;
3676
3919
  const block_q8_0 * restrict y = vy;
3677
3920
 
3921
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
3922
+ if (nrc == 2) {
3923
+ const block_q4_0 * restrict vx0 = vx;
3924
+ const block_q4_0 * restrict vx1 = vx + bx;
3925
+
3926
+ const block_q8_0 * restrict vy0 = vy;
3927
+ const block_q8_0 * restrict vy1 = vy + by;
3928
+
3929
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3930
+
3931
+ for (int i = 0; i < nb; i++) {
3932
+ const block_q4_0 * restrict b_x0 = &vx0[i];
3933
+ const block_q4_0 * restrict b_x1 = &vx1[i];
3934
+ const block_q8_0 * restrict b_y0 = &vy0[i];
3935
+ const block_q8_0 * restrict b_y1 = &vy1[i];
3936
+
3937
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3938
+ const int8x16_t s8b = vdupq_n_s8(0x8);
3939
+
3940
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
3941
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
3942
+
3943
+ // 4-bit -> 8-bit
3944
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3945
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3946
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3947
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3948
+
3949
+ // sub 8
3950
+ const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
3951
+ const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
3952
+ const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
3953
+ const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
3954
+
3955
+ // load y
3956
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
3957
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
3958
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
3959
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3960
+
3961
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3962
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3963
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3964
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3965
+
3966
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3967
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
3968
+
3969
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
3970
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
3971
+
3972
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
3973
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
3974
+
3975
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
3976
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
3977
+
3978
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
3979
+ l1, r1)), l2, r2)), l3, r3))), scale);
3980
+ }
3981
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3982
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3983
+
3984
+ vst1_f32(s, vget_low_f32(sumv2));
3985
+ vst1_f32(s + bs, vget_high_f32(sumv2));
3986
+ return;
3987
+ }
3988
+ #endif
3678
3989
  #if defined(__ARM_NEON)
3679
3990
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3680
3991
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -3729,15 +4040,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
3729
4040
  /* Compute combined scale for the block */
3730
4041
  const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
3731
4042
 
3732
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4043
+ __m256i qx = bytes_from_nibbles_32(x[i].qs);
3733
4044
 
3734
4045
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3735
4046
  const __m256i off = _mm256_set1_epi8( 8 );
3736
- bx = _mm256_sub_epi8( bx, off );
4047
+ qx = _mm256_sub_epi8( qx, off );
3737
4048
 
3738
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4049
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
3739
4050
 
3740
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
4051
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
3741
4052
 
3742
4053
  /* Multiply q with scale and accumulate */
3743
4054
  acc = _mm256_fmadd_ps( d, q, acc );
@@ -3758,15 +4069,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
3758
4069
 
3759
4070
  const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
3760
4071
 
3761
- __m128i bx = _mm_and_si128(lowMask, tmp);
3762
- __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
3763
- bx = _mm_sub_epi8(bx, off);
3764
- const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
4072
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp);
4073
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
4074
+ bx_0 = _mm_sub_epi8(bx_0, off);
4075
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
3765
4076
 
3766
- bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
3767
- by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
3768
- bx = _mm_sub_epi8(bx, off);
3769
- const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
4077
+ bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
4078
+ by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
4079
+ bx_0 = _mm_sub_epi8(bx_0, off);
4080
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
3770
4081
 
3771
4082
  // Convert int32_t to float
3772
4083
  __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
@@ -3956,15 +4267,93 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
3956
4267
  #endif
3957
4268
  }
3958
4269
 
3959
- void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4270
+ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
3960
4271
  const int qk = QK8_1;
3961
4272
  const int nb = n / qk;
3962
4273
 
3963
4274
  assert(n % qk == 0);
4275
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4276
+ assert((nrc == 2) || (nrc == 1));
4277
+ #else
4278
+ assert(nrc == 1);
4279
+ #endif
4280
+ UNUSED(nrc);
4281
+ UNUSED(bx);
4282
+ UNUSED(by);
4283
+ UNUSED(bs);
3964
4284
 
3965
4285
  const block_q4_1 * restrict x = vx;
3966
4286
  const block_q8_1 * restrict y = vy;
3967
4287
 
4288
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
4289
+ if (nrc == 2) {
4290
+ const block_q4_1 * restrict vx0 = vx;
4291
+ const block_q4_1 * restrict vx1 = vx + bx;
4292
+ const block_q8_1 * restrict vy0 = vy;
4293
+ const block_q8_1 * restrict vy1 = vy + by;
4294
+
4295
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
4296
+ float32x4_t summs0 = vdupq_n_f32(0.0f);
4297
+
4298
+ for (int i = 0; i < nb; i++) {
4299
+ const block_q4_1 * restrict b_x0 = &vx0[i];
4300
+ const block_q4_1 * restrict b_x1 = &vx1[i];
4301
+ const block_q8_1 * restrict b_y0 = &vy0[i];
4302
+ const block_q8_1 * restrict b_y1 = &vy1[i];
4303
+
4304
+ float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
4305
+ GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
4306
+ GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
4307
+ GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
4308
+ summs0 += summs_t;
4309
+
4310
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
4311
+
4312
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
4313
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
4314
+
4315
+ // 4-bit -> 8-bit
4316
+ const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
4317
+ const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
4318
+ const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
4319
+ const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
4320
+
4321
+ // load y
4322
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
4323
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
4324
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
4325
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
4326
+
4327
+ // mmla into int32x4_t
4328
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
4329
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
4330
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
4331
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
4332
+
4333
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4334
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
4335
+
4336
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4337
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
4338
+
4339
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4340
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
4341
+
4342
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4343
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
4344
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
4345
+ l1, r1)), l2, r2)), l3, r3))), scale);
4346
+ }
4347
+
4348
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
4349
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
4350
+ sumv2 = sumv2 + summs0;
4351
+
4352
+ vst1_f32(s, vget_low_f32(sumv2));
4353
+ vst1_f32(s + bs, vget_high_f32(sumv2));
4354
+ return;
4355
+ }
4356
+ #endif
3968
4357
  // TODO: add WASM SIMD
3969
4358
  #if defined(__ARM_NEON)
3970
4359
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -4028,10 +4417,10 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
4028
4417
  const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
4029
4418
 
4030
4419
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
4031
- const __m256i bx = bytes_from_nibbles_32(x[i].qs);
4032
- const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
4420
+ const __m256i qx = bytes_from_nibbles_32(x[i].qs);
4421
+ const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs );
4033
4422
 
4034
- const __m256 xy = mul_sum_us8_pairs_float(bx, by);
4423
+ const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
4035
4424
 
4036
4425
  // Accumulate d0*d1*x*y
4037
4426
  #if defined(__AVX2__)
@@ -4096,12 +4485,17 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
4096
4485
  #endif
4097
4486
  }
4098
4487
 
4099
- void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4488
+ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4100
4489
  const int qk = QK8_0;
4101
4490
  const int nb = n / qk;
4102
4491
 
4103
4492
  assert(n % qk == 0);
4104
4493
  assert(qk == QK5_0);
4494
+ assert(nrc == 1);
4495
+ UNUSED(nrc);
4496
+ UNUSED(bx);
4497
+ UNUSED(by);
4498
+ UNUSED(bs);
4105
4499
 
4106
4500
  const block_q5_0 * restrict x = vx;
4107
4501
  const block_q8_0 * restrict y = vy;
@@ -4245,14 +4639,14 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
4245
4639
  /* Compute combined scale for the block */
4246
4640
  const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
4247
4641
 
4248
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4642
+ __m256i qx = bytes_from_nibbles_32(x[i].qs);
4249
4643
  __m256i bxhi = bytes_from_bits_32(x[i].qh);
4250
4644
  bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
4251
- bx = _mm256_or_si256(bx, bxhi);
4645
+ qx = _mm256_or_si256(qx, bxhi);
4252
4646
 
4253
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4647
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
4254
4648
 
4255
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
4649
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
4256
4650
 
4257
4651
  /* Multiply q with scale and accumulate */
4258
4652
  acc = _mm256_fmadd_ps(d, q, acc);
@@ -4269,21 +4663,21 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
4269
4663
  /* Compute combined scale for the block */
4270
4664
  const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
4271
4665
 
4272
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4666
+ __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
4273
4667
  const __m256i bxhi = bytes_from_bits_32(x[i].qh);
4274
4668
  __m128i bxhil = _mm256_castsi256_si128(bxhi);
4275
4669
  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
4276
4670
  bxhil = _mm_andnot_si128(bxhil, mask);
4277
4671
  bxhih = _mm_andnot_si128(bxhih, mask);
4278
- __m128i bxl = _mm256_castsi256_si128(bx);
4279
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
4672
+ __m128i bxl = _mm256_castsi256_si128(bx_0);
4673
+ __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
4280
4674
  bxl = _mm_or_si128(bxl, bxhil);
4281
4675
  bxh = _mm_or_si128(bxh, bxhih);
4282
- bx = MM256_SET_M128I(bxh, bxl);
4676
+ bx_0 = MM256_SET_M128I(bxh, bxl);
4283
4677
 
4284
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4678
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
4285
4679
 
4286
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
4680
+ const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
4287
4681
 
4288
4682
  /* Multiply q with scale and accumulate */
4289
4683
  acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
@@ -4382,12 +4776,17 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
4382
4776
  #endif
4383
4777
  }
4384
4778
 
4385
- void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
4779
+ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4386
4780
  const int qk = QK8_1;
4387
4781
  const int nb = n / qk;
4388
4782
 
4389
4783
  assert(n % qk == 0);
4390
4784
  assert(qk == QK5_1);
4785
+ assert(nrc == 1);
4786
+ UNUSED(nrc);
4787
+ UNUSED(bx);
4788
+ UNUSED(by);
4789
+ UNUSED(bs);
4391
4790
 
4392
4791
  const block_q5_1 * restrict x = vx;
4393
4792
  const block_q8_1 * restrict y = vy;
@@ -4544,15 +4943,15 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
4544
4943
 
4545
4944
  summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
4546
4945
 
4547
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4946
+ __m256i qx = bytes_from_nibbles_32(x[i].qs);
4548
4947
  __m256i bxhi = bytes_from_bits_32(x[i].qh);
4549
4948
  bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
4550
- bx = _mm256_or_si256(bx, bxhi);
4949
+ qx = _mm256_or_si256(qx, bxhi);
4551
4950
 
4552
4951
  const __m256 dy = _mm256_set1_ps(y[i].d);
4553
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4952
+ const __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
4554
4953
 
4555
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
4954
+ const __m256 q = mul_sum_us8_pairs_float(qx, qy);
4556
4955
 
4557
4956
  acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
4558
4957
  }
@@ -4571,22 +4970,22 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
4571
4970
 
4572
4971
  summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
4573
4972
 
4574
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
4973
+ __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
4575
4974
  const __m256i bxhi = bytes_from_bits_32(x[i].qh);
4576
4975
  __m128i bxhil = _mm256_castsi256_si128(bxhi);
4577
4976
  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
4578
4977
  bxhil = _mm_and_si128(bxhil, mask);
4579
4978
  bxhih = _mm_and_si128(bxhih, mask);
4580
- __m128i bxl = _mm256_castsi256_si128(bx);
4581
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
4979
+ __m128i bxl = _mm256_castsi256_si128(bx_0);
4980
+ __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
4582
4981
  bxl = _mm_or_si128(bxl, bxhil);
4583
4982
  bxh = _mm_or_si128(bxh, bxhih);
4584
- bx = MM256_SET_M128I(bxh, bxl);
4983
+ bx_0 = MM256_SET_M128I(bxh, bxl);
4585
4984
 
4586
4985
  const __m256 dy = _mm256_set1_ps(y[i].d);
4587
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
4986
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
4588
4987
 
4589
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
4988
+ const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
4590
4989
 
4591
4990
  acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
4592
4991
  }
@@ -4681,15 +5080,79 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
4681
5080
  #endif
4682
5081
  }
4683
5082
 
4684
- void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5083
+ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4685
5084
  const int qk = QK8_0;
4686
5085
  const int nb = n / qk;
4687
5086
 
4688
5087
  assert(n % qk == 0);
5088
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
5089
+ assert((nrc == 2) || (nrc == 1));
5090
+ #else
5091
+ assert(nrc == 1);
5092
+ #endif
5093
+ UNUSED(nrc);
5094
+ UNUSED(bx);
5095
+ UNUSED(by);
5096
+ UNUSED(bs);
4689
5097
 
4690
5098
  const block_q8_0 * restrict x = vx;
4691
5099
  const block_q8_0 * restrict y = vy;
4692
5100
 
5101
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
5102
+ if (nrc == 2) {
5103
+ const block_q8_0 * restrict vx0 = vx;
5104
+ const block_q8_0 * restrict vx1 = vx + bx;
5105
+ const block_q8_0 * restrict vy0 = vy;
5106
+ const block_q8_0 * restrict vy1 = vy + by;
5107
+
5108
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
5109
+
5110
+ for (int i = 0; i < nb; i++) {
5111
+ const block_q8_0 * restrict b_x0 = &vx0[i];
5112
+ const block_q8_0 * restrict b_y0 = &vy0[i];
5113
+
5114
+ const block_q8_0 * restrict b_x1 = &vx1[i];
5115
+ const block_q8_0 * restrict b_y1 = &vy1[i];
5116
+
5117
+ const int8x16_t x0_l = vld1q_s8(b_x0->qs);
5118
+ const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
5119
+ const int8x16_t x1_l = vld1q_s8(b_x1->qs);
5120
+ const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
5121
+
5122
+ // load y
5123
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
5124
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
5125
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
5126
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
5127
+
5128
+ float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
5129
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
5130
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
5131
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
5132
+
5133
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
5134
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
5135
+
5136
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
5137
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
5138
+
5139
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
5140
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
5141
+
5142
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
5143
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
5144
+
5145
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
5146
+ l1, r1)), l2, r2)), l3, r3))), scale);
5147
+ }
5148
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
5149
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
5150
+
5151
+ vst1_f32(s, vget_low_f32(sumv2));
5152
+ vst1_f32(s + bs, vget_high_f32(sumv2));
5153
+ return;
5154
+ }
5155
+ #endif
4693
5156
  #if defined(__ARM_NEON)
4694
5157
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
4695
5158
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -4731,10 +5194,10 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
4731
5194
  for (int i = 0; i < nb; ++i) {
4732
5195
  // Compute combined scale for the block
4733
5196
  const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
4734
- __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
4735
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
5197
+ __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs);
5198
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
4736
5199
 
4737
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
5200
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
4738
5201
 
4739
5202
  // Multiply q with scale and accumulate
4740
5203
  #if defined(__AVX2__)
@@ -4751,10 +5214,10 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
4751
5214
 
4752
5215
  for (int i = 0; i < nb; i++) {
4753
5216
  // load elements
4754
- vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
4755
- vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
5217
+ vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl);
5218
+ vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
4756
5219
 
4757
- vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
5220
+ vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
4758
5221
 
4759
5222
  vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
4760
5223
  vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
@@ -4784,7 +5247,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
4784
5247
  }
4785
5248
 
4786
5249
  #if QK_K == 256
4787
- void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5250
+ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5251
+ assert(nrc == 1);
5252
+ UNUSED(nrc);
5253
+ UNUSED(bx);
5254
+ UNUSED(by);
5255
+ UNUSED(bs);
4788
5256
 
4789
5257
  const block_q2_K * restrict x = vx;
4790
5258
  const block_q8_K * restrict y = vy;
@@ -5160,7 +5628,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
5160
5628
 
5161
5629
  #else
5162
5630
 
5163
- void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5631
+ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5632
+ assert(nrc == 1);
5633
+ UNUSED(nrc);
5634
+ UNUSED(bx);
5635
+ UNUSED(by);
5636
+ UNUSED(bs);
5164
5637
 
5165
5638
  const block_q2_K * restrict x = vx;
5166
5639
  const block_q8_K * restrict y = vy;
@@ -5181,8 +5654,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
5181
5654
 
5182
5655
  for (int i = 0; i < nb; ++i) {
5183
5656
 
5184
- const float d = y[i].d * (float)x[i].d;
5185
- const float dmin = -y[i].d * (float)x[i].dmin;
5657
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5658
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5186
5659
 
5187
5660
  const uint8_t * restrict q2 = x[i].qs;
5188
5661
  const int8_t * restrict q8 = y[i].qs;
@@ -5331,8 +5804,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
5331
5804
 
5332
5805
  for (int i = 0; i < nb; ++i) {
5333
5806
 
5334
- const float d = y[i].d * (float)x[i].d;
5335
- const float dmin = -y[i].d * (float)x[i].dmin;
5807
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5808
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5336
5809
 
5337
5810
  const uint8_t * restrict q2 = x[i].qs;
5338
5811
  const int8_t * restrict q8 = y[i].qs;
@@ -5418,8 +5891,13 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
5418
5891
  #endif
5419
5892
 
5420
5893
  #if QK_K == 256
5421
- void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
5894
+ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5422
5895
  assert(n % QK_K == 0);
5896
+ assert(nrc == 1);
5897
+ UNUSED(nrc);
5898
+ UNUSED(bx);
5899
+ UNUSED(by);
5900
+ UNUSED(bs);
5423
5901
 
5424
5902
  const uint32_t kmask1 = 0x03030303;
5425
5903
  const uint32_t kmask2 = 0x0f0f0f0f;
@@ -5938,8 +6416,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
5938
6416
 
5939
6417
  #else
5940
6418
 
5941
- void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
6419
+ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
5942
6420
  assert(n % QK_K == 0);
6421
+ assert(nrc == 1);
6422
+ UNUSED(nrc);
6423
+ UNUSED(bx);
6424
+ UNUSED(by);
6425
+ UNUSED(bs);
5943
6426
 
5944
6427
  const block_q3_K * restrict x = vx;
5945
6428
  const block_q8_K * restrict y = vy;
@@ -5975,7 +6458,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
5975
6458
 
5976
6459
  int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
5977
6460
 
5978
- const float d = y[i].d * (float)x[i].d;
6461
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5979
6462
 
5980
6463
  const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
5981
6464
  q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
@@ -6177,7 +6660,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
6177
6660
 
6178
6661
  int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
6179
6662
 
6180
- const float d = y[i].d * (float)x[i].d;
6663
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6181
6664
 
6182
6665
  vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
6183
6666
 
@@ -6281,8 +6764,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
6281
6764
  #endif
6282
6765
 
6283
6766
  #if QK_K == 256
6284
- void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
6767
+ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6285
6768
  assert(n % QK_K == 0);
6769
+ assert(nrc == 1);
6770
+ UNUSED(nrc);
6771
+ UNUSED(bx);
6772
+ UNUSED(by);
6773
+ UNUSED(bs);
6286
6774
 
6287
6775
  const block_q4_K * restrict x = vx;
6288
6776
  const block_q8_K * restrict y = vy;
@@ -6637,8 +7125,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
6637
7125
  #endif
6638
7126
  }
6639
7127
  #else
6640
- void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7128
+ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6641
7129
  assert(n % QK_K == 0);
7130
+ assert(nrc == 1);
7131
+ UNUSED(nrc);
7132
+ UNUSED(bx);
7133
+ UNUSED(by);
7134
+ UNUSED(bs);
6642
7135
 
6643
7136
  const block_q4_K * restrict x = vx;
6644
7137
  const block_q8_K * restrict y = vy;
@@ -6670,9 +7163,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
6670
7163
  aux16[1] = (a[0] >> 4) & 0x0f0f;
6671
7164
 
6672
7165
  const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
6673
- sum_mins += y[i].d * (float)x[i].d[1] * summi;
7166
+ sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi;
6674
7167
 
6675
- const float d = y[i].d * (float)x[i].d[0];
7168
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
6676
7169
 
6677
7170
  const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
6678
7171
 
@@ -6880,8 +7373,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
6880
7373
  #endif
6881
7374
 
6882
7375
  #if QK_K == 256
6883
- void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7376
+ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
6884
7377
  assert(n % QK_K == 0);
7378
+ assert(nrc == 1);
7379
+ UNUSED(nrc);
7380
+ UNUSED(bx);
7381
+ UNUSED(by);
7382
+ UNUSED(bs);
6885
7383
 
6886
7384
  const block_q5_K * restrict x = vx;
6887
7385
  const block_q8_K * restrict y = vy;
@@ -7300,8 +7798,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
7300
7798
 
7301
7799
  #else
7302
7800
 
7303
- void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
7801
+ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
7304
7802
  assert(n % QK_K == 0);
7803
+ assert(nrc == 1);
7804
+ UNUSED(nrc);
7805
+ UNUSED(bx);
7806
+ UNUSED(by);
7807
+ UNUSED(bs);
7305
7808
 
7306
7809
  const block_q5_K * restrict x = vx;
7307
7810
  const block_q8_K * restrict y = vy;
@@ -7320,7 +7823,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
7320
7823
 
7321
7824
  for (int i = 0; i < nb; ++i) {
7322
7825
 
7323
- const float d = y[i].d * (float)x[i].d;
7826
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
7324
7827
  const int8_t * sc = x[i].scales;
7325
7828
 
7326
7829
  const uint8_t * restrict q5 = x[i].qs;
@@ -7462,7 +7965,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
7462
7965
 
7463
7966
  for (int i = 0; i < nb; ++i) {
7464
7967
 
7465
- const float d = y[i].d * (float)x[i].d;
7968
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
7466
7969
  const int8_t * sc = x[i].scales;
7467
7970
 
7468
7971
  const uint8_t * restrict q5 = x[i].qs;
@@ -7566,8 +8069,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
7566
8069
 
7567
8070
 
7568
8071
  #if QK_K == 256
7569
- void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8072
+ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
7570
8073
  assert(n % QK_K == 0);
8074
+ assert(nrc == 1);
8075
+ UNUSED(nrc);
8076
+ UNUSED(bx);
8077
+ UNUSED(by);
8078
+ UNUSED(bs);
7571
8079
 
7572
8080
  const block_q6_K * restrict x = vx;
7573
8081
  const block_q8_K * restrict y = vy;
@@ -7998,8 +8506,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
7998
8506
 
7999
8507
  #else
8000
8508
 
8001
- void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8509
+ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8002
8510
  assert(n % QK_K == 0);
8511
+ assert(nrc == 1);
8512
+ UNUSED(nrc);
8513
+ UNUSED(bx);
8514
+ UNUSED(by);
8515
+ UNUSED(bs);
8003
8516
 
8004
8517
  const block_q6_K * restrict x = vx;
8005
8518
  const block_q8_K * restrict y = vy;
@@ -8020,7 +8533,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
8020
8533
 
8021
8534
  for (int i = 0; i < nb; ++i) {
8022
8535
 
8023
- const float d_all = (float)x[i].d;
8536
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
8024
8537
 
8025
8538
  const uint8_t * restrict q6 = x[i].ql;
8026
8539
  const uint8_t * restrict qh = x[i].qh;
@@ -8191,7 +8704,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
8191
8704
 
8192
8705
  for (int i = 0; i < nb; ++i) {
8193
8706
 
8194
- const float d_all = (float)x[i].d;
8707
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
8195
8708
 
8196
8709
  const uint8_t * restrict q6 = x[i].ql;
8197
8710
  const uint8_t * restrict qh = x[i].qh;
@@ -8328,8 +8841,13 @@ static const int8_t keven_signs_q2xs[1024] = {
8328
8841
  1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
8329
8842
  };
8330
8843
 
8331
- void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8844
+ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8332
8845
  assert(n % QK_K == 0);
8846
+ assert(nrc == 1);
8847
+ UNUSED(nrc);
8848
+ UNUSED(bx);
8849
+ UNUSED(by);
8850
+ UNUSED(bs);
8333
8851
 
8334
8852
  const block_iq2_xxs * restrict x = vx;
8335
8853
  const block_q8_K * restrict y = vy;
@@ -8451,8 +8969,13 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
8451
8969
  #endif
8452
8970
  }
8453
8971
 
8454
- void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
8972
+ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8455
8973
  assert(n % QK_K == 0);
8974
+ assert(nrc == 1);
8975
+ UNUSED(nrc);
8976
+ UNUSED(bx);
8977
+ UNUSED(by);
8978
+ UNUSED(bs);
8456
8979
 
8457
8980
  const block_iq2_xs * restrict x = vx;
8458
8981
  const block_q8_K * restrict y = vy;
@@ -8670,9 +9193,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
8670
9193
  #endif
8671
9194
  }
8672
9195
 
8673
- // TODO
8674
- void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
9196
+ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
8675
9197
  assert(n % QK_K == 0);
9198
+ assert(nrc == 1);
9199
+ UNUSED(nrc);
9200
+ UNUSED(bx);
9201
+ UNUSED(by);
9202
+ UNUSED(bs);
8676
9203
 
8677
9204
  const block_iq3_xxs * restrict x = vx;
8678
9205
  const block_q8_K * restrict y = vy;
@@ -8698,10 +9225,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res
8698
9225
  for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8699
9226
  q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
8700
9227
  memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
8701
- const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
8702
- const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
8703
- const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
8704
- const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
9228
+ const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
9229
+ const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
9230
+ const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
9231
+ const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
8705
9232
  q3 += 16;
8706
9233
  q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
8707
9234
  q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
@@ -8800,6 +9327,271 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res
8800
9327
  #endif
8801
9328
  }
8802
9329
 
9330
+ #ifdef __AVX2__
9331
+ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9332
+ const __m256i ax = _mm256_sign_epi8(x, x);
9333
+ const __m256i sy = _mm256_sign_epi8(y, x);
9334
+ return _mm256_maddubs_epi16(ax, sy);
9335
+ }
9336
+ #endif
9337
+
9338
+ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
9339
+ assert(n % QK_K == 0);
9340
+ assert(nrc == 1);
9341
+ UNUSED(nrc);
9342
+ UNUSED(bx);
9343
+ UNUSED(by);
9344
+ UNUSED(bs);
9345
+
9346
+ const block_iq1_s * restrict x = vx;
9347
+ const block_q8_K * restrict y = vy;
9348
+
9349
+ const int nb = n / QK_K;
9350
+
9351
+ #if defined __ARM_NEON
9352
+
9353
+ const uint8x16_t m8 = vdupq_n_u8(0x08);
9354
+ const uint8x16_t m7 = vdupq_n_u8(0x07);
9355
+ const uint8x16_t m1 = vdupq_n_u8(0x01);
9356
+ const int32x4_t vzero = vdupq_n_s32(0);
9357
+
9358
+ uint16_t gindex[8];
9359
+ uint16x8x2_t vindex;
9360
+ int8x16x4_t q1b;
9361
+ ggml_int8x16x4_t q8b;
9362
+ uint16x8x4_t scales;
9363
+ int32x4x2_t sumi;
9364
+ int32x4x2_t dotq;
9365
+
9366
+ float sumf = 0;
9367
+ for (int i = 0; i < nb; ++i) {
9368
+
9369
+ const int8_t * q8 = y[i].qs;
9370
+ const uint8_t * qs = x[i].qs;
9371
+ const uint8_t * sc = x[i].scales;
9372
+
9373
+ sumi.val[0] = sumi.val[1] = vzero;
9374
+
9375
+ for (int i128 = 0; i128 < QK_K/128; ++i128) {
9376
+ const uint8x16_t ql = vld1q_u8(qs); qs += 16;
9377
+ const uint8x8_t tm1 = vld1_u8 (sc); sc += 8;
9378
+ const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
9379
+ const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
9380
+ const uint8x16_t hbit = vandq_u8(qh, m8);
9381
+ vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
9382
+ vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
9383
+ const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
9384
+ scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
9385
+ scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
9386
+
9387
+ for (int l = 0; l < 2; ++l) {
9388
+ vst1q_u16(gindex+0, vindex.val[l]);
9389
+ q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
9390
+ q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
9391
+ q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
9392
+ q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
9393
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
9394
+
9395
+ dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
9396
+ dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
9397
+
9398
+ sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
9399
+ sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
9400
+ }
9401
+ }
9402
+
9403
+ sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
9404
+ }
9405
+
9406
+ *s = sumf;
9407
+
9408
+ #elif defined __AVX2__
9409
+
9410
+ const __m128i m8 = _mm_set1_epi8(0x08);
9411
+ const __m128i m7 = _mm_set1_epi8(0x07);
9412
+ const __m128i m1 = _mm_set1_epi8(0x01);
9413
+ const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
9414
+ const __m128i shuffle_s[4] = {
9415
+ _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
9416
+ _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
9417
+ _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
9418
+ _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
9419
+ };
9420
+
9421
+ uint64_t aux64;
9422
+
9423
+ __m256i v_gindex;
9424
+ const uint16_t * gindex = (const uint16_t *)&v_gindex;
9425
+
9426
+ __m256 accum = _mm256_setzero_ps();
9427
+ for (int i = 0; i < nb; ++i) {
9428
+
9429
+ const int8_t * q8 = y[i].qs;
9430
+ const uint8_t * qs = x[i].qs;
9431
+ const uint8_t * sc = x[i].scales;
9432
+
9433
+ __m256i sumi = _mm256_setzero_si256();
9434
+ for (int i128 = 0; i128 < QK_K/128; ++i128) {
9435
+ const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
9436
+ memcpy(&aux64, sc, 8); sc += 8;
9437
+ const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
9438
+ const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
9439
+ v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
9440
+ const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
9441
+
9442
+ for (int i32 = 0; i32 < 4; ++i32) {
9443
+ const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
9444
+ const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
9445
+ iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
9446
+ const __m256i dot = mul_add_epi8(q1b, q8b);
9447
+ const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
9448
+ const __m256i p = _mm256_madd_epi16(s16, dot);
9449
+ sumi = _mm256_add_epi32(sumi, p);
9450
+ }
9451
+
9452
+ }
9453
+
9454
+ accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
9455
+
9456
+ }
9457
+
9458
+ *s = hsum_float_8(accum);
9459
+
9460
+ #else
9461
+
9462
+ int db[4];
9463
+ uint16_t idx[4];
9464
+
9465
+ float sumf = 0;
9466
+ for (int i = 0; i < nb; ++i) {
9467
+
9468
+ const int8_t * q8 = y[i].qs;
9469
+ const uint8_t * qs = x[i].qs;
9470
+ const uint8_t * sc = x[i].scales;
9471
+
9472
+ int sumi = 0;
9473
+ for (int i32 = 0; i32 < QK_K/32; ++i32) {
9474
+ idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
9475
+ idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
9476
+ idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
9477
+ idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
9478
+ db[0] = (2*(sc[0] & 7) + 1);
9479
+ db[1] = (2*((sc[0] >> 4) & 7) + 1);
9480
+ db[2] = (2*(sc[1] & 7) + 1);
9481
+ db[3] = (2*((sc[1] >> 4) & 7) + 1);
9482
+ for (int l = 0; l < 4; ++l) {
9483
+ const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
9484
+ int suml = 0;
9485
+ for (int j = 0; j < 8; ++j) suml += q8[j] * grid[j];
9486
+ sumi += db[l] * suml;
9487
+ q8 += 8;
9488
+ }
9489
+ qs += 4;
9490
+ sc += 2;
9491
+ }
9492
+
9493
+ sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
9494
+ }
9495
+
9496
+ *s = sumf;
9497
+
9498
+ #endif
9499
+ }
9500
+
9501
+ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
9502
+ assert(nrc == 1);
9503
+ UNUSED(nrc);
9504
+ UNUSED(bx);
9505
+ UNUSED(by);
9506
+ UNUSED(bs);
9507
+ assert(n % QK4_NL == 0);
9508
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
9509
+
9510
+ const block_iq4_nl * restrict x = vx;
9511
+ const block_q8_0 * restrict y = vy;
9512
+
9513
+ const int nb = n / QK4_NL;
9514
+
9515
+ #if defined __ARM_NEON
9516
+ const int8x16_t values = vld1q_s8(kvalues_iq4nl);
9517
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
9518
+ uint8x16x2_t q4bits;
9519
+ int8x16x4_t q4b;
9520
+ int8x16x4_t q8b;
9521
+ int32x4_t prod_1, prod_2;
9522
+
9523
+ float sumf = 0;
9524
+
9525
+ for (int ib = 0; ib < nb; ib += 2) {
9526
+ q4bits.val[0] = vld1q_u8(x[ib+0].qs);
9527
+ q4bits.val[1] = vld1q_u8(x[ib+1].qs);
9528
+ q8b.val[0] = vld1q_s8(y[ib+0].qs);
9529
+ q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
9530
+ q8b.val[2] = vld1q_s8(y[ib+1].qs);
9531
+ q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
9532
+
9533
+ q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
9534
+ q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
9535
+ q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
9536
+ q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
9537
+
9538
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
9539
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
9540
+
9541
+ sumf +=
9542
+ GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
9543
+ GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
9544
+ }
9545
+
9546
+ *s = sumf;
9547
+
9548
+ #elif defined __AVX2__
9549
+
9550
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
9551
+ const __m128i m4b = _mm_set1_epi8(0x0f);
9552
+ const __m256i mone = _mm256_set1_epi16(1);
9553
+
9554
+ __m256 accum1 = _mm256_setzero_ps();
9555
+ __m256 accum2 = _mm256_setzero_ps();
9556
+ for (int ib = 0; ib < nb; ib += 2) {
9557
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
9558
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
9559
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
9560
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
9561
+ const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
9562
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
9563
+ const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
9564
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
9565
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
9566
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
9567
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
9568
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
9569
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
9570
+ _mm256_cvtepi32_ps(p_1), accum1);
9571
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
9572
+ _mm256_cvtepi32_ps(p_2), accum2);
9573
+
9574
+ y += 2;
9575
+ x += 2;
9576
+ }
9577
+
9578
+ *s = hsum_float_8(_mm256_add_ps(accum1, accum2));
9579
+
9580
+ #else
9581
+ float sumf = 0;
9582
+ for (int ib = 0; ib < nb; ++ib) {
9583
+ const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
9584
+ int sumi1 = 0, sumi2 = 0;
9585
+ for (int j = 0; j < QK4_NL/2; ++j) {
9586
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
9587
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
9588
+ }
9589
+ sumf += d * (sumi1 + sumi2);
9590
+ }
9591
+ *s = sumf;
9592
+ #endif
9593
+ }
9594
+
8803
9595
  // ================================ IQ2 quantization =============================================
8804
9596
 
8805
9597
  typedef struct {
@@ -8808,14 +9600,22 @@ typedef struct {
8808
9600
  uint16_t * neighbours;
8809
9601
  } iq2_entry_t;
8810
9602
 
8811
- static iq2_entry_t iq2_data[2] = {
9603
+ static iq2_entry_t iq2_data[3] = {
9604
+ {NULL, NULL, NULL},
8812
9605
  {NULL, NULL, NULL},
8813
9606
  {NULL, NULL, NULL},
8814
9607
  };
8815
9608
 
8816
- static inline int iq2_data_index(int grid_size) {
8817
- GGML_ASSERT(grid_size == 256 || grid_size == 512);
8818
- return grid_size == 256 ? 0 : 1;
9609
+ static inline int iq2_data_index(enum ggml_type type) {
9610
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9611
+ return type == GGML_TYPE_IQ2_XXS ? 0 :
9612
+ type == GGML_TYPE_IQ2_XS ? 1 : 2;
9613
+ }
9614
+
9615
+ static inline int iq2_grid_size(enum ggml_type type) {
9616
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9617
+ return type == GGML_TYPE_IQ2_XXS ? 256 :
9618
+ type == GGML_TYPE_IQ2_XS ? 512 : 512;
8819
9619
  }
8820
9620
 
8821
9621
  static int iq2_compare_func(const void * left, const void * right) {
@@ -8824,12 +9624,13 @@ static int iq2_compare_func(const void * left, const void * right) {
8824
9624
  return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
8825
9625
  }
8826
9626
 
8827
- void iq2xs_init_impl(int grid_size) {
8828
- const int gindex = iq2_data_index(grid_size);
9627
+ void iq2xs_init_impl(enum ggml_type type) {
9628
+ const int gindex = iq2_data_index(type);
9629
+ const int grid_size = iq2_grid_size(type);
8829
9630
  if (iq2_data[gindex].grid) {
8830
9631
  return;
8831
9632
  }
8832
- static const uint16_t kgrid_256[256] = {
9633
+ static const uint16_t kgrid_2bit_256[256] = {
8833
9634
  0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
8834
9635
  100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
8835
9636
  1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
@@ -8847,7 +9648,7 @@ void iq2xs_init_impl(int grid_size) {
8847
9648
  33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
8848
9649
  37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
8849
9650
  };
8850
- static const uint16_t kgrid_512[512] = {
9651
+ static const uint16_t kgrid_2bit_512[512] = {
8851
9652
  0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
8852
9653
  73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
8853
9654
  260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
@@ -8881,9 +9682,45 @@ void iq2xs_init_impl(int grid_size) {
8881
9682
  40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
8882
9683
  42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
8883
9684
  };
9685
+ static const uint16_t kgrid_1bit_512[512] = {
9686
+ 10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545,
9687
+ 553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444,
9688
+ 1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440,
9689
+ 2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422,
9690
+ 4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397,
9691
+ 5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769,
9692
+ 5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788,
9693
+ 6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794,
9694
+ 9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272,
9695
+ 10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
9696
+ 16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
9697
+ 17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
9698
+ 18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
9699
+ 20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
9700
+ 20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
9701
+ 21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
9702
+ 21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
9703
+ 21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
9704
+ 22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
9705
+ 23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
9706
+ 25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
9707
+ 25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
9708
+ 26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
9709
+ 32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
9710
+ 33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
9711
+ 34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
9712
+ 37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
9713
+ 38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
9714
+ 38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
9715
+ 39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
9716
+ 41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
9717
+ 42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
9718
+ };
9719
+
8884
9720
  const int kmap_size = 43692;
8885
- const int nwant = 2;
8886
- const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
9721
+ const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
9722
+ const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
9723
+ type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 : kgrid_1bit_512;
8887
9724
  uint64_t * kgrid_q2xs;
8888
9725
  int * kmap_q2xs;
8889
9726
  uint16_t * kneighbors_q2xs;
@@ -8979,9 +9816,9 @@ void iq2xs_init_impl(int grid_size) {
8979
9816
  free(dist2);
8980
9817
  }
8981
9818
 
8982
- void iq2xs_free_impl(int grid_size) {
8983
- GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
8984
- const int gindex = iq2_data_index(grid_size);
9819
+ void iq2xs_free_impl(enum ggml_type type) {
9820
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S);
9821
+ const int gindex = iq2_data_index(type);
8985
9822
  if (iq2_data[gindex].grid) {
8986
9823
  free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
8987
9824
  free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
@@ -9015,7 +9852,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u
9015
9852
 
9016
9853
  static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9017
9854
 
9018
- const int gindex = iq2_data_index(256);
9855
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
9019
9856
 
9020
9857
  const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9021
9858
  const int * kmap_q2xs = iq2_data[gindex].map;
@@ -9188,7 +10025,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
9188
10025
 
9189
10026
  static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
9190
10027
 
9191
- const int gindex = iq2_data_index(512);
10028
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
9192
10029
 
9193
10030
  const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
9194
10031
  const int * kmap_q2xs = iq2_data[gindex].map;
@@ -9825,3 +10662,327 @@ void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * re
9825
10662
  assert(k % QK_K == 0);
9826
10663
  quantize_row_iq3_xxs_impl(x, y, k, NULL);
9827
10664
  }
10665
+
10666
+ // =================================== 1.5 bpw ===================================================
10667
+
10668
+ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
10669
+ const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
10670
+ int num_neighbors = neighbours[0];
10671
+ GGML_ASSERT(num_neighbors > 0);
10672
+ float best_score = 0;
10673
+ int grid_index = -1;
10674
+ for (int j = 1; j <= num_neighbors; ++j) {
10675
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
10676
+ float sumqx = 0, sumq2 = 0;
10677
+ for (int i = 0; i < 8; ++i) {
10678
+ float q = (pg[i] - 3)/2;
10679
+ float w = weight[i];
10680
+ sumqx += w*q*xval[i];
10681
+ sumq2 += w*q*q;
10682
+ }
10683
+ if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10684
+ *scale = sumqx/sumq2; best_score = *scale * sumqx;
10685
+ grid_index = neighbours[j];
10686
+ }
10687
+ }
10688
+ if (grid_index < 0) {
10689
+ for (int i = 0; i < ngrid; ++i) {
10690
+ const int8_t * grid_i = (const int8_t *)(grid + i);
10691
+ float sumqx = 0, sumq2 = 0;
10692
+ for (int j = 0; j < 8; ++j) {
10693
+ float w = weight[j];
10694
+ float q = (grid_i[j] - 3)/2;
10695
+ sumqx += w*q*xval[j];
10696
+ sumq2 += w*q*q;
10697
+ }
10698
+ if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10699
+ *scale = sumqx/sumq2; best_score = *scale*sumqx;
10700
+ grid_index = i;
10701
+ }
10702
+ }
10703
+ }
10704
+ if (grid_index < 0) {
10705
+ printf("Oops, did not find grid point\n");
10706
+ printf("Have %d neighbours\n", num_neighbors);
10707
+ for (int j = 1; j <= num_neighbors; ++j) {
10708
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
10709
+ float sumqx = 0, sumq2 = 0;
10710
+ for (int i = 0; i < 8; ++i) {
10711
+ float q = (pg[i] - 3)/2;
10712
+ float w = weight[i];
10713
+ sumqx += w*q*xval[i];
10714
+ sumq2 += w*q*q;
10715
+ }
10716
+ printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
10717
+ }
10718
+ }
10719
+ GGML_ASSERT(grid_index >= 0);
10720
+ //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
10721
+ *scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result.
10722
+ //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
10723
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
10724
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
10725
+ return grid_index;
10726
+ }
10727
+
10728
+ static int iq1_sort_helper(const void * left, const void * right) {
10729
+ const float * l = left;
10730
+ const float * r = right;
10731
+ return *l < *r ? -1 : *l > *r ? 1 : 0;
10732
+ }
10733
+
10734
+ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
10735
+
10736
+ const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
10737
+
10738
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
10739
+ const int * kmap_q2xs = iq2_data[gindex].map;
10740
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
10741
+
10742
+ GGML_ASSERT(quant_weights && "missing quantization weights");
10743
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
10744
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
10745
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
10746
+ GGML_ASSERT(n%QK_K == 0);
10747
+
10748
+ const int nbl = n/256;
10749
+
10750
+ block_iq1_s * y = vy;
10751
+
10752
+ float scales[QK_K/8];
10753
+ float weight[8];
10754
+ int8_t L[8];
10755
+ float sumx[9];
10756
+ float sumw[9];
10757
+ float pairs[16];
10758
+ int * idx = (int *)(pairs + 1);
10759
+ uint8_t hbit[QK_K/8];
10760
+
10761
+ for (int ibl = 0; ibl < nbl; ++ibl) {
10762
+
10763
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
10764
+ memset(y[ibl].qs, 0, QK_K/8);
10765
+ memset(y[ibl].scales, 0, QK_K/16);
10766
+
10767
+ float max_scale = 0;
10768
+
10769
+ const float * xbl = x + QK_K*ibl;
10770
+ float sumx2 = 0;
10771
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
10772
+ float sigma2 = sumx2/QK_K;
10773
+
10774
+ for (int ib = 0; ib < QK_K/8; ++ib) {
10775
+ const float * xb = xbl + 8*ib;
10776
+ const float * qw = quant_weights + QK_K*ibl + 8*ib;
10777
+ for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
10778
+ float max = fabsf(xb[0]);
10779
+ for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
10780
+ if (!max) {
10781
+ scales[ib] = 0;
10782
+ memset(L, 1, 8);
10783
+ continue;
10784
+ }
10785
+ // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
10786
+ // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
10787
+ // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
10788
+ // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
10789
+ // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
10790
+ // for each possible and score for each split.
10791
+ for (int j = 0; j < 8; ++j) {
10792
+ pairs[2*j] = xb[j];
10793
+ idx[2*j] = j;
10794
+ }
10795
+ qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
10796
+ {
10797
+ sumx[0] = sumw[0] = 0;
10798
+ for (int j = 0; j < 8; ++j) {
10799
+ int i = idx[2*j];
10800
+ sumx[j+1] = sumx[j] + weight[i]*xb[i];
10801
+ sumw[j+1] = sumw[j] + weight[i];
10802
+ }
10803
+ }
10804
+ float best_score = 0, scale = max;
10805
+ int besti1 = 0, besti2 = 0;
10806
+ for (int i1 = 0; i1 <= 8; ++i1) {
10807
+ for (int i2 = i1; i2 <= 8; ++i2) {
10808
+ float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
10809
+ float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
10810
+ if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
10811
+ scale = sumqx/sumq2; best_score = scale*sumqx;
10812
+ besti1 = i1; besti2 = i2;
10813
+ }
10814
+ }
10815
+ }
10816
+ for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
10817
+ for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
10818
+ for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2;
10819
+ if (scale < 0) {
10820
+ for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
10821
+ scale = -scale;
10822
+ }
10823
+ // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
10824
+ // grid point that minimizes SSD.
10825
+ uint16_t u = 0;
10826
+ for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
10827
+ int grid_index = kmap_q2xs[u];
10828
+ if (grid_index < 0) {
10829
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
10830
+ grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
10831
+ GGML_ASSERT(grid_index >= 0);
10832
+ }
10833
+ y[ibl].qs[ib] = grid_index & 255;
10834
+ hbit[ib] = grid_index >> 8;
10835
+ GGML_ASSERT(scale >= 0);
10836
+ scales[ib] = scale;
10837
+ max_scale = MAX(max_scale, scale);
10838
+ }
10839
+
10840
+ if (!max_scale) {
10841
+ memset(y[ibl].qs, 0, QK_K/8);
10842
+ continue;
10843
+ }
10844
+
10845
+ float d = max_scale/15;
10846
+ y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
10847
+ float id = 1/d;
10848
+ for (int ib = 0; ib < QK_K/8; ++ib) {
10849
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
10850
+ l = MAX(0, MIN(7, l));
10851
+ if (hbit[ib]) l |= 8;
10852
+ y[ibl].scales[ib/2] |= (l << 4*(ib%2));
10853
+ }
10854
+ }
10855
+ }
10856
+
10857
+ size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
10858
+ (void)hist;
10859
+ GGML_ASSERT(n_per_row%QK_K == 0);
10860
+ int nblock = n_per_row/QK_K;
10861
+ char * qrow = (char *)dst;
10862
+ for (int row = 0; row < nrow; ++row) {
10863
+ quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
10864
+ src += n_per_row;
10865
+ qrow += nblock*sizeof(block_iq1_s);
10866
+ }
10867
+ return nrow * nblock * sizeof(block_iq1_s);
10868
+ }
10869
+
10870
+ // ============================ 4-bit non-linear quants
10871
+
10872
+ static inline int best_index_int8(int n, const int8_t * val, float x) {
10873
+ if (x <= val[0]) return 0;
10874
+ if (x >= val[n-1]) return n-1;
10875
+ int ml = 0, mu = n-1;
10876
+ while (mu-ml > 1) {
10877
+ int mav = (ml+mu)/2;
10878
+ if (x < val[mav]) mu = mav; else ml = mav;
10879
+ }
10880
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
10881
+ }
10882
+
10883
+ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x,
10884
+ ggml_fp16_t * dh, uint8_t * q4,
10885
+ float * weight, uint8_t * L,
10886
+ const int8_t * values,
10887
+ const float * quant_weights) {
10888
+
10889
+ const int ntry = 7;
10890
+
10891
+ float sigma2 = 0;
10892
+ for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j];
10893
+ sigma2 *= 2.f/QK4_NL;
10894
+
10895
+ const int nb = QK4_NL/block_size;
10896
+
10897
+ memset(q4, 0, QK4_NL/2);
10898
+ for (int ib = 0; ib < nb; ++ib) {
10899
+ dh[ib] = GGML_FP32_TO_FP16(0.f);
10900
+ const float * xb = x + ib*block_size;
10901
+ if (quant_weights) {
10902
+ const float * qw = quant_weights + ib*block_size;
10903
+ for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
10904
+ } else {
10905
+ for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
10906
+ }
10907
+ float amax = 0, max = 0;
10908
+ for (int j = 0; j < block_size; ++j) {
10909
+ float ax = fabsf(xb[j]);
10910
+ if (ax > amax) {
10911
+ amax = ax; max = xb[j];
10912
+ }
10913
+ }
10914
+ if (!amax) {
10915
+ continue;
10916
+ }
10917
+ float d = -max/values[0];
10918
+ float id = 1/d;
10919
+ float sumqx = 0, sumq2 = 0;
10920
+ for (int j = 0; j < block_size; ++j) {
10921
+ float al = id*xb[j];
10922
+ int l = best_index_int8(16, values, al);
10923
+ float q = values[l];
10924
+ float w = weight[j];
10925
+ sumqx += w*q*xb[j];
10926
+ sumq2 += w*q*q;
10927
+ }
10928
+ float best_id = id;
10929
+ d = sumqx/sumq2;
10930
+ float best = d*sumqx;
10931
+ for (int itry = -ntry; itry <= ntry; ++itry) {
10932
+ id = (itry + values[0])/max;
10933
+ sumqx = sumq2 = 0;
10934
+ for (int j = 0; j < block_size; ++j) {
10935
+ float al = id*xb[j];
10936
+ int l = best_index_int8(16, values, al);
10937
+ float q = values[l];
10938
+ float w = weight[j];
10939
+ sumqx += w*q*xb[j];
10940
+ sumq2 += w*q*q;
10941
+ }
10942
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
10943
+ d = sumqx/sumq2; best = d * sumqx;
10944
+ best_id = id;
10945
+ }
10946
+ }
10947
+ dh[ib] = GGML_FP32_TO_FP16(d);
10948
+ for (int j = 0; j < block_size; ++j) {
10949
+ L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]);
10950
+ }
10951
+ }
10952
+ for (int i = 0; i < QK4_NL/32; ++i) {
10953
+ for (int j = 0; j < 16; ++j) {
10954
+ q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
10955
+ }
10956
+ }
10957
+ }
10958
+
10959
+ size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
10960
+ (void)hist;
10961
+ GGML_ASSERT(n_per_row%QK4_NL == 0);
10962
+ int nblock = n_per_row/QK4_NL;
10963
+ char * qrow = (char *)dst;
10964
+ uint8_t L[QK4_NL];
10965
+ float weight[32];
10966
+ for (int row = 0; row < nrow; ++row) {
10967
+ block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
10968
+ for (int ibl = 0; ibl < nblock; ++ibl) {
10969
+ const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
10970
+ quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
10971
+ }
10972
+ src += n_per_row;
10973
+ qrow += nblock*sizeof(block_iq4_nl);
10974
+ }
10975
+ return nrow * nblock * sizeof(block_iq4_nl);
10976
+ }
10977
+
10978
+ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
10979
+ assert(k % QK4_NL == 0);
10980
+ block_iq4_nl * restrict y = vy;
10981
+ quantize_row_iq4_nl_reference(x, y, k);
10982
+ }
10983
+
10984
+ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
10985
+ assert(k % QK4_NL == 0);
10986
+ quantize_iq4_nl(x, y, 1, k, NULL, NULL);
10987
+ }
10988
+