llama_cpp 0.12.7 → 0.14.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
1959
1959
  }
1960
1960
  }
1961
1961
 
1962
+ kernel void kernel_arange_f32(
1963
+ device char * dst,
1964
+ constant int64_t & ne0,
1965
+ constant float & start,
1966
+ constant float & step,
1967
+ uint3 tgpig[[threadgroup_position_in_grid]],
1968
+ uint3 tpitg[[thread_position_in_threadgroup]],
1969
+ uint3 ntg[[threads_per_threadgroup]]) {
1970
+
1971
+ device float * dst_ptr = (device float *) dst;
1972
+
1973
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1974
+ dst_ptr[i0] = start + step * i0;
1975
+ }
1976
+ }
1977
+
1978
+ kernel void kernel_timestep_embedding_f32(
1979
+ device const char * src0,
1980
+ device char * dst,
1981
+ constant uint64_t & nb1,
1982
+ constant int & dim,
1983
+ constant int & max_period,
1984
+ uint3 tgpig[[threadgroup_position_in_grid]],
1985
+ uint3 tpitg[[thread_position_in_threadgroup]],
1986
+ uint3 ntg[[threads_per_threadgroup]]) {
1987
+
1988
+ int i = tgpig.x;
1989
+ device float * embed_data = (device float *)(dst + i*nb1);
1990
+
1991
+ int half_ = dim / 2;
1992
+ for (int j = tpitg.x; j < half_; j += ntg.x) {
1993
+ float timestep = ((device float *)src0)[i];
1994
+ float freq = (float)exp(-log((float)max_period) * j / half_);
1995
+ float arg = timestep * freq;
1996
+ embed_data[j ] = cos(arg);
1997
+ embed_data[j + half_] = sin(arg);
1998
+ }
1999
+
2000
+ if (dim % 2 != 0 && tpitg.x == 0) {
2001
+ embed_data[dim] = 0.f;
2002
+ }
2003
+ }
2004
+
1962
2005
  // bitonic sort implementation following the CUDA kernels as reference
1963
2006
  typedef void (argsort_t)(
1964
2007
  device const float * x,
@@ -2519,12 +2562,34 @@ typedef struct {
2519
2562
  } block_iq2_xs;
2520
2563
  // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2521
2564
 
2565
+ // 2.5625 bpw quants
2566
+ typedef struct {
2567
+ half d;
2568
+ uint8_t qs[QK_K/4];
2569
+ uint8_t qh[QK_K/32];
2570
+ uint8_t scales[QK_K/32];
2571
+ } block_iq2_s;
2572
+
2522
2573
  typedef struct {
2523
2574
  half d;
2524
2575
  uint8_t qs[3*QK_K/8];
2525
2576
  } block_iq3_xxs;
2526
2577
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2527
2578
 
2579
+ // 3.4375 bpw
2580
+ #if QK_K == 64
2581
+ #define IQ3S_N_SCALE 2
2582
+ #else
2583
+ #define IQ3S_N_SCALE QK_K/64
2584
+ #endif
2585
+ typedef struct {
2586
+ half d;
2587
+ uint8_t qs[QK_K/4];
2588
+ uint8_t qh[QK_K/32];
2589
+ uint8_t signs[QK_K/8];
2590
+ uint8_t scales[IQ3S_N_SCALE];
2591
+ } block_iq3_s;
2592
+
2528
2593
  typedef struct {
2529
2594
  half d;
2530
2595
  uint8_t qs[QK_K/8];
@@ -2538,6 +2603,17 @@ typedef struct {
2538
2603
  uint8_t qs[QK4_NL/2];
2539
2604
  } block_iq4_nl;
2540
2605
 
2606
+ #if QK_K == 64
2607
+ #define block_iq4_xs block_iq4_nl
2608
+ #else
2609
+ typedef struct {
2610
+ half d;
2611
+ uint16_t scales_h;
2612
+ uint8_t scales_l[QK_K/64];
2613
+ uint8_t qs[QK_K/2];
2614
+ } block_iq4_xs;
2615
+ #endif
2616
+
2541
2617
  //====================================== dot products =========================
2542
2618
 
2543
2619
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3760,6 +3836,265 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
3760
3836
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3761
3837
  };
3762
3838
 
3839
+ constexpr constant static uint64_t iq2s_grid[1024] = {
3840
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3841
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3842
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3843
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3844
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3845
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
3846
+ 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
3847
+ 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
3848
+ 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
3849
+ 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
3850
+ 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
3851
+ 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
3852
+ 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
3853
+ 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
3854
+ 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
3855
+ 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
3856
+ 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
3857
+ 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
3858
+ 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
3859
+ 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
3860
+ 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
3861
+ 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
3862
+ 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
3863
+ 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
3864
+ 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
3865
+ 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
3866
+ 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
3867
+ 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
3868
+ 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
3869
+ 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
3870
+ 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
3871
+ 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
3872
+ 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
3873
+ 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
3874
+ 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
3875
+ 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
3876
+ 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
3877
+ 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
3878
+ 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
3879
+ 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
3880
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
3881
+ 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
3882
+ 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
3883
+ 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
3884
+ 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
3885
+ 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
3886
+ 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
3887
+ 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
3888
+ 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
3889
+ 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
3890
+ 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
3891
+ 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
3892
+ 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
3893
+ 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
3894
+ 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
3895
+ 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
3896
+ 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
3897
+ 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
3898
+ 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
3899
+ 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
3900
+ 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
3901
+ 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
3902
+ 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
3903
+ 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
3904
+ 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
3905
+ 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
3906
+ 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
3907
+ 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
3908
+ 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
3909
+ 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
3910
+ 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
3911
+ 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
3912
+ 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
3913
+ 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
3914
+ 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
3915
+ 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
3916
+ 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
3917
+ 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
3918
+ 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
3919
+ 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
3920
+ 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
3921
+ 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
3922
+ 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
3923
+ 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
3924
+ 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
3925
+ 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
3926
+ 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
3927
+ 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
3928
+ 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
3929
+ 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
3930
+ 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
3931
+ 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
3932
+ 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
3933
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
3934
+ 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
3935
+ 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
3936
+ 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
3937
+ 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
3938
+ 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
3939
+ 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
3940
+ 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
3941
+ 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
3942
+ 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
3943
+ 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
3944
+ 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
3945
+ 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
3946
+ 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
3947
+ 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
3948
+ 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
3949
+ 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
3950
+ 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
3951
+ 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
3952
+ 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
3953
+ 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
3954
+ 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
3955
+ 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
3956
+ 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
3957
+ 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
3958
+ 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
3959
+ 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
3960
+ 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
3961
+ 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
3962
+ 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
3963
+ 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
3964
+ 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
3965
+ 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
3966
+ 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
3967
+ 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
3968
+ 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
3969
+ 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
3970
+ 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
3971
+ 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
3972
+ 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
3973
+ 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
3974
+ 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
3975
+ 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
3976
+ 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
3977
+ 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
3978
+ 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
3979
+ 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
3980
+ 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
3981
+ 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
3982
+ 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
3983
+ 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
3984
+ 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
3985
+ 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
3986
+ 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
3987
+ 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
3988
+ 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
3989
+ 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
3990
+ 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
3991
+ 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
3992
+ 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
3993
+ 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
3994
+ 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
3995
+ 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
3996
+ 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
3997
+ 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
3998
+ 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
3999
+ 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
4000
+ 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
4001
+ 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
4002
+ 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
4003
+ 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
4004
+ 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
4005
+ 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
4006
+ 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
4007
+ 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
4008
+ 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
4009
+ 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
4010
+ 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
4011
+ 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
4012
+ 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
4013
+ 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
4014
+ 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
4015
+ 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
4016
+ 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
4017
+ 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
4018
+ 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
4019
+ 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
4020
+ 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
4021
+ 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
4022
+ 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
4023
+ 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
4024
+ 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
4025
+ 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
4026
+ 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
4027
+ 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
4028
+ 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
4029
+ 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
4030
+ 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
4031
+ 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
4032
+ 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
4033
+ 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
4034
+ 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
4035
+ 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
4036
+ 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
4037
+ 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
4038
+ 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
4039
+ 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
4040
+ 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
4041
+ 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
4042
+ 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
4043
+ 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
4044
+ 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
4045
+ 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
4046
+ 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
4047
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
4048
+ 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
4049
+ 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
4050
+ 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
4051
+ 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
4052
+ 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
4053
+ 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
4054
+ 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
4055
+ 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
4056
+ 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
4057
+ 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
4058
+ 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
4059
+ 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
4060
+ 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
4061
+ 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
4062
+ 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
4063
+ 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
4064
+ 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
4065
+ 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
4066
+ 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
4067
+ 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
4068
+ 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
4069
+ 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
4070
+ 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
4071
+ 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
4072
+ 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
4073
+ 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
4074
+ 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
4075
+ 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
4076
+ 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
4077
+ 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
4078
+ 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
4079
+ 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
4080
+ 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
4081
+ 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
4082
+ 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
4083
+ 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
4084
+ 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
4085
+ 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
4086
+ 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
4087
+ 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
4088
+ 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
4089
+ 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
4090
+ 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
4091
+ 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
4092
+ 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
4093
+ 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
4094
+ 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
4095
+ 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
4096
+ };
4097
+
3763
4098
  constexpr constant static uint32_t iq3xxs_grid[256] = {
3764
4099
  0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
3765
4100
  0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -3795,6 +4130,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
3795
4130
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3796
4131
  };
3797
4132
 
4133
+ constexpr constant static uint32_t iq3s_grid[512] = {
4134
+ 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
4135
+ 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
4136
+ 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
4137
+ 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
4138
+ 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
4139
+ 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
4140
+ 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
4141
+ 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
4142
+ 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
4143
+ 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
4144
+ 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
4145
+ 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
4146
+ 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
4147
+ 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
4148
+ 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
4149
+ 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
4150
+ 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
4151
+ 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
4152
+ 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
4153
+ 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
4154
+ 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
4155
+ 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
4156
+ 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
4157
+ 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
4158
+ 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
4159
+ 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
4160
+ 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
4161
+ 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
4162
+ 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
4163
+ 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
4164
+ 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
4165
+ 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
4166
+ 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
4167
+ 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
4168
+ 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
4169
+ 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
4170
+ 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
4171
+ 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
4172
+ 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
4173
+ 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
4174
+ 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
4175
+ 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
4176
+ 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
4177
+ 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
4178
+ 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
4179
+ 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
4180
+ 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
4181
+ 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
4182
+ 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
4183
+ 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
4184
+ 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
4185
+ 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
4186
+ 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
4187
+ 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
4188
+ 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
4189
+ 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
4190
+ 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
4191
+ 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
4192
+ 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
4193
+ 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
4194
+ 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
4195
+ 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
4196
+ 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
4197
+ 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
4198
+ };
4199
+
3798
4200
  #define NGRID_IQ1S 512
3799
4201
  constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
3800
4202
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@@ -3991,7 +4393,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3991
4393
  threadgroup_barrier(mem_flags::mem_threadgroup);
3992
4394
  }
3993
4395
 
3994
- #if QK_K == 256
3995
4396
  const int ix = tiisg;
3996
4397
 
3997
4398
  device const float * y4 = y + 32 * ix;
@@ -4032,12 +4433,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
4032
4433
 
4033
4434
  y4 += 32 * 32;
4034
4435
  }
4035
- #else
4036
- (void) x;
4037
- (void) y;
4038
- (void) yl;
4039
- (void) nb32;
4040
- #endif
4041
4436
 
4042
4437
  for (int row = 0; row < N_DST; ++row) {
4043
4438
  all_sum = simd_sum(sumf[row]);
@@ -4127,7 +4522,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4127
4522
  threadgroup_barrier(mem_flags::mem_threadgroup);
4128
4523
  }
4129
4524
 
4130
- #if QK_K == 256
4131
4525
  const int ix = tiisg;
4132
4526
 
4133
4527
  device const float * y4 = y + 32 * ix;
@@ -4178,12 +4572,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
4178
4572
 
4179
4573
  y4 += 32 * 32;
4180
4574
  }
4181
- #else
4182
- (void) x;
4183
- (void) y;
4184
- (void) yl;
4185
- (void) nb32;
4186
- #endif
4187
4575
 
4188
4576
  for (int row = 0; row < N_DST; ++row) {
4189
4577
  all_sum = simd_sum(sumf[row]);
@@ -4273,7 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4273
4661
  threadgroup_barrier(mem_flags::mem_threadgroup);
4274
4662
  }
4275
4663
 
4276
- #if QK_K == 256
4277
4664
  const int ix = tiisg;
4278
4665
 
4279
4666
  device const float * y4 = y + 32 * ix;
@@ -4317,12 +4704,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4317
4704
 
4318
4705
  y4 += 32 * 32;
4319
4706
  }
4320
- #else
4321
- (void) x;
4322
- (void) y;
4323
- (void) yl;
4324
- (void) nb32;
4325
- #endif
4326
4707
 
4327
4708
  for (int row = 0; row < N_DST; ++row) {
4328
4709
  all_sum = simd_sum(sumf[row]);
@@ -4361,7 +4742,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4361
4742
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4362
4743
  }
4363
4744
 
4364
- void kernel_mul_mv_iq1_s_f32_impl(
4745
+ void kernel_mul_mv_iq3_s_f32_impl(
4365
4746
  device const void * src0,
4366
4747
  device const float * src1,
4367
4748
  device float * dst,
@@ -4374,6 +4755,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
4374
4755
  constant int64_t & ne1,
4375
4756
  constant uint & r2,
4376
4757
  constant uint & r3,
4758
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4377
4759
  uint3 tgpig[[threadgroup_position_in_grid]],
4378
4760
  uint tiisg[[thread_index_in_simdgroup]],
4379
4761
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4390,59 +4772,70 @@ void kernel_mul_mv_iq1_s_f32_impl(
4390
4772
  const uint i13 = im/ne12;
4391
4773
 
4392
4774
  const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4393
- device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4775
+
4776
+ device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
4394
4777
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4395
4778
 
4396
- float yl[16];
4779
+ float yl[32];
4397
4780
  float sumf[N_DST]={0.f}, all_sum;
4398
4781
 
4399
4782
  const int nb32 = nb * (QK_K / 32);
4400
4783
 
4401
- #if QK_K == 256
4402
- const int ix = tiisg/2;
4403
- const int il = tiisg%2;
4784
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
4785
+ {
4786
+ int nval = 8;
4787
+ int pos = (32*sgitg + tiisg)*nval;
4788
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
4789
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4790
+ }
4404
4791
 
4405
- device const float * y4 = y + 32 * ix + 16 * il;
4792
+ const int ix = tiisg;
4406
4793
 
4407
- for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
4794
+ device const float * y4 = y + 32 * ix;
4408
4795
 
4409
- for (int i = 0; i < 16; ++i) {
4796
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4797
+
4798
+ for (int i = 0; i < 32; ++i) {
4410
4799
  yl[i] = y4[i];
4411
4800
  }
4412
4801
 
4413
4802
  const int ibl = ib32 / (QK_K / 32);
4414
4803
  const int ib = ib32 % (QK_K / 32);
4415
4804
 
4416
- device const block_iq1_s * xr = x + ibl;
4417
- device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
4418
- device const uint8_t * sc = xr->scales + 2 * ib + il;
4419
- device const half * dh = &xr->d;
4805
+ device const block_iq3_s * xr = x + ibl;
4806
+ device const uint8_t * qs = xr->qs + 8 * ib;
4807
+ device const uint8_t * qh = xr->qh + ib;
4808
+ device const uint8_t * sc = xr->scales + (ib/2);
4809
+ device const uint8_t * signs = xr->signs + 4 * ib;
4810
+ device const half * dh = &xr->d;
4420
4811
 
4421
4812
  for (int row = 0; row < N_DST; row++) {
4422
4813
 
4423
- constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
4424
- constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
4814
+ const float db = dh[0];
4815
+ const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
4425
4816
 
4426
4817
  float2 sum = {0};
4427
- for (int j = 0; j < 8; ++j) {
4428
- sum[0] += yl[j+ 0] * grid1[j];
4429
- sum[1] += yl[j+ 8] * grid2[j];
4818
+ for (int l = 0; l < 4; ++l) {
4819
+ const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
4820
+ const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
4821
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
4822
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
4823
+ for (int j = 0; j < 4; ++j) {
4824
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4825
+ sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
4826
+ }
4430
4827
  }
4431
- sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
4828
+ sumf[row] += d * (sum[0] + sum[1]);
4432
4829
 
4433
- dh += nb*sizeof(block_iq1_s)/2;
4434
- qs += nb*sizeof(block_iq1_s);
4435
- sc += nb*sizeof(block_iq1_s);
4830
+ dh += nb*sizeof(block_iq3_s)/2;
4831
+ qs += nb*sizeof(block_iq3_s);
4832
+ qh += nb*sizeof(block_iq3_s);
4833
+ sc += nb*sizeof(block_iq3_s);
4834
+ signs += nb*sizeof(block_iq3_s);
4436
4835
  }
4437
4836
 
4438
- y4 += 16 * 32;
4837
+ y4 += 32 * 32;
4439
4838
  }
4440
- #else
4441
- (void) x;
4442
- (void) y;
4443
- (void) yl;
4444
- (void) nb32;
4445
- #endif
4446
4839
 
4447
4840
  for (int row = 0; row < N_DST; ++row) {
4448
4841
  all_sum = simd_sum(sumf[row]);
@@ -4452,11 +4845,36 @@ void kernel_mul_mv_iq1_s_f32_impl(
4452
4845
  }
4453
4846
  }
4454
4847
 
4455
- constexpr constant static float kvalues_iq4nl_f[16] = {
4456
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
4457
- };
4848
+ [[host_name("kernel_mul_mv_iq3_s_f32")]]
4849
+ kernel void kernel_mul_mv_iq3_s_f32(
4850
+ device const void * src0,
4851
+ device const float * src1,
4852
+ device float * dst,
4853
+ constant int64_t & ne00,
4854
+ constant int64_t & ne01,
4855
+ constant int64_t & ne02,
4856
+ constant uint64_t & nb00,
4857
+ constant uint64_t & nb01,
4858
+ constant uint64_t & nb02,
4859
+ constant int64_t & ne10,
4860
+ constant int64_t & ne11,
4861
+ constant int64_t & ne12,
4862
+ constant uint64_t & nb10,
4863
+ constant uint64_t & nb11,
4864
+ constant uint64_t & nb12,
4865
+ constant int64_t & ne0,
4866
+ constant int64_t & ne1,
4867
+ constant uint & r2,
4868
+ constant uint & r3,
4869
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4870
+ uint3 tgpig[[threadgroup_position_in_grid]],
4871
+ uint tiisg[[thread_index_in_simdgroup]],
4872
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4458
4873
 
4459
- void kernel_mul_mv_iq4_nl_f32_impl(
4874
+ kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4875
+ }
4876
+
4877
+ void kernel_mul_mv_iq2_s_f32_impl(
4460
4878
  device const void * src0,
4461
4879
  device const float * src1,
4462
4880
  device float * dst,
@@ -4469,88 +4887,99 @@ void kernel_mul_mv_iq4_nl_f32_impl(
4469
4887
  constant int64_t & ne1,
4470
4888
  constant uint & r2,
4471
4889
  constant uint & r3,
4472
- threadgroup float * shared_values [[threadgroup(0)]],
4890
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4473
4891
  uint3 tgpig[[threadgroup_position_in_grid]],
4474
4892
  uint tiisg[[thread_index_in_simdgroup]],
4475
4893
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4476
4894
 
4477
- const int nb = ne00/QK4_NL;
4895
+ const int nb = ne00/QK_K;
4478
4896
  const int r0 = tgpig.x;
4479
4897
  const int r1 = tgpig.y;
4480
4898
  const int im = tgpig.z;
4481
- const int first_row = (r0 * 2 + sgitg) * 2;
4899
+
4900
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4482
4901
  const int ib_row = first_row * nb;
4483
4902
 
4484
4903
  const uint i12 = im%ne12;
4485
4904
  const uint i13 = im/ne12;
4486
4905
 
4487
4906
  const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4488
- device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
4489
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4490
-
4491
- const int ix = tiisg/2; // 0...15
4492
- const int it = tiisg%2; // 0 or 1
4493
4907
 
4494
- shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
4495
- threadgroup_barrier(mem_flags::mem_threadgroup);
4496
-
4497
- float4 yl[4];
4498
- float sumf[2]={0.f}, all_sum;
4908
+ device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
4909
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4499
4910
 
4500
- device const float * yb = y + ix * QK4_NL + it * 8;
4911
+ float yl[32];
4912
+ float sumf[N_DST]={0.f}, all_sum;
4501
4913
 
4502
- uint32_t aux32[2];
4503
- thread const uint8_t * q8 = (thread const uint8_t *)aux32;
4914
+ const int nb32 = nb * (QK_K / 32);
4504
4915
 
4505
- float4 qf1, qf2;
4916
+ //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
4917
+ //{
4918
+ // int nval = 32;
4919
+ // int pos = (32*sgitg + tiisg)*nval;
4920
+ // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
4921
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
4922
+ //}
4506
4923
 
4507
- for (int ib = ix; ib < nb; ib += 16) {
4924
+ const int ix = tiisg;
4508
4925
 
4509
- device const float4 * y4 = (device const float4 *)yb;
4510
- yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
4926
+ device const float * y4 = y + 32 * ix;
4511
4927
 
4512
- for (int row = 0; row < 2; ++row) {
4928
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4513
4929
 
4514
- device const block_iq4_nl & xb = x[row*nb + ib];
4515
- device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
4930
+ for (int i = 0; i < 32; ++i) {
4931
+ yl[i] = y4[i];
4932
+ }
4516
4933
 
4517
- float4 acc1 = {0.f}, acc2 = {0.f};
4934
+ const int ibl = ib32 / (QK_K / 32);
4935
+ const int ib = ib32 % (QK_K / 32);
4518
4936
 
4519
- aux32[0] = q4[0] | (q4[1] << 16);
4520
- aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
4521
- aux32[0] &= 0x0f0f0f0f;
4522
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
4523
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
4524
- acc1 += yl[0] * qf1;
4525
- acc2 += yl[1] * qf2;
4937
+ device const block_iq2_s * xr = x + ibl;
4938
+ device const uint8_t * qs = xr->qs + 4 * ib;
4939
+ device const uint8_t * qh = xr->qh + ib;
4940
+ device const uint8_t * sc = xr->scales + ib;
4941
+ device const uint8_t * signs = qs + QK_K/8;
4942
+ device const half * dh = &xr->d;
4526
4943
 
4527
- aux32[0] = q4[2] | (q4[3] << 16);
4528
- aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
4529
- aux32[0] &= 0x0f0f0f0f;
4530
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
4531
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
4532
- acc1 += yl[2] * qf1;
4533
- acc2 += yl[3] * qf2;
4944
+ for (int row = 0; row < N_DST; row++) {
4534
4945
 
4535
- acc1 += acc2;
4946
+ const float db = dh[0];
4947
+ const float d1 = db * (0.5f + (sc[0] & 0xf));
4948
+ const float d2 = db * (0.5f + (sc[0] >> 4));
4536
4949
 
4537
- sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
4950
+ float2 sum = {0};
4951
+ for (int l = 0; l < 2; ++l) {
4952
+ //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
4953
+ //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
4954
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
4955
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
4956
+ for (int j = 0; j < 8; ++j) {
4957
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
4958
+ sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
4959
+ }
4960
+ }
4961
+ sumf[row] += d1 * sum[0] + d2 * sum[1];
4538
4962
 
4963
+ dh += nb*sizeof(block_iq2_s)/2;
4964
+ qs += nb*sizeof(block_iq2_s);
4965
+ qh += nb*sizeof(block_iq2_s);
4966
+ sc += nb*sizeof(block_iq2_s);
4967
+ signs += nb*sizeof(block_iq2_s);
4539
4968
  }
4540
4969
 
4541
- yb += 16 * QK4_NL;
4970
+ y4 += 32 * 32;
4542
4971
  }
4543
4972
 
4544
- for (int row = 0; row < 2; ++row) {
4973
+ for (int row = 0; row < N_DST; ++row) {
4545
4974
  all_sum = simd_sum(sumf[row]);
4546
4975
  if (tiisg == 0) {
4547
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4976
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
4548
4977
  }
4549
4978
  }
4550
4979
  }
4551
4980
 
4552
- [[host_name("kernel_mul_mv_iq1_s_f32")]]
4553
- kernel void kernel_mul_mv_iq1_s_f32(
4981
+ [[host_name("kernel_mul_mv_iq2_s_f32")]]
4982
+ kernel void kernel_mul_mv_iq2_s_f32(
4554
4983
  device const void * src0,
4555
4984
  device const float * src1,
4556
4985
  device float * dst,
@@ -4570,67 +4999,406 @@ kernel void kernel_mul_mv_iq1_s_f32(
4570
4999
  constant int64_t & ne1,
4571
5000
  constant uint & r2,
4572
5001
  constant uint & r3,
5002
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4573
5003
  uint3 tgpig[[threadgroup_position_in_grid]],
4574
5004
  uint tiisg[[thread_index_in_simdgroup]],
4575
5005
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4576
5006
 
4577
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
5007
+ kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4578
5008
  }
4579
5009
 
4580
- [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4581
- kernel void kernel_mul_mv_iq4_nl_f32(
5010
+ void kernel_mul_mv_iq1_s_f32_impl(
4582
5011
  device const void * src0,
4583
5012
  device const float * src1,
4584
5013
  device float * dst,
4585
5014
  constant int64_t & ne00,
4586
5015
  constant int64_t & ne01,
4587
5016
  constant int64_t & ne02,
4588
- constant uint64_t & nb00,
4589
- constant uint64_t & nb01,
4590
- constant uint64_t & nb02,
4591
5017
  constant int64_t & ne10,
4592
- constant int64_t & ne11,
4593
5018
  constant int64_t & ne12,
4594
- constant uint64_t & nb10,
4595
- constant uint64_t & nb11,
4596
- constant uint64_t & nb12,
4597
5019
  constant int64_t & ne0,
4598
5020
  constant int64_t & ne1,
4599
5021
  constant uint & r2,
4600
5022
  constant uint & r3,
4601
- threadgroup float * shared_values [[threadgroup(0)]],
4602
5023
  uint3 tgpig[[threadgroup_position_in_grid]],
4603
- uint tiisg[[thread_index_in_simdgroup]],
4604
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5024
+ uint tiisg[[thread_index_in_simdgroup]],
5025
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4605
5026
 
4606
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4607
- }
5027
+ const int nb = ne00/QK_K;
5028
+ const int r0 = tgpig.x;
5029
+ const int r1 = tgpig.y;
5030
+ const int im = tgpig.z;
4608
5031
 
4609
- //============================= templates and their specializations =============================
5032
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
5033
+ const int ib_row = first_row * nb;
4610
5034
 
4611
- // NOTE: this is not dequantizing - we are simply fitting the template
4612
- template <typename type4x4>
4613
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
4614
- float4x4 temp = *(((device float4x4 *)src));
4615
- for (int i = 0; i < 16; i++){
4616
- reg[i/4][i%4] = temp[i/4][i%4];
4617
- }
4618
- }
5035
+ const uint i12 = im%ne12;
5036
+ const uint i13 = im/ne12;
4619
5037
 
4620
- template <typename type4x4>
4621
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
4622
- half4x4 temp = *(((device half4x4 *)src));
4623
- for (int i = 0; i < 16; i++){
4624
- reg[i/4][i%4] = temp[i/4][i%4];
4625
- }
4626
- }
5038
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5039
+ device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
5040
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4627
5041
 
4628
- template <typename type4x4>
4629
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
4630
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
4631
- const float d1 = il ? (xb->d / 16.h) : xb->d;
4632
- const float d2 = d1 / 256.f;
4633
- const float md = -8.h * xb->d;
5042
+ float yl[16];
5043
+ float sumf[N_DST]={0.f}, all_sum;
5044
+
5045
+ const int nb32 = nb * (QK_K / 32);
5046
+
5047
+ const int ix = tiisg/2;
5048
+ const int il = tiisg%2;
5049
+
5050
+ device const float * y4 = y + 32 * ix + 16 * il;
5051
+
5052
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
5053
+
5054
+ for (int i = 0; i < 16; ++i) {
5055
+ yl[i] = y4[i];
5056
+ }
5057
+
5058
+ const int ibl = ib32 / (QK_K / 32);
5059
+ const int ib = ib32 % (QK_K / 32);
5060
+
5061
+ device const block_iq1_s * xr = x + ibl;
5062
+ device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
5063
+ device const uint8_t * sc = xr->scales + 2 * ib + il;
5064
+ device const half * dh = &xr->d;
5065
+
5066
+ for (int row = 0; row < N_DST; row++) {
5067
+
5068
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
5069
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
5070
+
5071
+ float2 sum = {0};
5072
+ for (int j = 0; j < 8; ++j) {
5073
+ sum[0] += yl[j+ 0] * grid1[j];
5074
+ sum[1] += yl[j+ 8] * grid2[j];
5075
+ }
5076
+ sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
5077
+
5078
+ dh += nb*sizeof(block_iq1_s)/2;
5079
+ qs += nb*sizeof(block_iq1_s);
5080
+ sc += nb*sizeof(block_iq1_s);
5081
+ }
5082
+
5083
+ y4 += 16 * 32;
5084
+ }
5085
+
5086
+ for (int row = 0; row < N_DST; ++row) {
5087
+ all_sum = simd_sum(sumf[row]);
5088
+ if (tiisg == 0) {
5089
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5090
+ }
5091
+ }
5092
+ }
5093
+
5094
+ constexpr constant static float kvalues_iq4nl_f[16] = {
5095
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
5096
+ };
5097
+
5098
+ void kernel_mul_mv_iq4_nl_f32_impl(
5099
+ device const void * src0,
5100
+ device const float * src1,
5101
+ device float * dst,
5102
+ constant int64_t & ne00,
5103
+ constant int64_t & ne01,
5104
+ constant int64_t & ne02,
5105
+ constant int64_t & ne10,
5106
+ constant int64_t & ne12,
5107
+ constant int64_t & ne0,
5108
+ constant int64_t & ne1,
5109
+ constant uint & r2,
5110
+ constant uint & r3,
5111
+ threadgroup float * shared_values [[threadgroup(0)]],
5112
+ uint3 tgpig[[threadgroup_position_in_grid]],
5113
+ uint tiisg[[thread_index_in_simdgroup]],
5114
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5115
+
5116
+ const int nb = ne00/QK4_NL;
5117
+ const int r0 = tgpig.x;
5118
+ const int r1 = tgpig.y;
5119
+ const int im = tgpig.z;
5120
+ const int first_row = (r0 * 2 + sgitg) * 2;
5121
+ const int ib_row = first_row * nb;
5122
+
5123
+ const uint i12 = im%ne12;
5124
+ const uint i13 = im/ne12;
5125
+
5126
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5127
+ device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
5128
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
5129
+
5130
+ const int ix = tiisg/2; // 0...15
5131
+ const int it = tiisg%2; // 0 or 1
5132
+
5133
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
5134
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5135
+
5136
+ float4 yl[4];
5137
+ float sumf[2]={0.f}, all_sum;
5138
+
5139
+ device const float * yb = y + ix * QK4_NL + it * 8;
5140
+
5141
+ uint32_t aux32[2];
5142
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
5143
+
5144
+ float4 qf1, qf2;
5145
+
5146
+ for (int ib = ix; ib < nb; ib += 16) {
5147
+
5148
+ device const float4 * y4 = (device const float4 *)yb;
5149
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
5150
+
5151
+ for (int row = 0; row < 2; ++row) {
5152
+
5153
+ device const block_iq4_nl & xb = x[row*nb + ib];
5154
+ device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
5155
+
5156
+ float4 acc1 = {0.f}, acc2 = {0.f};
5157
+
5158
+ aux32[0] = q4[0] | (q4[1] << 16);
5159
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
5160
+ aux32[0] &= 0x0f0f0f0f;
5161
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5162
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5163
+ acc1 += yl[0] * qf1;
5164
+ acc2 += yl[1] * qf2;
5165
+
5166
+ aux32[0] = q4[2] | (q4[3] << 16);
5167
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
5168
+ aux32[0] &= 0x0f0f0f0f;
5169
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5170
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5171
+ acc1 += yl[2] * qf1;
5172
+ acc2 += yl[3] * qf2;
5173
+
5174
+ acc1 += acc2;
5175
+
5176
+ sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
5177
+
5178
+ }
5179
+
5180
+ yb += 16 * QK4_NL;
5181
+ }
5182
+
5183
+ for (int row = 0; row < 2; ++row) {
5184
+ all_sum = simd_sum(sumf[row]);
5185
+ if (tiisg == 0) {
5186
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5187
+ }
5188
+ }
5189
+ }
5190
+
5191
+ #if QK_K != 64
5192
+ void kernel_mul_mv_iq4_xs_f32_impl(
5193
+ device const void * src0,
5194
+ device const float * src1,
5195
+ device float * dst,
5196
+ constant int64_t & ne00,
5197
+ constant int64_t & ne01,
5198
+ constant int64_t & ne02,
5199
+ constant int64_t & ne10,
5200
+ constant int64_t & ne12,
5201
+ constant int64_t & ne0,
5202
+ constant int64_t & ne1,
5203
+ constant uint & r2,
5204
+ constant uint & r3,
5205
+ threadgroup float * shared_values [[threadgroup(0)]],
5206
+ uint3 tgpig[[threadgroup_position_in_grid]],
5207
+ uint tiisg[[thread_index_in_simdgroup]],
5208
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5209
+
5210
+ const int nb = ne00/QK_K;
5211
+ const int r0 = tgpig.x;
5212
+ const int r1 = tgpig.y;
5213
+ const int im = tgpig.z;
5214
+ const int first_row = (r0 * 2 + sgitg) * 2;
5215
+ const int ib_row = first_row * nb;
5216
+
5217
+ const uint i12 = im%ne12;
5218
+ const uint i13 = im/ne12;
5219
+
5220
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5221
+ device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
5222
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
5223
+
5224
+ const int ix = tiisg/16; // 0 or 1
5225
+ const int it = tiisg%16; // 0...15
5226
+ const int ib = it/2;
5227
+ const int il = it%2;
5228
+
5229
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
5230
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5231
+
5232
+ float4 yl[4];
5233
+ float sumf[2]={0.f}, all_sum;
5234
+
5235
+ device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
5236
+
5237
+ uint32_t aux32[2];
5238
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
5239
+
5240
+ float4 qf1, qf2;
5241
+
5242
+ for (int ibl = ix; ibl < nb; ibl += 2) {
5243
+
5244
+ device const float4 * y4 = (device const float4 *)yb;
5245
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
5246
+
5247
+ for (int row = 0; row < 2; ++row) {
5248
+
5249
+ device const block_iq4_xs & xb = x[row*nb + ibl];
5250
+ device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
5251
+
5252
+ float4 acc1 = {0.f}, acc2 = {0.f};
5253
+
5254
+ aux32[0] = q4[0] & 0x0f0f0f0f;
5255
+ aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
5256
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5257
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5258
+ acc1 += yl[0] * qf1;
5259
+ acc2 += yl[1] * qf2;
5260
+
5261
+ aux32[0] = q4[1] & 0x0f0f0f0f;
5262
+ aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
5263
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5264
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5265
+ acc1 += yl[2] * qf1;
5266
+ acc2 += yl[3] * qf2;
5267
+
5268
+ acc1 += acc2;
5269
+
5270
+ const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
5271
+ sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
5272
+
5273
+ }
5274
+
5275
+ yb += 2 * QK_K;
5276
+ }
5277
+
5278
+ for (int row = 0; row < 2; ++row) {
5279
+ all_sum = simd_sum(sumf[row]);
5280
+ if (tiisg == 0) {
5281
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5282
+ }
5283
+ }
5284
+ }
5285
+ #endif
5286
+
5287
+ [[host_name("kernel_mul_mv_iq1_s_f32")]]
5288
+ kernel void kernel_mul_mv_iq1_s_f32(
5289
+ device const void * src0,
5290
+ device const float * src1,
5291
+ device float * dst,
5292
+ constant int64_t & ne00,
5293
+ constant int64_t & ne01,
5294
+ constant int64_t & ne02,
5295
+ constant uint64_t & nb00,
5296
+ constant uint64_t & nb01,
5297
+ constant uint64_t & nb02,
5298
+ constant int64_t & ne10,
5299
+ constant int64_t & ne11,
5300
+ constant int64_t & ne12,
5301
+ constant uint64_t & nb10,
5302
+ constant uint64_t & nb11,
5303
+ constant uint64_t & nb12,
5304
+ constant int64_t & ne0,
5305
+ constant int64_t & ne1,
5306
+ constant uint & r2,
5307
+ constant uint & r3,
5308
+ uint3 tgpig[[threadgroup_position_in_grid]],
5309
+ uint tiisg[[thread_index_in_simdgroup]],
5310
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5311
+
5312
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
5313
+ }
5314
+
5315
+ [[host_name("kernel_mul_mv_iq4_nl_f32")]]
5316
+ kernel void kernel_mul_mv_iq4_nl_f32(
5317
+ device const void * src0,
5318
+ device const float * src1,
5319
+ device float * dst,
5320
+ constant int64_t & ne00,
5321
+ constant int64_t & ne01,
5322
+ constant int64_t & ne02,
5323
+ constant uint64_t & nb00,
5324
+ constant uint64_t & nb01,
5325
+ constant uint64_t & nb02,
5326
+ constant int64_t & ne10,
5327
+ constant int64_t & ne11,
5328
+ constant int64_t & ne12,
5329
+ constant uint64_t & nb10,
5330
+ constant uint64_t & nb11,
5331
+ constant uint64_t & nb12,
5332
+ constant int64_t & ne0,
5333
+ constant int64_t & ne1,
5334
+ constant uint & r2,
5335
+ constant uint & r3,
5336
+ threadgroup float * shared_values [[threadgroup(0)]],
5337
+ uint3 tgpig[[threadgroup_position_in_grid]],
5338
+ uint tiisg[[thread_index_in_simdgroup]],
5339
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5340
+
5341
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5342
+ }
5343
+
5344
+ [[host_name("kernel_mul_mv_iq4_xs_f32")]]
5345
+ kernel void kernel_mul_mv_iq4_xs_f32(
5346
+ device const void * src0,
5347
+ device const float * src1,
5348
+ device float * dst,
5349
+ constant int64_t & ne00,
5350
+ constant int64_t & ne01,
5351
+ constant int64_t & ne02,
5352
+ constant uint64_t & nb00,
5353
+ constant uint64_t & nb01,
5354
+ constant uint64_t & nb02,
5355
+ constant int64_t & ne10,
5356
+ constant int64_t & ne11,
5357
+ constant int64_t & ne12,
5358
+ constant uint64_t & nb10,
5359
+ constant uint64_t & nb11,
5360
+ constant uint64_t & nb12,
5361
+ constant int64_t & ne0,
5362
+ constant int64_t & ne1,
5363
+ constant uint & r2,
5364
+ constant uint & r3,
5365
+ threadgroup float * shared_values [[threadgroup(0)]],
5366
+ uint3 tgpig[[threadgroup_position_in_grid]],
5367
+ uint tiisg[[thread_index_in_simdgroup]],
5368
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5369
+
5370
+ #if QK_K == 64
5371
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5372
+ #else
5373
+ kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5374
+ #endif
5375
+ }
5376
+
5377
+ //============================= templates and their specializations =============================
5378
+
5379
+ // NOTE: this is not dequantizing - we are simply fitting the template
5380
+ template <typename type4x4>
5381
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
5382
+ float4x4 temp = *(((device float4x4 *)src));
5383
+ for (int i = 0; i < 16; i++){
5384
+ reg[i/4][i%4] = temp[i/4][i%4];
5385
+ }
5386
+ }
5387
+
5388
+ template <typename type4x4>
5389
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
5390
+ half4x4 temp = *(((device half4x4 *)src));
5391
+ for (int i = 0; i < 16; i++){
5392
+ reg[i/4][i%4] = temp[i/4][i%4];
5393
+ }
5394
+ }
5395
+
5396
+ template <typename type4x4>
5397
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
5398
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
5399
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
5400
+ const float d2 = d1 / 256.f;
5401
+ const float md = -8.h * xb->d;
4634
5402
  const ushort mask0 = il ? 0x00F0 : 0x000F;
4635
5403
  const ushort mask1 = mask0 << 8;
4636
5404
 
@@ -4952,6 +5720,50 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
4952
5720
  }
4953
5721
  }
4954
5722
 
5723
+ template <typename type4x4>
5724
+ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
5725
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5726
+ const float d = xb->d;
5727
+ const int ib32 = il/2;
5728
+ il = il%2;
5729
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5730
+ device const uint8_t * qs = xb->qs + 8*ib32;
5731
+ device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
5732
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
5733
+ const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
5734
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
5735
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
5736
+ for (int i = 0; i < 4; ++i) {
5737
+ reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5738
+ reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5739
+ }
5740
+ grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
5741
+ grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
5742
+ for (int i = 0; i < 4; ++i) {
5743
+ reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5744
+ reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5745
+ }
5746
+ }
5747
+
5748
+ template <typename type4x4>
5749
+ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
5750
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5751
+ const float d = xb->d;
5752
+ const int ib32 = il/2;
5753
+ il = il%2;
5754
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5755
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5756
+ device const uint8_t * signs = qs + QK_K/8;
5757
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
5758
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
5759
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
5760
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
5761
+ for (int i = 0; i < 8; ++i) {
5762
+ reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
5763
+ reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
5764
+ }
5765
+ }
5766
+
4955
5767
  template <typename type4x4>
4956
5768
  void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
4957
5769
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@@ -4983,6 +5795,30 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
4983
5795
  }
4984
5796
  }
4985
5797
 
5798
+ template <typename type4x4>
5799
+ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5800
+ #if QK_K == 64
5801
+ dequantize_iq4_nl(xb, il, reg);
5802
+ #else
5803
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5804
+ const int ib32 = il/2;
5805
+ il = il%2;
5806
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5807
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
5808
+ const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
5809
+ const float d = (float)xb->d * (ls - 32);
5810
+ uint32_t aux32;
5811
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
5812
+ for (int i = 0; i < 4; ++i) {
5813
+ aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
5814
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
5815
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
5816
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5817
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5818
+ }
5819
+ #endif
5820
+ }
5821
+
4986
5822
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4987
5823
  kernel void kernel_get_rows(
4988
5824
  device const void * src0,
@@ -5525,8 +6361,15 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
5525
6361
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5526
6362
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5527
6363
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6364
+ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6365
+ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
5528
6366
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5529
- template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6367
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6368
+ #if QK_K == 64
6369
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6370
+ #else
6371
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6372
+ #endif
5530
6373
 
5531
6374
  //
5532
6375
  // matrix-matrix multiplication
@@ -5566,8 +6409,15 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
5566
6409
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5567
6410
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5568
6411
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6412
+ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
6413
+ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
5569
6414
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5570
- template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6415
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6416
+ #if QK_K == 64
6417
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6418
+ #else
6419
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6420
+ #endif
5571
6421
 
5572
6422
  //
5573
6423
  // indirect matrix-matrix multiplication
@@ -5619,8 +6469,15 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
5619
6469
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5620
6470
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5621
6471
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6472
+ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
6473
+ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
5622
6474
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
5623
- template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6475
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6476
+ #if QK_K == 64
6477
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6478
+ #else
6479
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6480
+ #endif
5624
6481
 
5625
6482
  //
5626
6483
  // matrix-vector multiplication
@@ -6589,6 +7446,136 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6589
7446
  sgitg);
6590
7447
  }
6591
7448
 
7449
+ [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
7450
+ kernel void kernel_mul_mv_id_iq3_s_f32(
7451
+ device const char * ids,
7452
+ device const char * src1,
7453
+ device float * dst,
7454
+ constant uint64_t & nbi1,
7455
+ constant int64_t & ne00,
7456
+ constant int64_t & ne01,
7457
+ constant int64_t & ne02,
7458
+ constant uint64_t & nb00,
7459
+ constant uint64_t & nb01,
7460
+ constant uint64_t & nb02,
7461
+ constant int64_t & ne10,
7462
+ constant int64_t & ne11,
7463
+ constant int64_t & ne12,
7464
+ constant int64_t & ne13,
7465
+ constant uint64_t & nb10,
7466
+ constant uint64_t & nb11,
7467
+ constant uint64_t & nb12,
7468
+ constant int64_t & ne0,
7469
+ constant int64_t & ne1,
7470
+ constant uint64_t & nb1,
7471
+ constant uint & r2,
7472
+ constant uint & r3,
7473
+ constant int & idx,
7474
+ device const char * src00,
7475
+ device const char * src01,
7476
+ device const char * src02,
7477
+ device const char * src03,
7478
+ device const char * src04,
7479
+ device const char * src05,
7480
+ device const char * src06,
7481
+ device const char * src07,
7482
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
7483
+ uint3 tgpig[[threadgroup_position_in_grid]],
7484
+ uint tiitg[[thread_index_in_threadgroup]],
7485
+ uint tiisg[[thread_index_in_simdgroup]],
7486
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7487
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7488
+
7489
+ const int64_t bid = tgpig.z/(ne12*ne13);
7490
+
7491
+ tgpig.z = tgpig.z%(ne12*ne13);
7492
+
7493
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7494
+
7495
+ kernel_mul_mv_iq3_s_f32_impl(
7496
+ src0[id],
7497
+ (device const float *) (src1 + bid*nb11),
7498
+ dst + bid*ne0,
7499
+ ne00,
7500
+ ne01,
7501
+ ne02,
7502
+ ne10,
7503
+ ne12,
7504
+ ne0,
7505
+ ne1,
7506
+ r2,
7507
+ r3,
7508
+ shared_values,
7509
+ tgpig,
7510
+ tiisg,
7511
+ sgitg);
7512
+ }
7513
+
7514
+ [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
7515
+ kernel void kernel_mul_mv_id_iq2_s_f32(
7516
+ device const char * ids,
7517
+ device const char * src1,
7518
+ device float * dst,
7519
+ constant uint64_t & nbi1,
7520
+ constant int64_t & ne00,
7521
+ constant int64_t & ne01,
7522
+ constant int64_t & ne02,
7523
+ constant uint64_t & nb00,
7524
+ constant uint64_t & nb01,
7525
+ constant uint64_t & nb02,
7526
+ constant int64_t & ne10,
7527
+ constant int64_t & ne11,
7528
+ constant int64_t & ne12,
7529
+ constant int64_t & ne13,
7530
+ constant uint64_t & nb10,
7531
+ constant uint64_t & nb11,
7532
+ constant uint64_t & nb12,
7533
+ constant int64_t & ne0,
7534
+ constant int64_t & ne1,
7535
+ constant uint64_t & nb1,
7536
+ constant uint & r2,
7537
+ constant uint & r3,
7538
+ constant int & idx,
7539
+ device const char * src00,
7540
+ device const char * src01,
7541
+ device const char * src02,
7542
+ device const char * src03,
7543
+ device const char * src04,
7544
+ device const char * src05,
7545
+ device const char * src06,
7546
+ device const char * src07,
7547
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
7548
+ uint3 tgpig[[threadgroup_position_in_grid]],
7549
+ uint tiitg[[thread_index_in_threadgroup]],
7550
+ uint tiisg[[thread_index_in_simdgroup]],
7551
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7552
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7553
+
7554
+ const int64_t bid = tgpig.z/(ne12*ne13);
7555
+
7556
+ tgpig.z = tgpig.z%(ne12*ne13);
7557
+
7558
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7559
+
7560
+ kernel_mul_mv_iq2_s_f32_impl(
7561
+ src0[id],
7562
+ (device const float *) (src1 + bid*nb11),
7563
+ dst + bid*ne0,
7564
+ ne00,
7565
+ ne01,
7566
+ ne02,
7567
+ ne10,
7568
+ ne12,
7569
+ ne0,
7570
+ ne1,
7571
+ r2,
7572
+ r3,
7573
+ shared_values,
7574
+ tgpig,
7575
+ tiisg,
7576
+ sgitg);
7577
+ }
7578
+
6592
7579
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6593
7580
  kernel void kernel_mul_mv_id_iq1_s_f32(
6594
7581
  device const char * ids,
@@ -6716,3 +7703,72 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
6716
7703
  tiisg,
6717
7704
  sgitg);
6718
7705
  }
7706
+
7707
+ [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7708
+ kernel void kernel_mul_mv_id_iq4_xs_f32(
7709
+ device const char * ids,
7710
+ device const char * src1,
7711
+ device float * dst,
7712
+ constant uint64_t & nbi1,
7713
+ constant int64_t & ne00,
7714
+ constant int64_t & ne01,
7715
+ constant int64_t & ne02,
7716
+ constant uint64_t & nb00,
7717
+ constant uint64_t & nb01,
7718
+ constant uint64_t & nb02,
7719
+ constant int64_t & ne10,
7720
+ constant int64_t & ne11,
7721
+ constant int64_t & ne12,
7722
+ constant int64_t & ne13,
7723
+ constant uint64_t & nb10,
7724
+ constant uint64_t & nb11,
7725
+ constant uint64_t & nb12,
7726
+ constant int64_t & ne0,
7727
+ constant int64_t & ne1,
7728
+ constant uint64_t & nb1,
7729
+ constant uint & r2,
7730
+ constant uint & r3,
7731
+ constant int & idx,
7732
+ device const char * src00,
7733
+ device const char * src01,
7734
+ device const char * src02,
7735
+ device const char * src03,
7736
+ device const char * src04,
7737
+ device const char * src05,
7738
+ device const char * src06,
7739
+ device const char * src07,
7740
+ threadgroup float * shared_values [[threadgroup(0)]],
7741
+ uint3 tgpig[[threadgroup_position_in_grid]],
7742
+ uint tiitg[[thread_index_in_threadgroup]],
7743
+ uint tiisg[[thread_index_in_simdgroup]],
7744
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7745
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7746
+
7747
+ const int64_t bid = tgpig.z/(ne12*ne13);
7748
+
7749
+ tgpig.z = tgpig.z%(ne12*ne13);
7750
+
7751
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7752
+
7753
+ #if QK_K == 64
7754
+ kernel_mul_mv_iq4_nl_f32_impl(
7755
+ #else
7756
+ kernel_mul_mv_iq4_xs_f32_impl(
7757
+ #endif
7758
+ src0[id],
7759
+ (device const float *) (src1 + bid*nb11),
7760
+ dst + bid*ne0,
7761
+ ne00,
7762
+ ne01,
7763
+ ne02,
7764
+ ne10,
7765
+ ne12,
7766
+ ne0,
7767
+ ne1,
7768
+ r2,
7769
+ r3,
7770
+ shared_values,
7771
+ tgpig,
7772
+ tiisg,
7773
+ sgitg);
7774
+ }