llama_cpp 0.12.0 → 0.12.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -2446,6 +2446,19 @@ typedef struct {
2446
2446
  } block_q6_K;
2447
2447
  // 210 bytes / block
2448
2448
 
2449
+ typedef struct {
2450
+ half d;
2451
+ uint16_t qs[QK_K/8];
2452
+ } block_iq2_xxs;
2453
+ // 66 bytes / block for QK_K = 256, so 2.0625 bpw
2454
+
2455
+ typedef struct {
2456
+ half d;
2457
+ uint16_t qs[QK_K/8];
2458
+ uint8_t scales[QK_K/32];
2459
+ } block_iq2_xs;
2460
+ // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2461
+
2449
2462
  //====================================== dot products =========================
2450
2463
 
2451
2464
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3468,6 +3481,495 @@ kernel void kernel_mul_mv_q6_K_f32(
3468
3481
  kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3469
3482
  }
3470
3483
 
3484
+ // ======================= "True" 2-bit
3485
+
3486
+ constexpr constant static uint64_t iq2xxs_grid[256] = {
3487
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3488
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3489
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
3490
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
3491
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
3492
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
3493
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
3494
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
3495
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
3496
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
3497
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
3498
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
3499
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
3500
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
3501
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
3502
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
3503
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
3504
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
3505
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
3506
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
3507
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
3508
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
3509
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
3510
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
3511
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
3512
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
3513
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
3514
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
3515
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
3516
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
3517
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
3518
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
3519
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
3520
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
3521
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
3522
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
3523
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
3524
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
3525
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
3526
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
3527
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
3528
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
3529
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
3530
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
3531
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
3532
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
3533
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
3534
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
3535
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
3536
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
3537
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
3538
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
3539
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
3540
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
3541
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
3542
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
3543
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
3544
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
3545
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
3546
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
3547
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
3548
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
3549
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
3550
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3551
+ };
3552
+
3553
+ constexpr constant static uint64_t iq2xs_grid[512] = {
3554
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3555
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3556
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3557
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3558
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3559
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
3560
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
3561
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
3562
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
3563
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
3564
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
3565
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
3566
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
3567
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
3568
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
3569
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
3570
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
3571
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
3572
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
3573
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
3574
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
3575
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
3576
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
3577
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
3578
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
3579
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
3580
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
3581
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
3582
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
3583
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
3584
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
3585
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
3586
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
3587
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
3588
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
3589
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
3590
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
3591
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
3592
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
3593
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
3594
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
3595
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
3596
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
3597
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
3598
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
3599
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
3600
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
3601
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
3602
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
3603
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
3604
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
3605
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
3606
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
3607
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
3608
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
3609
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
3610
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
3611
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
3612
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
3613
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
3614
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
3615
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
3616
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
3617
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
3618
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
3619
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
3620
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
3621
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
3622
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
3623
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
3624
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
3625
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
3626
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
3627
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
3628
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
3629
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
3630
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
3631
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
3632
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
3633
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
3634
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
3635
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
3636
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
3637
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
3638
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
3639
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
3640
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
3641
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
3642
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
3643
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
3644
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
3645
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
3646
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
3647
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
3648
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
3649
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
3650
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
3651
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
3652
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
3653
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
3654
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
3655
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
3656
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
3657
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
3658
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
3659
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
3660
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
3661
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
3662
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
3663
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
3664
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
3665
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
3666
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
3667
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
3668
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
3669
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
3670
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
3671
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
3672
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
3673
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
3674
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
3675
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
3676
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
3677
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
3678
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
3679
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
3680
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
3681
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3682
+ };
3683
+
3684
+ constexpr constant static uint8_t ksigns_iq2xs[128] = {
3685
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3686
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
3687
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
3688
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
3689
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
3690
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
3691
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
3692
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
3693
+ };
3694
+
3695
+ constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
3696
+
3697
+ void kernel_mul_mv_iq2_xxs_f32_impl(
3698
+ device const void * src0,
3699
+ device const float * src1,
3700
+ device float * dst,
3701
+ constant int64_t & ne00,
3702
+ constant int64_t & ne01,
3703
+ constant int64_t & ne02,
3704
+ constant int64_t & ne10,
3705
+ constant int64_t & ne12,
3706
+ constant int64_t & ne0,
3707
+ constant int64_t & ne1,
3708
+ constant uint & r2,
3709
+ constant uint & r3,
3710
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3711
+ uint3 tgpig[[threadgroup_position_in_grid]],
3712
+ uint tiisg[[thread_index_in_simdgroup]],
3713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3714
+
3715
+ const int nb = ne00/QK_K;
3716
+ const int r0 = tgpig.x;
3717
+ const int r1 = tgpig.y;
3718
+ const int im = tgpig.z;
3719
+
3720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3721
+ const int ib_row = first_row * nb;
3722
+
3723
+ const uint i12 = im%ne12;
3724
+ const uint i13 = im/ne12;
3725
+
3726
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3727
+
3728
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
3729
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3730
+
3731
+ float yl[32];
3732
+ float sumf[N_DST]={0.f}, all_sum;
3733
+
3734
+ const int nb32 = nb * (QK_K / 32);
3735
+
3736
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3737
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
3738
+ {
3739
+ int nval = 4;
3740
+ int pos = (32*sgitg + tiisg)*nval;
3741
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
3742
+ nval = 2;
3743
+ pos = (32*sgitg + tiisg)*nval;
3744
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3745
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3746
+ }
3747
+
3748
+ #if QK_K == 256
3749
+ const int ix = tiisg;
3750
+
3751
+ device const float * y4 = y + 32 * ix;
3752
+
3753
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3754
+
3755
+ for (int i = 0; i < 32; ++i) {
3756
+ yl[i] = y4[i];
3757
+ }
3758
+
3759
+ const int ibl = ib32 / (QK_K / 32);
3760
+ const int ib = ib32 % (QK_K / 32);
3761
+
3762
+ device const block_iq2_xxs * xr = x + ibl;
3763
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3764
+ device const half * dh = &xr->d;
3765
+
3766
+ for (int row = 0; row < N_DST; row++) {
3767
+
3768
+ const float db = dh[0];
3769
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
3770
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
3771
+ const float d = db * (0.5f + (aux32 >> 28));
3772
+
3773
+ float sum = 0;
3774
+ for (int l = 0; l < 4; ++l) {
3775
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
3776
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
3777
+ for (int j = 0; j < 8; ++j) {
3778
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3779
+ }
3780
+ }
3781
+ sumf[row] += d * sum;
3782
+
3783
+ dh += nb*sizeof(block_iq2_xxs)/2;
3784
+ q2 += nb*sizeof(block_iq2_xxs)/2;
3785
+ }
3786
+
3787
+ y4 += 32 * 32;
3788
+ }
3789
+ #else
3790
+ // TODO
3791
+ #endif
3792
+
3793
+ for (int row = 0; row < N_DST; ++row) {
3794
+ all_sum = simd_sum(sumf[row]);
3795
+ if (tiisg == 0) {
3796
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3797
+ }
3798
+ }
3799
+ }
3800
+
3801
+ [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
3802
+ kernel void kernel_mul_mv_iq2_xxs_f32(
3803
+ device const void * src0,
3804
+ device const float * src1,
3805
+ device float * dst,
3806
+ constant int64_t & ne00,
3807
+ constant int64_t & ne01,
3808
+ constant int64_t & ne02,
3809
+ constant uint64_t & nb00,
3810
+ constant uint64_t & nb01,
3811
+ constant uint64_t & nb02,
3812
+ constant int64_t & ne10,
3813
+ constant int64_t & ne11,
3814
+ constant int64_t & ne12,
3815
+ constant uint64_t & nb10,
3816
+ constant uint64_t & nb11,
3817
+ constant uint64_t & nb12,
3818
+ constant int64_t & ne0,
3819
+ constant int64_t & ne1,
3820
+ constant uint & r2,
3821
+ constant uint & r3,
3822
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3823
+ uint3 tgpig[[threadgroup_position_in_grid]],
3824
+ uint tiisg[[thread_index_in_simdgroup]],
3825
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3826
+
3827
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3828
+ }
3829
+
3830
+ void kernel_mul_mv_iq2_xs_f32_impl(
3831
+ device const void * src0,
3832
+ device const float * src1,
3833
+ device float * dst,
3834
+ constant int64_t & ne00,
3835
+ constant int64_t & ne01,
3836
+ constant int64_t & ne02,
3837
+ constant int64_t & ne10,
3838
+ constant int64_t & ne12,
3839
+ constant int64_t & ne0,
3840
+ constant int64_t & ne1,
3841
+ constant uint & r2,
3842
+ constant uint & r3,
3843
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiisg[[thread_index_in_simdgroup]],
3846
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3847
+
3848
+ const int nb = ne00/QK_K;
3849
+ const int r0 = tgpig.x;
3850
+ const int r1 = tgpig.y;
3851
+ const int im = tgpig.z;
3852
+
3853
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3854
+ const int ib_row = first_row * nb;
3855
+
3856
+ const uint i12 = im%ne12;
3857
+ const uint i13 = im/ne12;
3858
+
3859
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3860
+
3861
+ device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
3862
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3863
+
3864
+ float yl[32];
3865
+ float sumf[N_DST]={0.f}, all_sum;
3866
+
3867
+ const int nb32 = nb * (QK_K / 32);
3868
+
3869
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3870
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
3871
+ {
3872
+ int nval = 8;
3873
+ int pos = (32*sgitg + tiisg)*nval;
3874
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
3875
+ nval = 2;
3876
+ pos = (32*sgitg + tiisg)*nval;
3877
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3878
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3879
+ }
3880
+
3881
+ #if QK_K == 256
3882
+ const int ix = tiisg;
3883
+
3884
+ device const float * y4 = y + 32 * ix;
3885
+
3886
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3887
+
3888
+ for (int i = 0; i < 32; ++i) {
3889
+ yl[i] = y4[i];
3890
+ }
3891
+
3892
+ const int ibl = ib32 / (QK_K / 32);
3893
+ const int ib = ib32 % (QK_K / 32);
3894
+
3895
+ device const block_iq2_xs * xr = x + ibl;
3896
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3897
+ device const uint8_t * sc = xr->scales + ib;
3898
+ device const half * dh = &xr->d;
3899
+
3900
+ for (int row = 0; row < N_DST; row++) {
3901
+
3902
+ const float db = dh[0];
3903
+ const uint8_t ls1 = sc[0] & 0xf;
3904
+ const uint8_t ls2 = sc[0] >> 4;
3905
+ const float d1 = db * (0.5f + ls1);
3906
+ const float d2 = db * (0.5f + ls2);
3907
+
3908
+ float sum1 = 0, sum2 = 0;
3909
+ for (int l = 0; l < 2; ++l) {
3910
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3911
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3912
+ for (int j = 0; j < 8; ++j) {
3913
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3914
+ }
3915
+ }
3916
+ for (int l = 2; l < 4; ++l) {
3917
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3918
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3919
+ for (int j = 0; j < 8; ++j) {
3920
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3921
+ }
3922
+ }
3923
+ sumf[row] += d1 * sum1 + d2 * sum2;
3924
+
3925
+ dh += nb*sizeof(block_iq2_xs)/2;
3926
+ q2 += nb*sizeof(block_iq2_xs)/2;
3927
+ sc += nb*sizeof(block_iq2_xs);
3928
+ }
3929
+
3930
+ y4 += 32 * 32;
3931
+ }
3932
+ #else
3933
+ // TODO
3934
+ #endif
3935
+
3936
+ for (int row = 0; row < N_DST; ++row) {
3937
+ all_sum = simd_sum(sumf[row]);
3938
+ if (tiisg == 0) {
3939
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3940
+ }
3941
+ }
3942
+ }
3943
+
3944
+ [[host_name("kernel_mul_mv_iq2_xs_f32")]]
3945
+ kernel void kernel_mul_mv_iq2_xs_f32(
3946
+ device const void * src0,
3947
+ device const float * src1,
3948
+ device float * dst,
3949
+ constant int64_t & ne00,
3950
+ constant int64_t & ne01,
3951
+ constant int64_t & ne02,
3952
+ constant uint64_t & nb00,
3953
+ constant uint64_t & nb01,
3954
+ constant uint64_t & nb02,
3955
+ constant int64_t & ne10,
3956
+ constant int64_t & ne11,
3957
+ constant int64_t & ne12,
3958
+ constant uint64_t & nb10,
3959
+ constant uint64_t & nb11,
3960
+ constant uint64_t & nb12,
3961
+ constant int64_t & ne0,
3962
+ constant int64_t & ne1,
3963
+ constant uint & r2,
3964
+ constant uint & r3,
3965
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3966
+ uint3 tgpig[[threadgroup_position_in_grid]],
3967
+ uint tiisg[[thread_index_in_simdgroup]],
3968
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3969
+
3970
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
+ }
3972
+
3471
3973
  //============================= templates and their specializations =============================
3472
3974
 
3473
3975
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3620,8 +4122,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
3620
4122
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
3621
4123
  int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
3622
4124
  : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3623
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
3624
- const half ml = 4.h * dl;
4125
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
4126
+ const float ml = 4.f * dl;
3625
4127
 
3626
4128
  il = (il/2) & 3;
3627
4129
  const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@@ -3688,7 +4190,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3688
4190
  uint8_t ul = 1 << (il/2);
3689
4191
  il = il & 3;
3690
4192
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3691
- const float d = il < 2 ? xb->d : xb->d / 16.h;
4193
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
3692
4194
  const float min = xb->dmin;
3693
4195
  const float dl = d * sc[0];
3694
4196
  const float ml = min * sc[1];
@@ -3721,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3721
4223
  #if QK_K == 256
3722
4224
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
3723
4225
  qh = qh + 32*(il/8) + 16*(il&1);
3724
- half sc = scales[(il%2) + 2 * ((il/2))];
4226
+ float sc = scales[(il%2) + 2 * ((il/2))];
3725
4227
  il = (il/2) & 3;
3726
4228
  #else
3727
4229
  ql = ql + 16 * (il&1);
3728
- half sc = scales[il];
4230
+ float sc = scales[il];
3729
4231
  #endif
3730
4232
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3731
4233
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3732
- const half coef = il>1 ? 1.f/16.h : 1.h;
3733
- const half ml = d_all * sc * 32.h;
3734
- const half dl = d_all * sc * coef;
4234
+ const float coef = il>1 ? 1.f/16.f : 1.f;
4235
+ const float ml = d_all * sc * 32.f;
4236
+ const float dl = d_all * sc * coef;
3735
4237
  for (int i = 0; i < 16; ++i) {
3736
4238
  const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
3737
4239
  : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
@@ -3739,6 +4241,52 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3739
4241
  }
3740
4242
  }
3741
4243
 
4244
+ template <typename type4x4>
4245
+ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
4246
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4247
+ const float d = xb->d;
4248
+ const int ib32 = il/2;
4249
+ il = il%2;
4250
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4251
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
4252
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4253
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
4254
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
4255
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
4256
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
4257
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4258
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
4259
+ for (int i = 0; i < 8; ++i) {
4260
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4261
+ }
4262
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4263
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
4264
+ for (int i = 0; i < 8; ++i) {
4265
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4266
+ }
4267
+ }
4268
+
4269
+ template <typename type4x4>
4270
+ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
4271
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4272
+ const float d = xb->d;
4273
+ const int ib32 = il/2;
4274
+ il = il%2;
4275
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4276
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4277
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
4278
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
4279
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
4280
+ for (int i = 0; i < 8; ++i) {
4281
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4282
+ }
4283
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
4284
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
4285
+ for (int i = 0; i < 8; ++i) {
4286
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4287
+ }
4288
+ }
4289
+
3742
4290
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3743
4291
  kernel void kernel_get_rows(
3744
4292
  device const void * src0,
@@ -4278,6 +4826,8 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
4278
4826
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
4279
4827
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4280
4828
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4281
4831
 
4282
4832
  //
4283
4833
  // matrix-matrix multiplication
@@ -4314,6 +4864,8 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4314
4864
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
4315
4865
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4316
4866
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4317
4869
 
4318
4870
  //
4319
4871
  // indirect matrix-matrix multiplication
@@ -4362,6 +4914,8 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
4362
4914
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4363
4915
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4364
4916
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4365
4919
 
4366
4920
  //
4367
4921
  // matrix-vector multiplication
@@ -5134,3 +5688,133 @@ kernel void kernel_mul_mv_id_q6_K_f32(
5134
5688
  tiisg,
5135
5689
  sgitg);
5136
5690
  }
5691
+
5692
+ [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
5693
+ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5694
+ device const char * ids,
5695
+ device const char * src1,
5696
+ device float * dst,
5697
+ constant uint64_t & nbi1,
5698
+ constant int64_t & ne00,
5699
+ constant int64_t & ne01,
5700
+ constant int64_t & ne02,
5701
+ constant uint64_t & nb00,
5702
+ constant uint64_t & nb01,
5703
+ constant uint64_t & nb02,
5704
+ constant int64_t & ne10,
5705
+ constant int64_t & ne11,
5706
+ constant int64_t & ne12,
5707
+ constant int64_t & ne13,
5708
+ constant uint64_t & nb10,
5709
+ constant uint64_t & nb11,
5710
+ constant uint64_t & nb12,
5711
+ constant int64_t & ne0,
5712
+ constant int64_t & ne1,
5713
+ constant uint64_t & nb1,
5714
+ constant uint & r2,
5715
+ constant uint & r3,
5716
+ constant int & idx,
5717
+ device const char * src00,
5718
+ device const char * src01,
5719
+ device const char * src02,
5720
+ device const char * src03,
5721
+ device const char * src04,
5722
+ device const char * src05,
5723
+ device const char * src06,
5724
+ device const char * src07,
5725
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5726
+ uint3 tgpig[[threadgroup_position_in_grid]],
5727
+ uint tiitg[[thread_index_in_threadgroup]],
5728
+ uint tiisg[[thread_index_in_simdgroup]],
5729
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5730
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5731
+
5732
+ const int64_t bid = tgpig.z/(ne12*ne13);
5733
+
5734
+ tgpig.z = tgpig.z%(ne12*ne13);
5735
+
5736
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5737
+
5738
+ kernel_mul_mv_iq2_xxs_f32_impl(
5739
+ src0[id],
5740
+ (device const float *) (src1 + bid*nb11),
5741
+ dst + bid*ne0,
5742
+ ne00,
5743
+ ne01,
5744
+ ne02,
5745
+ ne10,
5746
+ ne12,
5747
+ ne0,
5748
+ ne1,
5749
+ r2,
5750
+ r3,
5751
+ shared_values,
5752
+ tgpig,
5753
+ tiisg,
5754
+ sgitg);
5755
+ }
5756
+
5757
+ [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
5758
+ kernel void kernel_mul_mv_id_iq2_xs_f32(
5759
+ device const char * ids,
5760
+ device const char * src1,
5761
+ device float * dst,
5762
+ constant uint64_t & nbi1,
5763
+ constant int64_t & ne00,
5764
+ constant int64_t & ne01,
5765
+ constant int64_t & ne02,
5766
+ constant uint64_t & nb00,
5767
+ constant uint64_t & nb01,
5768
+ constant uint64_t & nb02,
5769
+ constant int64_t & ne10,
5770
+ constant int64_t & ne11,
5771
+ constant int64_t & ne12,
5772
+ constant int64_t & ne13,
5773
+ constant uint64_t & nb10,
5774
+ constant uint64_t & nb11,
5775
+ constant uint64_t & nb12,
5776
+ constant int64_t & ne0,
5777
+ constant int64_t & ne1,
5778
+ constant uint64_t & nb1,
5779
+ constant uint & r2,
5780
+ constant uint & r3,
5781
+ constant int & idx,
5782
+ device const char * src00,
5783
+ device const char * src01,
5784
+ device const char * src02,
5785
+ device const char * src03,
5786
+ device const char * src04,
5787
+ device const char * src05,
5788
+ device const char * src06,
5789
+ device const char * src07,
5790
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5791
+ uint3 tgpig[[threadgroup_position_in_grid]],
5792
+ uint tiitg[[thread_index_in_threadgroup]],
5793
+ uint tiisg[[thread_index_in_simdgroup]],
5794
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5795
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5796
+
5797
+ const int64_t bid = tgpig.z/(ne12*ne13);
5798
+
5799
+ tgpig.z = tgpig.z%(ne12*ne13);
5800
+
5801
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5802
+
5803
+ kernel_mul_mv_iq2_xs_f32_impl(
5804
+ src0[id],
5805
+ (device const float *) (src1 + bid*nb11),
5806
+ dst + bid*ne0,
5807
+ ne00,
5808
+ ne01,
5809
+ ne02,
5810
+ ne10,
5811
+ ne12,
5812
+ ne0,
5813
+ ne1,
5814
+ r2,
5815
+ r3,
5816
+ shared_values,
5817
+ tgpig,
5818
+ tiisg,
5819
+ sgitg);
5820
+ }