llama_cpp 0.12.7 → 0.13.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -2519,12 +2519,34 @@ typedef struct {
2519
2519
  } block_iq2_xs;
2520
2520
  // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2521
2521
 
2522
+ // 2.5625 bpw quants
2523
+ typedef struct {
2524
+ half d;
2525
+ uint8_t qs[QK_K/4];
2526
+ uint8_t qh[QK_K/32];
2527
+ uint8_t scales[QK_K/32];
2528
+ } block_iq2_s;
2529
+
2522
2530
  typedef struct {
2523
2531
  half d;
2524
2532
  uint8_t qs[3*QK_K/8];
2525
2533
  } block_iq3_xxs;
2526
2534
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2527
2535
 
2536
+ // 3.4375 bpw
2537
+ #if QK_K == 64
2538
+ #define IQ3S_N_SCALE 2
2539
+ #else
2540
+ #define IQ3S_N_SCALE QK_K/64
2541
+ #endif
2542
+ typedef struct {
2543
+ half d;
2544
+ uint8_t qs[QK_K/4];
2545
+ uint8_t qh[QK_K/32];
2546
+ uint8_t signs[QK_K/8];
2547
+ uint8_t scales[IQ3S_N_SCALE];
2548
+ } block_iq3_s;
2549
+
2528
2550
  typedef struct {
2529
2551
  half d;
2530
2552
  uint8_t qs[QK_K/8];
@@ -2538,6 +2560,17 @@ typedef struct {
2538
2560
  uint8_t qs[QK4_NL/2];
2539
2561
  } block_iq4_nl;
2540
2562
 
2563
+ #if QK_K == 64
2564
+ #define block_iq4_xs block_iq4_nl
2565
+ #else
2566
+ typedef struct {
2567
+ half d;
2568
+ uint16_t scales_h;
2569
+ uint8_t scales_l[QK_K/64];
2570
+ uint8_t qs[QK_K/2];
2571
+ } block_iq4_xs;
2572
+ #endif
2573
+
2541
2574
  //====================================== dot products =========================
2542
2575
 
2543
2576
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3760,6 +3793,265 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
3760
3793
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3761
3794
  };
3762
3795
 
3796
+ constexpr constant static uint64_t iq2s_grid[1024] = {
3797
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3798
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3799
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3800
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3801
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3802
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
3803
+ 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
3804
+ 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
3805
+ 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
3806
+ 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
3807
+ 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
3808
+ 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
3809
+ 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
3810
+ 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
3811
+ 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
3812
+ 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
3813
+ 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
3814
+ 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
3815
+ 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
3816
+ 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
3817
+ 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
3818
+ 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
3819
+ 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
3820
+ 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
3821
+ 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
3822
+ 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
3823
+ 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
3824
+ 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
3825
+ 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
3826
+ 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
3827
+ 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
3828
+ 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
3829
+ 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
3830
+ 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
3831
+ 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
3832
+ 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
3833
+ 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
3834
+ 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
3835
+ 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
3836
+ 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
3837
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
3838
+ 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
3839
+ 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
3840
+ 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
3841
+ 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
3842
+ 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
3843
+ 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
3844
+ 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
3845
+ 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
3846
+ 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
3847
+ 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
3848
+ 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
3849
+ 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
3850
+ 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
3851
+ 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
3852
+ 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
3853
+ 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
3854
+ 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
3855
+ 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
3856
+ 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
3857
+ 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
3858
+ 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
3859
+ 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
3860
+ 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
3861
+ 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
3862
+ 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
3863
+ 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
3864
+ 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
3865
+ 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
3866
+ 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
3867
+ 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
3868
+ 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
3869
+ 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
3870
+ 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
3871
+ 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
3872
+ 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
3873
+ 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
3874
+ 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
3875
+ 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
3876
+ 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
3877
+ 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
3878
+ 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
3879
+ 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
3880
+ 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
3881
+ 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
3882
+ 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
3883
+ 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
3884
+ 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
3885
+ 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
3886
+ 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
3887
+ 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
3888
+ 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
3889
+ 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
3890
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
3891
+ 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
3892
+ 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
3893
+ 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
3894
+ 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
3895
+ 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
3896
+ 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
3897
+ 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
3898
+ 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
3899
+ 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
3900
+ 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
3901
+ 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
3902
+ 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
3903
+ 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
3904
+ 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
3905
+ 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
3906
+ 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
3907
+ 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
3908
+ 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
3909
+ 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
3910
+ 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
3911
+ 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
3912
+ 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
3913
+ 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
3914
+ 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
3915
+ 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
3916
+ 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
3917
+ 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
3918
+ 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
3919
+ 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
3920
+ 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
3921
+ 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
3922
+ 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
3923
+ 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
3924
+ 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
3925
+ 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
3926
+ 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
3927
+ 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
3928
+ 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
3929
+ 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
3930
+ 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
3931
+ 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
3932
+ 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
3933
+ 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
3934
+ 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
3935
+ 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
3936
+ 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
3937
+ 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
3938
+ 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
3939
+ 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
3940
+ 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
3941
+ 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
3942
+ 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
3943
+ 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
3944
+ 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
3945
+ 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
3946
+ 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
3947
+ 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
3948
+ 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
3949
+ 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
3950
+ 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
3951
+ 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
3952
+ 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
3953
+ 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
3954
+ 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
3955
+ 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
3956
+ 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
3957
+ 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
3958
+ 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
3959
+ 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
3960
+ 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
3961
+ 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
3962
+ 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
3963
+ 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
3964
+ 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
3965
+ 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
3966
+ 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
3967
+ 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
3968
+ 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
3969
+ 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
3970
+ 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
3971
+ 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
3972
+ 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
3973
+ 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
3974
+ 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
3975
+ 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
3976
+ 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
3977
+ 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
3978
+ 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
3979
+ 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
3980
+ 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
3981
+ 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
3982
+ 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
3983
+ 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
3984
+ 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
3985
+ 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
3986
+ 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
3987
+ 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
3988
+ 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
3989
+ 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
3990
+ 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
3991
+ 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
3992
+ 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
3993
+ 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
3994
+ 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
3995
+ 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
3996
+ 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
3997
+ 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
3998
+ 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
3999
+ 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
4000
+ 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
4001
+ 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
4002
+ 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
4003
+ 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
4004
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
4005
+ 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
4006
+ 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
4007
+ 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
4008
+ 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
4009
+ 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
4010
+ 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
4011
+ 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
4012
+ 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
4013
+ 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
4014
+ 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
4015
+ 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
4016
+ 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
4017
+ 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
4018
+ 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
4019
+ 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
4020
+ 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
4021
+ 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
4022
+ 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
4023
+ 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
4024
+ 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
4025
+ 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
4026
+ 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
4027
+ 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
4028
+ 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
4029
+ 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
4030
+ 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
4031
+ 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
4032
+ 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
4033
+ 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
4034
+ 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
4035
+ 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
4036
+ 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
4037
+ 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
4038
+ 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
4039
+ 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
4040
+ 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
4041
+ 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
4042
+ 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
4043
+ 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
4044
+ 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
4045
+ 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
4046
+ 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
4047
+ 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
4048
+ 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
4049
+ 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
4050
+ 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
4051
+ 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
4052
+ 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
4053
+ };
4054
+
3763
4055
  constexpr constant static uint32_t iq3xxs_grid[256] = {
3764
4056
  0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
3765
4057
  0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -3795,6 +4087,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
3795
4087
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3796
4088
  };
3797
4089
 
4090
+ constexpr constant static uint32_t iq3xs_grid[512] = {
4091
+ 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
4092
+ 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
4093
+ 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
4094
+ 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
4095
+ 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
4096
+ 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
4097
+ 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
4098
+ 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
4099
+ 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
4100
+ 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
4101
+ 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
4102
+ 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
4103
+ 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
4104
+ 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
4105
+ 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
4106
+ 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
4107
+ 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
4108
+ 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
4109
+ 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
4110
+ 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
4111
+ 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
4112
+ 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
4113
+ 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
4114
+ 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
4115
+ 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
4116
+ 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
4117
+ 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
4118
+ 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
4119
+ 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
4120
+ 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
4121
+ 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
4122
+ 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
4123
+ 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
4124
+ 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
4125
+ 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
4126
+ 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
4127
+ 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
4128
+ 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
4129
+ 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
4130
+ 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
4131
+ 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
4132
+ 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
4133
+ 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
4134
+ 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
4135
+ 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
4136
+ 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
4137
+ 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
4138
+ 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
4139
+ 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
4140
+ 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
4141
+ 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
4142
+ 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
4143
+ 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
4144
+ 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
4145
+ 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
4146
+ 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
4147
+ 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
4148
+ 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
4149
+ 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
4150
+ 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
4151
+ 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
4152
+ 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
4153
+ 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
4154
+ 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
4155
+ };
4156
+
3798
4157
  #define NGRID_IQ1S 512
3799
4158
  constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
3800
4159
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@@ -3991,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3991
4350
  threadgroup_barrier(mem_flags::mem_threadgroup);
3992
4351
  }
3993
4352
 
3994
- #if QK_K == 256
3995
4353
  const int ix = tiisg;
3996
4354
 
3997
4355
  device const float * y4 = y + 32 * ix;
@@ -4032,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4032
4390
 
4033
4391
  y4 += 32 * 32;
4034
4392
  }
4035
- #else
4036
- (void) x;
4037
- (void) y;
4038
- (void) yl;
4039
- (void) nb32;
4040
- #endif
4041
4393
 
4042
4394
  for (int row = 0; row < N_DST; ++row) {
4043
4395
  all_sum = simd_sum(sumf[row]);
@@ -4127,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4127
4479
  threadgroup_barrier(mem_flags::mem_threadgroup);
4128
4480
  }
4129
4481
 
4130
- #if QK_K == 256
4131
4482
  const int ix = tiisg;
4132
4483
 
4133
4484
  device const float * y4 = y + 32 * ix;
@@ -4178,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4178
4529
 
4179
4530
  y4 += 32 * 32;
4180
4531
  }
4181
- #else
4182
- (void) x;
4183
- (void) y;
4184
- (void) yl;
4185
- (void) nb32;
4186
- #endif
4187
4532
 
4188
4533
  for (int row = 0; row < N_DST; ++row) {
4189
4534
  all_sum = simd_sum(sumf[row]);
@@ -4273,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4273
4618
  threadgroup_barrier(mem_flags::mem_threadgroup);
4274
4619
  }
4275
4620
 
4276
- #if QK_K == 256
4277
4621
  const int ix = tiisg;
4278
4622
 
4279
4623
  device const float * y4 = y + 32 * ix;
@@ -4317,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4317
4661
 
4318
4662
  y4 += 32 * 32;
4319
4663
  }
4320
- #else
4321
- (void) x;
4322
- (void) y;
4323
- (void) yl;
4324
- (void) nb32;
4325
- #endif
4326
4664
 
4327
4665
  for (int row = 0; row < N_DST; ++row) {
4328
4666
  all_sum = simd_sum(sumf[row]);
@@ -4361,6 +4699,269 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4361
4699
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4362
4700
  }
4363
4701
 
4702
+ void kernel_mul_mv_iq3_s_f32_impl(
4703
+ device const void * src0,
4704
+ device const float * src1,
4705
+ device float * dst,
4706
+ constant int64_t & ne00,
4707
+ constant int64_t & ne01,
4708
+ constant int64_t & ne02,
4709
+ constant int64_t & ne10,
4710
+ constant int64_t & ne12,
4711
+ constant int64_t & ne0,
4712
+ constant int64_t & ne1,
4713
+ constant uint & r2,
4714
+ constant uint & r3,
4715
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4716
+ uint3 tgpig[[threadgroup_position_in_grid]],
4717
+ uint tiisg[[thread_index_in_simdgroup]],
4718
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4719
+
4720
+ const int nb = ne00/QK_K;
4721
+ const int r0 = tgpig.x;
4722
+ const int r1 = tgpig.y;
4723
+ const int im = tgpig.z;
4724
+
4725
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4726
+ const int ib_row = first_row * nb;
4727
+
4728
+ const uint i12 = im%ne12;
4729
+ const uint i13 = im/ne12;
4730
+
4731
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4732
+
4733
+ device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
4734
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4735
+
4736
+ float yl[32];
4737
+ float sumf[N_DST]={0.f}, all_sum;
4738
+
4739
+ const int nb32 = nb * (QK_K / 32);
4740
+
4741
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
4742
+ {
4743
+ int nval = 8;
4744
+ int pos = (32*sgitg + tiisg)*nval;
4745
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
4746
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4747
+ }
4748
+
4749
+ const int ix = tiisg;
4750
+
4751
+ device const float * y4 = y + 32 * ix;
4752
+
4753
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4754
+
4755
+ for (int i = 0; i < 32; ++i) {
4756
+ yl[i] = y4[i];
4757
+ }
4758
+
4759
+ const int ibl = ib32 / (QK_K / 32);
4760
+ const int ib = ib32 % (QK_K / 32);
4761
+
4762
+ device const block_iq3_s * xr = x + ibl;
4763
+ device const uint8_t * qs = xr->qs + 8 * ib;
4764
+ device const uint8_t * qh = xr->qh + ib;
4765
+ device const uint8_t * sc = xr->scales + (ib/2);
4766
+ device const uint8_t * signs = xr->signs + 4 * ib;
4767
+ device const half * dh = &xr->d;
4768
+
4769
+ for (int row = 0; row < N_DST; row++) {
4770
+
4771
+ const float db = dh[0];
4772
+ const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
4773
+
4774
+ float2 sum = {0};
4775
+ for (int l = 0; l < 4; ++l) {
4776
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
4777
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
4778
+ for (int j = 0; j < 4; ++j) {
4779
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4780
+ sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
4781
+ }
4782
+ }
4783
+ sumf[row] += d * (sum[0] + sum[1]);
4784
+
4785
+ dh += nb*sizeof(block_iq3_s)/2;
4786
+ qs += nb*sizeof(block_iq3_s);
4787
+ qh += nb*sizeof(block_iq3_s);
4788
+ sc += nb*sizeof(block_iq3_s);
4789
+ signs += nb*sizeof(block_iq3_s);
4790
+ }
4791
+
4792
+ y4 += 32 * 32;
4793
+ }
4794
+
4795
+ for (int row = 0; row < N_DST; ++row) {
4796
+ all_sum = simd_sum(sumf[row]);
4797
+ if (tiisg == 0) {
4798
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
4799
+ }
4800
+ }
4801
+ }
4802
+
4803
+ [[host_name("kernel_mul_mv_iq3_s_f32")]]
4804
+ kernel void kernel_mul_mv_iq3_s_f32(
4805
+ device const void * src0,
4806
+ device const float * src1,
4807
+ device float * dst,
4808
+ constant int64_t & ne00,
4809
+ constant int64_t & ne01,
4810
+ constant int64_t & ne02,
4811
+ constant uint64_t & nb00,
4812
+ constant uint64_t & nb01,
4813
+ constant uint64_t & nb02,
4814
+ constant int64_t & ne10,
4815
+ constant int64_t & ne11,
4816
+ constant int64_t & ne12,
4817
+ constant uint64_t & nb10,
4818
+ constant uint64_t & nb11,
4819
+ constant uint64_t & nb12,
4820
+ constant int64_t & ne0,
4821
+ constant int64_t & ne1,
4822
+ constant uint & r2,
4823
+ constant uint & r3,
4824
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4825
+ uint3 tgpig[[threadgroup_position_in_grid]],
4826
+ uint tiisg[[thread_index_in_simdgroup]],
4827
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4828
+
4829
+ kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4830
+ }
4831
+
4832
+ void kernel_mul_mv_iq2_s_f32_impl(
4833
+ device const void * src0,
4834
+ device const float * src1,
4835
+ device float * dst,
4836
+ constant int64_t & ne00,
4837
+ constant int64_t & ne01,
4838
+ constant int64_t & ne02,
4839
+ constant int64_t & ne10,
4840
+ constant int64_t & ne12,
4841
+ constant int64_t & ne0,
4842
+ constant int64_t & ne1,
4843
+ constant uint & r2,
4844
+ constant uint & r3,
4845
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4846
+ uint3 tgpig[[threadgroup_position_in_grid]],
4847
+ uint tiisg[[thread_index_in_simdgroup]],
4848
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4849
+
4850
+ const int nb = ne00/QK_K;
4851
+ const int r0 = tgpig.x;
4852
+ const int r1 = tgpig.y;
4853
+ const int im = tgpig.z;
4854
+
4855
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4856
+ const int ib_row = first_row * nb;
4857
+
4858
+ const uint i12 = im%ne12;
4859
+ const uint i13 = im/ne12;
4860
+
4861
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4862
+
4863
+ device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
4864
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4865
+
4866
+ float yl[32];
4867
+ float sumf[N_DST]={0.f}, all_sum;
4868
+
4869
+ const int nb32 = nb * (QK_K / 32);
4870
+
4871
+ //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
4872
+ //{
4873
+ // int nval = 32;
4874
+ // int pos = (32*sgitg + tiisg)*nval;
4875
+ // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
4876
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
4877
+ //}
4878
+
4879
+ const int ix = tiisg;
4880
+
4881
+ device const float * y4 = y + 32 * ix;
4882
+
4883
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4884
+
4885
+ for (int i = 0; i < 32; ++i) {
4886
+ yl[i] = y4[i];
4887
+ }
4888
+
4889
+ const int ibl = ib32 / (QK_K / 32);
4890
+ const int ib = ib32 % (QK_K / 32);
4891
+
4892
+ device const block_iq2_s * xr = x + ibl;
4893
+ device const uint8_t * qs = xr->qs + 4 * ib;
4894
+ device const uint8_t * qh = xr->qh + ib;
4895
+ device const uint8_t * sc = xr->scales + ib;
4896
+ device const uint8_t * signs = qs + QK_K/8;
4897
+ device const half * dh = &xr->d;
4898
+
4899
+ for (int row = 0; row < N_DST; row++) {
4900
+
4901
+ const float db = dh[0];
4902
+ const float d1 = db * (0.5f + (sc[0] & 0xf));
4903
+ const float d2 = db * (0.5f + (sc[0] >> 4));
4904
+
4905
+ float2 sum = {0};
4906
+ for (int l = 0; l < 2; ++l) {
4907
+ //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
4908
+ //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
4909
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
4910
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
4911
+ for (int j = 0; j < 8; ++j) {
4912
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
4913
+ sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
4914
+ }
4915
+ }
4916
+ sumf[row] += d1 * sum[0] + d2 * sum[1];
4917
+
4918
+ dh += nb*sizeof(block_iq2_s)/2;
4919
+ qs += nb*sizeof(block_iq2_s);
4920
+ qh += nb*sizeof(block_iq2_s);
4921
+ sc += nb*sizeof(block_iq2_s);
4922
+ signs += nb*sizeof(block_iq2_s);
4923
+ }
4924
+
4925
+ y4 += 32 * 32;
4926
+ }
4927
+
4928
+ for (int row = 0; row < N_DST; ++row) {
4929
+ all_sum = simd_sum(sumf[row]);
4930
+ if (tiisg == 0) {
4931
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
4932
+ }
4933
+ }
4934
+ }
4935
+
4936
+ [[host_name("kernel_mul_mv_iq2_s_f32")]]
4937
+ kernel void kernel_mul_mv_iq2_s_f32(
4938
+ device const void * src0,
4939
+ device const float * src1,
4940
+ device float * dst,
4941
+ constant int64_t & ne00,
4942
+ constant int64_t & ne01,
4943
+ constant int64_t & ne02,
4944
+ constant uint64_t & nb00,
4945
+ constant uint64_t & nb01,
4946
+ constant uint64_t & nb02,
4947
+ constant int64_t & ne10,
4948
+ constant int64_t & ne11,
4949
+ constant int64_t & ne12,
4950
+ constant uint64_t & nb10,
4951
+ constant uint64_t & nb11,
4952
+ constant uint64_t & nb12,
4953
+ constant int64_t & ne0,
4954
+ constant int64_t & ne1,
4955
+ constant uint & r2,
4956
+ constant uint & r3,
4957
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4958
+ uint3 tgpig[[threadgroup_position_in_grid]],
4959
+ uint tiisg[[thread_index_in_simdgroup]],
4960
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4961
+
4962
+ kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4963
+ }
4964
+
4364
4965
  void kernel_mul_mv_iq1_s_f32_impl(
4365
4966
  device const void * src0,
4366
4967
  device const float * src1,
@@ -4398,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
4398
4999
 
4399
5000
  const int nb32 = nb * (QK_K / 32);
4400
5001
 
4401
- #if QK_K == 256
4402
5002
  const int ix = tiisg/2;
4403
5003
  const int il = tiisg%2;
4404
5004
 
@@ -4437,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
4437
5037
 
4438
5038
  y4 += 16 * 32;
4439
5039
  }
4440
- #else
4441
- (void) x;
4442
- (void) y;
4443
- (void) yl;
4444
- (void) nb32;
4445
- #endif
4446
5040
 
4447
5041
  for (int row = 0; row < N_DST; ++row) {
4448
5042
  all_sum = simd_sum(sumf[row]);
@@ -4549,8 +5143,132 @@ void kernel_mul_mv_iq4_nl_f32_impl(
4549
5143
  }
4550
5144
  }
4551
5145
 
4552
- [[host_name("kernel_mul_mv_iq1_s_f32")]]
4553
- kernel void kernel_mul_mv_iq1_s_f32(
5146
+ #if QK_K != 64
5147
+ void kernel_mul_mv_iq4_xs_f32_impl(
5148
+ device const void * src0,
5149
+ device const float * src1,
5150
+ device float * dst,
5151
+ constant int64_t & ne00,
5152
+ constant int64_t & ne01,
5153
+ constant int64_t & ne02,
5154
+ constant int64_t & ne10,
5155
+ constant int64_t & ne12,
5156
+ constant int64_t & ne0,
5157
+ constant int64_t & ne1,
5158
+ constant uint & r2,
5159
+ constant uint & r3,
5160
+ threadgroup float * shared_values [[threadgroup(0)]],
5161
+ uint3 tgpig[[threadgroup_position_in_grid]],
5162
+ uint tiisg[[thread_index_in_simdgroup]],
5163
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5164
+
5165
+ const int nb = ne00/QK_K;
5166
+ const int r0 = tgpig.x;
5167
+ const int r1 = tgpig.y;
5168
+ const int im = tgpig.z;
5169
+ const int first_row = (r0 * 2 + sgitg) * 2;
5170
+ const int ib_row = first_row * nb;
5171
+
5172
+ const uint i12 = im%ne12;
5173
+ const uint i13 = im/ne12;
5174
+
5175
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5176
+ device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
5177
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
5178
+
5179
+ const int ix = tiisg/16; // 0 or 1
5180
+ const int it = tiisg%16; // 0...15
5181
+ const int ib = it/2;
5182
+ const int il = it%2;
5183
+
5184
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
5185
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5186
+
5187
+ float4 yl[4];
5188
+ float sumf[2]={0.f}, all_sum;
5189
+
5190
+ device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
5191
+
5192
+ uint32_t aux32[2];
5193
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
5194
+
5195
+ float4 qf1, qf2;
5196
+
5197
+ for (int ibl = ix; ibl < nb; ibl += 2) {
5198
+
5199
+ device const float4 * y4 = (device const float4 *)yb;
5200
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
5201
+
5202
+ for (int row = 0; row < 2; ++row) {
5203
+
5204
+ device const block_iq4_xs & xb = x[row*nb + ibl];
5205
+ device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
5206
+
5207
+ float4 acc1 = {0.f}, acc2 = {0.f};
5208
+
5209
+ aux32[0] = q4[0] & 0x0f0f0f0f;
5210
+ aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
5211
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5212
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5213
+ acc1 += yl[0] * qf1;
5214
+ acc2 += yl[1] * qf2;
5215
+
5216
+ aux32[0] = q4[1] & 0x0f0f0f0f;
5217
+ aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
5218
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5219
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5220
+ acc1 += yl[2] * qf1;
5221
+ acc2 += yl[3] * qf2;
5222
+
5223
+ acc1 += acc2;
5224
+
5225
+ const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
5226
+ sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
5227
+
5228
+ }
5229
+
5230
+ yb += 2 * QK_K;
5231
+ }
5232
+
5233
+ for (int row = 0; row < 2; ++row) {
5234
+ all_sum = simd_sum(sumf[row]);
5235
+ if (tiisg == 0) {
5236
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5237
+ }
5238
+ }
5239
+ }
5240
+ #endif
5241
+
5242
+ [[host_name("kernel_mul_mv_iq1_s_f32")]]
5243
+ kernel void kernel_mul_mv_iq1_s_f32(
5244
+ device const void * src0,
5245
+ device const float * src1,
5246
+ device float * dst,
5247
+ constant int64_t & ne00,
5248
+ constant int64_t & ne01,
5249
+ constant int64_t & ne02,
5250
+ constant uint64_t & nb00,
5251
+ constant uint64_t & nb01,
5252
+ constant uint64_t & nb02,
5253
+ constant int64_t & ne10,
5254
+ constant int64_t & ne11,
5255
+ constant int64_t & ne12,
5256
+ constant uint64_t & nb10,
5257
+ constant uint64_t & nb11,
5258
+ constant uint64_t & nb12,
5259
+ constant int64_t & ne0,
5260
+ constant int64_t & ne1,
5261
+ constant uint & r2,
5262
+ constant uint & r3,
5263
+ uint3 tgpig[[threadgroup_position_in_grid]],
5264
+ uint tiisg[[thread_index_in_simdgroup]],
5265
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5266
+
5267
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
5268
+ }
5269
+
5270
+ [[host_name("kernel_mul_mv_iq4_nl_f32")]]
5271
+ kernel void kernel_mul_mv_iq4_nl_f32(
4554
5272
  device const void * src0,
4555
5273
  device const float * src1,
4556
5274
  device float * dst,
@@ -4570,15 +5288,16 @@ kernel void kernel_mul_mv_iq1_s_f32(
4570
5288
  constant int64_t & ne1,
4571
5289
  constant uint & r2,
4572
5290
  constant uint & r3,
5291
+ threadgroup float * shared_values [[threadgroup(0)]],
4573
5292
  uint3 tgpig[[threadgroup_position_in_grid]],
4574
- uint tiisg[[thread_index_in_simdgroup]],
4575
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5293
+ uint tiisg[[thread_index_in_simdgroup]],
5294
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4576
5295
 
4577
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
5296
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4578
5297
  }
4579
5298
 
4580
- [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4581
- kernel void kernel_mul_mv_iq4_nl_f32(
5299
+ [[host_name("kernel_mul_mv_iq4_xs_f32")]]
5300
+ kernel void kernel_mul_mv_iq4_xs_f32(
4582
5301
  device const void * src0,
4583
5302
  device const float * src1,
4584
5303
  device float * dst,
@@ -4603,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_nl_f32(
4603
5322
  uint tiisg[[thread_index_in_simdgroup]],
4604
5323
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4605
5324
 
5325
+ #if QK_K == 64
4606
5326
  kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5327
+ #else
5328
+ kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5329
+ #endif
4607
5330
  }
4608
5331
 
4609
5332
  //============================= templates and their specializations =============================
@@ -4952,6 +5675,50 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
4952
5675
  }
4953
5676
  }
4954
5677
 
5678
+ template <typename type4x4>
5679
+ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
5680
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5681
+ const float d = xb->d;
5682
+ const int ib32 = il/2;
5683
+ il = il%2;
5684
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5685
+ device const uint8_t * qs = xb->qs + 8*ib32;
5686
+ device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
5687
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
5688
+ const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
5689
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
5690
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
5691
+ for (int i = 0; i < 4; ++i) {
5692
+ reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5693
+ reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5694
+ }
5695
+ grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
5696
+ grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
5697
+ for (int i = 0; i < 4; ++i) {
5698
+ reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5699
+ reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5700
+ }
5701
+ }
5702
+
5703
+ template <typename type4x4>
5704
+ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
5705
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5706
+ const float d = xb->d;
5707
+ const int ib32 = il/2;
5708
+ il = il%2;
5709
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5710
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5711
+ device const uint8_t * signs = qs + QK_K/8;
5712
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
5713
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
5714
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
5715
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
5716
+ for (int i = 0; i < 8; ++i) {
5717
+ reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
5718
+ reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
5719
+ }
5720
+ }
5721
+
4955
5722
  template <typename type4x4>
4956
5723
  void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
4957
5724
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@@ -4983,6 +5750,30 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
4983
5750
  }
4984
5751
  }
4985
5752
 
5753
+ template <typename type4x4>
5754
+ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5755
+ #if QK_K == 64
5756
+ dequantize_iq4_nl(xb, il, reg);
5757
+ #else
5758
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5759
+ const int ib32 = il/2;
5760
+ il = il%2;
5761
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5762
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
5763
+ const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
5764
+ const float d = (float)xb->d * (ls - 32);
5765
+ uint32_t aux32;
5766
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
5767
+ for (int i = 0; i < 4; ++i) {
5768
+ aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
5769
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
5770
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
5771
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5772
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5773
+ }
5774
+ #endif
5775
+ }
5776
+
4986
5777
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4987
5778
  kernel void kernel_get_rows(
4988
5779
  device const void * src0,
@@ -5525,8 +6316,15 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
5525
6316
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5526
6317
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5527
6318
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6319
+ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6320
+ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
5528
6321
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5529
- template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6322
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6323
+ #if QK_K == 64
6324
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6325
+ #else
6326
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6327
+ #endif
5530
6328
 
5531
6329
  //
5532
6330
  // matrix-matrix multiplication
@@ -5566,8 +6364,15 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
5566
6364
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5567
6365
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5568
6366
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6367
+ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
6368
+ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
5569
6369
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5570
- template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6370
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6371
+ #if QK_K == 64
6372
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6373
+ #else
6374
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6375
+ #endif
5571
6376
 
5572
6377
  //
5573
6378
  // indirect matrix-matrix multiplication
@@ -5619,8 +6424,15 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
5619
6424
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5620
6425
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5621
6426
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6427
+ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
6428
+ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
5622
6429
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
5623
- template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6430
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6431
+ #if QK_K == 64
6432
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6433
+ #else
6434
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6435
+ #endif
5624
6436
 
5625
6437
  //
5626
6438
  // matrix-vector multiplication
@@ -6589,6 +7401,136 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6589
7401
  sgitg);
6590
7402
  }
6591
7403
 
7404
+ [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
7405
+ kernel void kernel_mul_mv_id_iq3_s_f32(
7406
+ device const char * ids,
7407
+ device const char * src1,
7408
+ device float * dst,
7409
+ constant uint64_t & nbi1,
7410
+ constant int64_t & ne00,
7411
+ constant int64_t & ne01,
7412
+ constant int64_t & ne02,
7413
+ constant uint64_t & nb00,
7414
+ constant uint64_t & nb01,
7415
+ constant uint64_t & nb02,
7416
+ constant int64_t & ne10,
7417
+ constant int64_t & ne11,
7418
+ constant int64_t & ne12,
7419
+ constant int64_t & ne13,
7420
+ constant uint64_t & nb10,
7421
+ constant uint64_t & nb11,
7422
+ constant uint64_t & nb12,
7423
+ constant int64_t & ne0,
7424
+ constant int64_t & ne1,
7425
+ constant uint64_t & nb1,
7426
+ constant uint & r2,
7427
+ constant uint & r3,
7428
+ constant int & idx,
7429
+ device const char * src00,
7430
+ device const char * src01,
7431
+ device const char * src02,
7432
+ device const char * src03,
7433
+ device const char * src04,
7434
+ device const char * src05,
7435
+ device const char * src06,
7436
+ device const char * src07,
7437
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
7438
+ uint3 tgpig[[threadgroup_position_in_grid]],
7439
+ uint tiitg[[thread_index_in_threadgroup]],
7440
+ uint tiisg[[thread_index_in_simdgroup]],
7441
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7442
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7443
+
7444
+ const int64_t bid = tgpig.z/(ne12*ne13);
7445
+
7446
+ tgpig.z = tgpig.z%(ne12*ne13);
7447
+
7448
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7449
+
7450
+ kernel_mul_mv_iq3_s_f32_impl(
7451
+ src0[id],
7452
+ (device const float *) (src1 + bid*nb11),
7453
+ dst + bid*ne0,
7454
+ ne00,
7455
+ ne01,
7456
+ ne02,
7457
+ ne10,
7458
+ ne12,
7459
+ ne0,
7460
+ ne1,
7461
+ r2,
7462
+ r3,
7463
+ shared_values,
7464
+ tgpig,
7465
+ tiisg,
7466
+ sgitg);
7467
+ }
7468
+
7469
+ [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
7470
+ kernel void kernel_mul_mv_id_iq2_s_f32(
7471
+ device const char * ids,
7472
+ device const char * src1,
7473
+ device float * dst,
7474
+ constant uint64_t & nbi1,
7475
+ constant int64_t & ne00,
7476
+ constant int64_t & ne01,
7477
+ constant int64_t & ne02,
7478
+ constant uint64_t & nb00,
7479
+ constant uint64_t & nb01,
7480
+ constant uint64_t & nb02,
7481
+ constant int64_t & ne10,
7482
+ constant int64_t & ne11,
7483
+ constant int64_t & ne12,
7484
+ constant int64_t & ne13,
7485
+ constant uint64_t & nb10,
7486
+ constant uint64_t & nb11,
7487
+ constant uint64_t & nb12,
7488
+ constant int64_t & ne0,
7489
+ constant int64_t & ne1,
7490
+ constant uint64_t & nb1,
7491
+ constant uint & r2,
7492
+ constant uint & r3,
7493
+ constant int & idx,
7494
+ device const char * src00,
7495
+ device const char * src01,
7496
+ device const char * src02,
7497
+ device const char * src03,
7498
+ device const char * src04,
7499
+ device const char * src05,
7500
+ device const char * src06,
7501
+ device const char * src07,
7502
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
7503
+ uint3 tgpig[[threadgroup_position_in_grid]],
7504
+ uint tiitg[[thread_index_in_threadgroup]],
7505
+ uint tiisg[[thread_index_in_simdgroup]],
7506
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7507
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7508
+
7509
+ const int64_t bid = tgpig.z/(ne12*ne13);
7510
+
7511
+ tgpig.z = tgpig.z%(ne12*ne13);
7512
+
7513
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7514
+
7515
+ kernel_mul_mv_iq2_s_f32_impl(
7516
+ src0[id],
7517
+ (device const float *) (src1 + bid*nb11),
7518
+ dst + bid*ne0,
7519
+ ne00,
7520
+ ne01,
7521
+ ne02,
7522
+ ne10,
7523
+ ne12,
7524
+ ne0,
7525
+ ne1,
7526
+ r2,
7527
+ r3,
7528
+ shared_values,
7529
+ tgpig,
7530
+ tiisg,
7531
+ sgitg);
7532
+ }
7533
+
6592
7534
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6593
7535
  kernel void kernel_mul_mv_id_iq1_s_f32(
6594
7536
  device const char * ids,
@@ -6716,3 +7658,72 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
6716
7658
  tiisg,
6717
7659
  sgitg);
6718
7660
  }
7661
+
7662
+ [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7663
+ kernel void kernel_mul_mv_id_iq4_xs_f32(
7664
+ device const char * ids,
7665
+ device const char * src1,
7666
+ device float * dst,
7667
+ constant uint64_t & nbi1,
7668
+ constant int64_t & ne00,
7669
+ constant int64_t & ne01,
7670
+ constant int64_t & ne02,
7671
+ constant uint64_t & nb00,
7672
+ constant uint64_t & nb01,
7673
+ constant uint64_t & nb02,
7674
+ constant int64_t & ne10,
7675
+ constant int64_t & ne11,
7676
+ constant int64_t & ne12,
7677
+ constant int64_t & ne13,
7678
+ constant uint64_t & nb10,
7679
+ constant uint64_t & nb11,
7680
+ constant uint64_t & nb12,
7681
+ constant int64_t & ne0,
7682
+ constant int64_t & ne1,
7683
+ constant uint64_t & nb1,
7684
+ constant uint & r2,
7685
+ constant uint & r3,
7686
+ constant int & idx,
7687
+ device const char * src00,
7688
+ device const char * src01,
7689
+ device const char * src02,
7690
+ device const char * src03,
7691
+ device const char * src04,
7692
+ device const char * src05,
7693
+ device const char * src06,
7694
+ device const char * src07,
7695
+ threadgroup float * shared_values [[threadgroup(0)]],
7696
+ uint3 tgpig[[threadgroup_position_in_grid]],
7697
+ uint tiitg[[thread_index_in_threadgroup]],
7698
+ uint tiisg[[thread_index_in_simdgroup]],
7699
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7700
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7701
+
7702
+ const int64_t bid = tgpig.z/(ne12*ne13);
7703
+
7704
+ tgpig.z = tgpig.z%(ne12*ne13);
7705
+
7706
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7707
+
7708
+ #if QK_K == 64
7709
+ kernel_mul_mv_iq4_nl_f32_impl(
7710
+ #else
7711
+ kernel_mul_mv_iq4_xs_f32_impl(
7712
+ #endif
7713
+ src0[id],
7714
+ (device const float *) (src1 + bid*nb11),
7715
+ dst + bid*ne0,
7716
+ ne00,
7717
+ ne01,
7718
+ ne02,
7719
+ ne10,
7720
+ ne12,
7721
+ ne0,
7722
+ ne1,
7723
+ r2,
7724
+ r3,
7725
+ shared_values,
7726
+ tgpig,
7727
+ tiisg,
7728
+ sgitg);
7729
+ }