llama_cpp 0.12.0 → 0.12.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2446,6 +2446,19 @@ typedef struct {
2446
2446
  } block_q6_K;
2447
2447
  // 210 bytes / block
2448
2448
 
2449
+ typedef struct {
2450
+ half d;
2451
+ uint16_t qs[QK_K/8];
2452
+ } block_iq2_xxs;
2453
+ // 66 bytes / block for QK_K = 256, so 2.0625 bpw
2454
+
2455
+ typedef struct {
2456
+ half d;
2457
+ uint16_t qs[QK_K/8];
2458
+ uint8_t scales[QK_K/32];
2459
+ } block_iq2_xs;
2460
+ // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2461
+
2449
2462
  //====================================== dot products =========================
2450
2463
 
2451
2464
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3468,6 +3481,495 @@ kernel void kernel_mul_mv_q6_K_f32(
3468
3481
  kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
3469
3482
  }
3470
3483
 
3484
+ // ======================= "True" 2-bit
3485
+
3486
+ constexpr constant static uint64_t iq2xxs_grid[256] = {
3487
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3488
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
3489
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
3490
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
3491
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
3492
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
3493
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
3494
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
3495
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
3496
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
3497
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
3498
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
3499
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
3500
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
3501
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
3502
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
3503
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
3504
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
3505
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
3506
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
3507
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
3508
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
3509
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
3510
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
3511
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
3512
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
3513
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
3514
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
3515
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
3516
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
3517
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
3518
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
3519
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
3520
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
3521
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
3522
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
3523
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
3524
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
3525
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
3526
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
3527
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
3528
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
3529
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
3530
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
3531
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
3532
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
3533
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
3534
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
3535
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
3536
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
3537
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
3538
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
3539
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
3540
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
3541
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
3542
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
3543
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
3544
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
3545
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
3546
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
3547
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
3548
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
3549
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
3550
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
3551
+ };
3552
+
3553
+ constexpr constant static uint64_t iq2xs_grid[512] = {
3554
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3555
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3556
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3557
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3558
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3559
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
3560
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
3561
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
3562
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
3563
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
3564
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
3565
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
3566
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
3567
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
3568
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
3569
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
3570
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
3571
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
3572
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
3573
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
3574
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
3575
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
3576
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
3577
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
3578
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
3579
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
3580
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
3581
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
3582
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
3583
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
3584
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
3585
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
3586
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
3587
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
3588
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
3589
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
3590
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
3591
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
3592
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
3593
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
3594
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
3595
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
3596
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
3597
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
3598
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
3599
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
3600
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
3601
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
3602
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
3603
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
3604
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
3605
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
3606
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
3607
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
3608
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
3609
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
3610
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
3611
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
3612
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
3613
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
3614
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
3615
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
3616
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
3617
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
3618
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
3619
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
3620
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
3621
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
3622
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
3623
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
3624
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
3625
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
3626
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
3627
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
3628
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
3629
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
3630
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
3631
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
3632
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
3633
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
3634
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
3635
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
3636
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
3637
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
3638
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
3639
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
3640
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
3641
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
3642
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
3643
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
3644
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
3645
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
3646
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
3647
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
3648
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
3649
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
3650
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
3651
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
3652
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
3653
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
3654
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
3655
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
3656
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
3657
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
3658
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
3659
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
3660
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
3661
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
3662
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
3663
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
3664
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
3665
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
3666
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
3667
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
3668
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
3669
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
3670
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
3671
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
3672
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
3673
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
3674
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
3675
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
3676
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
3677
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
3678
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
3679
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
3680
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
3681
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3682
+ };
3683
+
3684
+ constexpr constant static uint8_t ksigns_iq2xs[128] = {
3685
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3686
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
3687
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
3688
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
3689
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
3690
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
3691
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
3692
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
3693
+ };
3694
+
3695
+ constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
3696
+
3697
+ void kernel_mul_mv_iq2_xxs_f32_impl(
3698
+ device const void * src0,
3699
+ device const float * src1,
3700
+ device float * dst,
3701
+ constant int64_t & ne00,
3702
+ constant int64_t & ne01,
3703
+ constant int64_t & ne02,
3704
+ constant int64_t & ne10,
3705
+ constant int64_t & ne12,
3706
+ constant int64_t & ne0,
3707
+ constant int64_t & ne1,
3708
+ constant uint & r2,
3709
+ constant uint & r3,
3710
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3711
+ uint3 tgpig[[threadgroup_position_in_grid]],
3712
+ uint tiisg[[thread_index_in_simdgroup]],
3713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3714
+
3715
+ const int nb = ne00/QK_K;
3716
+ const int r0 = tgpig.x;
3717
+ const int r1 = tgpig.y;
3718
+ const int im = tgpig.z;
3719
+
3720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3721
+ const int ib_row = first_row * nb;
3722
+
3723
+ const uint i12 = im%ne12;
3724
+ const uint i13 = im/ne12;
3725
+
3726
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3727
+
3728
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
3729
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3730
+
3731
+ float yl[32];
3732
+ float sumf[N_DST]={0.f}, all_sum;
3733
+
3734
+ const int nb32 = nb * (QK_K / 32);
3735
+
3736
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3737
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
3738
+ {
3739
+ int nval = 4;
3740
+ int pos = (32*sgitg + tiisg)*nval;
3741
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
3742
+ nval = 2;
3743
+ pos = (32*sgitg + tiisg)*nval;
3744
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3745
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3746
+ }
3747
+
3748
+ #if QK_K == 256
3749
+ const int ix = tiisg;
3750
+
3751
+ device const float * y4 = y + 32 * ix;
3752
+
3753
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3754
+
3755
+ for (int i = 0; i < 32; ++i) {
3756
+ yl[i] = y4[i];
3757
+ }
3758
+
3759
+ const int ibl = ib32 / (QK_K / 32);
3760
+ const int ib = ib32 % (QK_K / 32);
3761
+
3762
+ device const block_iq2_xxs * xr = x + ibl;
3763
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3764
+ device const half * dh = &xr->d;
3765
+
3766
+ for (int row = 0; row < N_DST; row++) {
3767
+
3768
+ const float db = dh[0];
3769
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
3770
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
3771
+ const float d = db * (0.5f + (aux32 >> 28));
3772
+
3773
+ float sum = 0;
3774
+ for (int l = 0; l < 4; ++l) {
3775
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
3776
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
3777
+ for (int j = 0; j < 8; ++j) {
3778
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3779
+ }
3780
+ }
3781
+ sumf[row] += d * sum;
3782
+
3783
+ dh += nb*sizeof(block_iq2_xxs)/2;
3784
+ q2 += nb*sizeof(block_iq2_xxs)/2;
3785
+ }
3786
+
3787
+ y4 += 32 * 32;
3788
+ }
3789
+ #else
3790
+ // TODO
3791
+ #endif
3792
+
3793
+ for (int row = 0; row < N_DST; ++row) {
3794
+ all_sum = simd_sum(sumf[row]);
3795
+ if (tiisg == 0) {
3796
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3797
+ }
3798
+ }
3799
+ }
3800
+
3801
+ [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
3802
+ kernel void kernel_mul_mv_iq2_xxs_f32(
3803
+ device const void * src0,
3804
+ device const float * src1,
3805
+ device float * dst,
3806
+ constant int64_t & ne00,
3807
+ constant int64_t & ne01,
3808
+ constant int64_t & ne02,
3809
+ constant uint64_t & nb00,
3810
+ constant uint64_t & nb01,
3811
+ constant uint64_t & nb02,
3812
+ constant int64_t & ne10,
3813
+ constant int64_t & ne11,
3814
+ constant int64_t & ne12,
3815
+ constant uint64_t & nb10,
3816
+ constant uint64_t & nb11,
3817
+ constant uint64_t & nb12,
3818
+ constant int64_t & ne0,
3819
+ constant int64_t & ne1,
3820
+ constant uint & r2,
3821
+ constant uint & r3,
3822
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3823
+ uint3 tgpig[[threadgroup_position_in_grid]],
3824
+ uint tiisg[[thread_index_in_simdgroup]],
3825
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3826
+
3827
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3828
+ }
3829
+
3830
+ void kernel_mul_mv_iq2_xs_f32_impl(
3831
+ device const void * src0,
3832
+ device const float * src1,
3833
+ device float * dst,
3834
+ constant int64_t & ne00,
3835
+ constant int64_t & ne01,
3836
+ constant int64_t & ne02,
3837
+ constant int64_t & ne10,
3838
+ constant int64_t & ne12,
3839
+ constant int64_t & ne0,
3840
+ constant int64_t & ne1,
3841
+ constant uint & r2,
3842
+ constant uint & r3,
3843
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3844
+ uint3 tgpig[[threadgroup_position_in_grid]],
3845
+ uint tiisg[[thread_index_in_simdgroup]],
3846
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3847
+
3848
+ const int nb = ne00/QK_K;
3849
+ const int r0 = tgpig.x;
3850
+ const int r1 = tgpig.y;
3851
+ const int im = tgpig.z;
3852
+
3853
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
3854
+ const int ib_row = first_row * nb;
3855
+
3856
+ const uint i12 = im%ne12;
3857
+ const uint i13 = im/ne12;
3858
+
3859
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
3860
+
3861
+ device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
3862
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
3863
+
3864
+ float yl[32];
3865
+ float sumf[N_DST]={0.f}, all_sum;
3866
+
3867
+ const int nb32 = nb * (QK_K / 32);
3868
+
3869
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
3870
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
3871
+ {
3872
+ int nval = 8;
3873
+ int pos = (32*sgitg + tiisg)*nval;
3874
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
3875
+ nval = 2;
3876
+ pos = (32*sgitg + tiisg)*nval;
3877
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
3878
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3879
+ }
3880
+
3881
+ #if QK_K == 256
3882
+ const int ix = tiisg;
3883
+
3884
+ device const float * y4 = y + 32 * ix;
3885
+
3886
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
3887
+
3888
+ for (int i = 0; i < 32; ++i) {
3889
+ yl[i] = y4[i];
3890
+ }
3891
+
3892
+ const int ibl = ib32 / (QK_K / 32);
3893
+ const int ib = ib32 % (QK_K / 32);
3894
+
3895
+ device const block_iq2_xs * xr = x + ibl;
3896
+ device const uint16_t * q2 = xr->qs + 4 * ib;
3897
+ device const uint8_t * sc = xr->scales + ib;
3898
+ device const half * dh = &xr->d;
3899
+
3900
+ for (int row = 0; row < N_DST; row++) {
3901
+
3902
+ const float db = dh[0];
3903
+ const uint8_t ls1 = sc[0] & 0xf;
3904
+ const uint8_t ls2 = sc[0] >> 4;
3905
+ const float d1 = db * (0.5f + ls1);
3906
+ const float d2 = db * (0.5f + ls2);
3907
+
3908
+ float sum1 = 0, sum2 = 0;
3909
+ for (int l = 0; l < 2; ++l) {
3910
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3911
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3912
+ for (int j = 0; j < 8; ++j) {
3913
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3914
+ }
3915
+ }
3916
+ for (int l = 2; l < 4; ++l) {
3917
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
3918
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
3919
+ for (int j = 0; j < 8; ++j) {
3920
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
3921
+ }
3922
+ }
3923
+ sumf[row] += d1 * sum1 + d2 * sum2;
3924
+
3925
+ dh += nb*sizeof(block_iq2_xs)/2;
3926
+ q2 += nb*sizeof(block_iq2_xs)/2;
3927
+ sc += nb*sizeof(block_iq2_xs);
3928
+ }
3929
+
3930
+ y4 += 32 * 32;
3931
+ }
3932
+ #else
3933
+ // TODO
3934
+ #endif
3935
+
3936
+ for (int row = 0; row < N_DST; ++row) {
3937
+ all_sum = simd_sum(sumf[row]);
3938
+ if (tiisg == 0) {
3939
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
3940
+ }
3941
+ }
3942
+ }
3943
+
3944
+ [[host_name("kernel_mul_mv_iq2_xs_f32")]]
3945
+ kernel void kernel_mul_mv_iq2_xs_f32(
3946
+ device const void * src0,
3947
+ device const float * src1,
3948
+ device float * dst,
3949
+ constant int64_t & ne00,
3950
+ constant int64_t & ne01,
3951
+ constant int64_t & ne02,
3952
+ constant uint64_t & nb00,
3953
+ constant uint64_t & nb01,
3954
+ constant uint64_t & nb02,
3955
+ constant int64_t & ne10,
3956
+ constant int64_t & ne11,
3957
+ constant int64_t & ne12,
3958
+ constant uint64_t & nb10,
3959
+ constant uint64_t & nb11,
3960
+ constant uint64_t & nb12,
3961
+ constant int64_t & ne0,
3962
+ constant int64_t & ne1,
3963
+ constant uint & r2,
3964
+ constant uint & r3,
3965
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
3966
+ uint3 tgpig[[threadgroup_position_in_grid]],
3967
+ uint tiisg[[thread_index_in_simdgroup]],
3968
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3969
+
3970
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
+ }
3972
+
3471
3973
  //============================= templates and their specializations =============================
3472
3974
 
3473
3975
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3620,8 +4122,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
3620
4122
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
3621
4123
  int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
3622
4124
  : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
3623
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
3624
- const half ml = 4.h * dl;
4125
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
4126
+ const float ml = 4.f * dl;
3625
4127
 
3626
4128
  il = (il/2) & 3;
3627
4129
  const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@@ -3688,7 +4190,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
3688
4190
  uint8_t ul = 1 << (il/2);
3689
4191
  il = il & 3;
3690
4192
  const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
3691
- const float d = il < 2 ? xb->d : xb->d / 16.h;
4193
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
3692
4194
  const float min = xb->dmin;
3693
4195
  const float dl = d * sc[0];
3694
4196
  const float ml = min * sc[1];
@@ -3721,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3721
4223
  #if QK_K == 256
3722
4224
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
3723
4225
  qh = qh + 32*(il/8) + 16*(il&1);
3724
- half sc = scales[(il%2) + 2 * ((il/2))];
4226
+ float sc = scales[(il%2) + 2 * ((il/2))];
3725
4227
  il = (il/2) & 3;
3726
4228
  #else
3727
4229
  ql = ql + 16 * (il&1);
3728
- half sc = scales[il];
4230
+ float sc = scales[il];
3729
4231
  #endif
3730
4232
  const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
3731
4233
  const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
3732
- const half coef = il>1 ? 1.f/16.h : 1.h;
3733
- const half ml = d_all * sc * 32.h;
3734
- const half dl = d_all * sc * coef;
4234
+ const float coef = il>1 ? 1.f/16.f : 1.f;
4235
+ const float ml = d_all * sc * 32.f;
4236
+ const float dl = d_all * sc * coef;
3735
4237
  for (int i = 0; i < 16; ++i) {
3736
4238
  const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
3737
4239
  : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
@@ -3739,6 +4241,52 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
3739
4241
  }
3740
4242
  }
3741
4243
 
4244
+ template <typename type4x4>
4245
+ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
4246
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4247
+ const float d = xb->d;
4248
+ const int ib32 = il/2;
4249
+ il = il%2;
4250
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4251
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
4252
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4253
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
4254
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
4255
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
4256
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
4257
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4258
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
4259
+ for (int i = 0; i < 8; ++i) {
4260
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4261
+ }
4262
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4263
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
4264
+ for (int i = 0; i < 8; ++i) {
4265
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4266
+ }
4267
+ }
4268
+
4269
+ template <typename type4x4>
4270
+ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
4271
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4272
+ const float d = xb->d;
4273
+ const int ib32 = il/2;
4274
+ il = il%2;
4275
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4276
+ device const uint16_t * q2 = xb->qs + 4*ib32;
4277
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
4278
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
4279
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
4280
+ for (int i = 0; i < 8; ++i) {
4281
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4282
+ }
4283
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
4284
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
4285
+ for (int i = 0; i < 8; ++i) {
4286
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4287
+ }
4288
+ }
4289
+
3742
4290
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
3743
4291
  kernel void kernel_get_rows(
3744
4292
  device const void * src0,
@@ -4278,6 +4826,8 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
4278
4826
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
4279
4827
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
4280
4828
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4281
4831
 
4282
4832
  //
4283
4833
  // matrix-matrix multiplication
@@ -4314,6 +4864,8 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4314
4864
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
4315
4865
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
4316
4866
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4317
4869
 
4318
4870
  //
4319
4871
  // indirect matrix-matrix multiplication
@@ -4362,6 +4914,8 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
4362
4914
  template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
4363
4915
  template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
4364
4916
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
4365
4919
 
4366
4920
  //
4367
4921
  // matrix-vector multiplication
@@ -5134,3 +5688,133 @@ kernel void kernel_mul_mv_id_q6_K_f32(
5134
5688
  tiisg,
5135
5689
  sgitg);
5136
5690
  }
5691
+
5692
+ [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
5693
+ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5694
+ device const char * ids,
5695
+ device const char * src1,
5696
+ device float * dst,
5697
+ constant uint64_t & nbi1,
5698
+ constant int64_t & ne00,
5699
+ constant int64_t & ne01,
5700
+ constant int64_t & ne02,
5701
+ constant uint64_t & nb00,
5702
+ constant uint64_t & nb01,
5703
+ constant uint64_t & nb02,
5704
+ constant int64_t & ne10,
5705
+ constant int64_t & ne11,
5706
+ constant int64_t & ne12,
5707
+ constant int64_t & ne13,
5708
+ constant uint64_t & nb10,
5709
+ constant uint64_t & nb11,
5710
+ constant uint64_t & nb12,
5711
+ constant int64_t & ne0,
5712
+ constant int64_t & ne1,
5713
+ constant uint64_t & nb1,
5714
+ constant uint & r2,
5715
+ constant uint & r3,
5716
+ constant int & idx,
5717
+ device const char * src00,
5718
+ device const char * src01,
5719
+ device const char * src02,
5720
+ device const char * src03,
5721
+ device const char * src04,
5722
+ device const char * src05,
5723
+ device const char * src06,
5724
+ device const char * src07,
5725
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5726
+ uint3 tgpig[[threadgroup_position_in_grid]],
5727
+ uint tiitg[[thread_index_in_threadgroup]],
5728
+ uint tiisg[[thread_index_in_simdgroup]],
5729
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5730
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5731
+
5732
+ const int64_t bid = tgpig.z/(ne12*ne13);
5733
+
5734
+ tgpig.z = tgpig.z%(ne12*ne13);
5735
+
5736
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5737
+
5738
+ kernel_mul_mv_iq2_xxs_f32_impl(
5739
+ src0[id],
5740
+ (device const float *) (src1 + bid*nb11),
5741
+ dst + bid*ne0,
5742
+ ne00,
5743
+ ne01,
5744
+ ne02,
5745
+ ne10,
5746
+ ne12,
5747
+ ne0,
5748
+ ne1,
5749
+ r2,
5750
+ r3,
5751
+ shared_values,
5752
+ tgpig,
5753
+ tiisg,
5754
+ sgitg);
5755
+ }
5756
+
5757
+ [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
5758
+ kernel void kernel_mul_mv_id_iq2_xs_f32(
5759
+ device const char * ids,
5760
+ device const char * src1,
5761
+ device float * dst,
5762
+ constant uint64_t & nbi1,
5763
+ constant int64_t & ne00,
5764
+ constant int64_t & ne01,
5765
+ constant int64_t & ne02,
5766
+ constant uint64_t & nb00,
5767
+ constant uint64_t & nb01,
5768
+ constant uint64_t & nb02,
5769
+ constant int64_t & ne10,
5770
+ constant int64_t & ne11,
5771
+ constant int64_t & ne12,
5772
+ constant int64_t & ne13,
5773
+ constant uint64_t & nb10,
5774
+ constant uint64_t & nb11,
5775
+ constant uint64_t & nb12,
5776
+ constant int64_t & ne0,
5777
+ constant int64_t & ne1,
5778
+ constant uint64_t & nb1,
5779
+ constant uint & r2,
5780
+ constant uint & r3,
5781
+ constant int & idx,
5782
+ device const char * src00,
5783
+ device const char * src01,
5784
+ device const char * src02,
5785
+ device const char * src03,
5786
+ device const char * src04,
5787
+ device const char * src05,
5788
+ device const char * src06,
5789
+ device const char * src07,
5790
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5791
+ uint3 tgpig[[threadgroup_position_in_grid]],
5792
+ uint tiitg[[thread_index_in_threadgroup]],
5793
+ uint tiisg[[thread_index_in_simdgroup]],
5794
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5795
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5796
+
5797
+ const int64_t bid = tgpig.z/(ne12*ne13);
5798
+
5799
+ tgpig.z = tgpig.z%(ne12*ne13);
5800
+
5801
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5802
+
5803
+ kernel_mul_mv_iq2_xs_f32_impl(
5804
+ src0[id],
5805
+ (device const float *) (src1 + bid*nb11),
5806
+ dst + bid*ne0,
5807
+ ne00,
5808
+ ne01,
5809
+ ne02,
5810
+ ne10,
5811
+ ne12,
5812
+ ne0,
5813
+ ne1,
5814
+ r2,
5815
+ r3,
5816
+ shared_values,
5817
+ tgpig,
5818
+ tiisg,
5819
+ sgitg);
5820
+ }