llama_cpp 0.12.7 → 0.14.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -172,6 +172,7 @@
172
172
  #endif
173
173
 
174
174
  typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
175
+ typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
175
176
  static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
176
177
  const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
177
178
  const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
@@ -196,6 +197,18 @@ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
196
197
  return __vsubss4(a, b);
197
198
  }
198
199
 
200
+ static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
201
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
202
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
203
+ unsigned int c;
204
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
205
+ #pragma unroll
206
+ for (int i = 0; i < 4; ++i) {
207
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
208
+ }
209
+ return c;
210
+ }
211
+
199
212
  static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
200
213
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
201
214
  c = __builtin_amdgcn_sdot4(a, b, c, false);
@@ -510,6 +523,17 @@ typedef struct {
510
523
  } block_iq2_xs;
511
524
  static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
512
525
 
526
+ // 2.5625 bpw quants
527
+ #define QR2_S 8
528
+ #define QI2_S (QK_K / (4*QR2_S))
529
+ typedef struct {
530
+ half d;
531
+ uint8_t qs[QK_K/4];
532
+ uint8_t qh[QK_K/32];
533
+ uint8_t scales[QK_K/32];
534
+ } block_iq2_s;
535
+ static_assert(sizeof(block_iq2_s) == sizeof(ggml_fp16_t) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
536
+
513
537
  #define QR3_XXS 8
514
538
  #define QI3_XXS (QK_K / (4*QR3_XXS))
515
539
  typedef struct {
@@ -518,6 +542,22 @@ typedef struct {
518
542
  } block_iq3_xxs;
519
543
  static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
520
544
 
545
+ #define QR3_XS 8
546
+ #define QI3_XS (QK_K / (4*QR3_XS))
547
+ #if QK_K == 64
548
+ #define IQ3S_N_SCALE 2
549
+ #else
550
+ #define IQ3S_N_SCALE QK_K/64
551
+ #endif
552
+ typedef struct {
553
+ half d;
554
+ uint8_t qs[QK_K/4];
555
+ uint8_t qh[QK_K/32];
556
+ uint8_t signs[QK_K/8];
557
+ uint8_t scales[IQ3S_N_SCALE];
558
+ } block_iq3_s;
559
+ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
560
+
521
561
  #define QR1_S 8
522
562
  #define QI1_S (QK_K / (4*QR1_S))
523
563
  typedef struct {
@@ -536,6 +576,23 @@ typedef struct {
536
576
  } block_iq4_nl;
537
577
  static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
538
578
 
579
+ #if QK_K == 64
580
+ #define block_iq4_xs block_iq4_nl
581
+ #define QR4_XS QR4_NL
582
+ #define QI4_XS QI4_NL
583
+ #else
584
+ // QR4_XS = 8 is very slightly faster than QR4_XS = 4
585
+ #define QR4_XS 8
586
+ #define QI4_XS (QK_K / (4*QR4_XS))
587
+ typedef struct {
588
+ half d;
589
+ uint16_t scales_h;
590
+ uint8_t scales_l[QK_K/64];
591
+ uint8_t qs[QK_K/2];
592
+ } block_iq4_xs;
593
+ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
594
+ #endif
595
+
539
596
  #define WARP_SIZE 32
540
597
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
541
598
 
@@ -559,6 +616,8 @@ static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4
559
616
  #define CUDA_UPSCALE_BLOCK_SIZE 256
560
617
  #define CUDA_CONCAT_BLOCK_SIZE 256
561
618
  #define CUDA_PAD_BLOCK_SIZE 256
619
+ #define CUDA_ARANGE_BLOCK_SIZE 256
620
+ #define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
562
621
  #define CUDA_ACC_BLOCK_SIZE 256
563
622
  #define CUDA_IM2COL_BLOCK_SIZE 256
564
623
  #define CUDA_POOL2D_BLOCK_SIZE 256
@@ -661,18 +720,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
661
720
  return a;
662
721
  }
663
722
 
664
- //static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
665
- //#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
666
- //#pragma unroll
667
- // for (int mask = 16; mask > 0; mask >>= 1) {
668
- // a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
669
- // }
670
- // return a;
671
- //#else
672
- // (void) a;
673
- // NO_DEVICE_CODE;
674
- //#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
675
- //}
723
+ #ifdef GGML_CUDA_F16
724
+ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
725
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
726
+ #pragma unroll
727
+ for (int mask = 16; mask > 0; mask >>= 1) {
728
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
729
+ }
730
+ return a;
731
+ #else
732
+ (void) a;
733
+ NO_DEVICE_CODE;
734
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
735
+ }
736
+ #endif // GGML_CUDA_F16
676
737
 
677
738
  static __device__ __forceinline__ float warp_reduce_max(float x) {
678
739
  #pragma unroll
@@ -931,17 +992,21 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
931
992
  nidx +
932
993
  blockIdx.y * ne0 +
933
994
  blockIdx.z * ne0 * gridDim.y;
934
- dst[offset_dst] = x[offset_src];
995
+ dst[offset_dst] = x[offset_src];
935
996
  } else {
936
997
  int offset_src =
937
998
  nidx +
938
999
  blockIdx.y * ne0 +
939
1000
  (blockIdx.z - ne02) * ne0 * gridDim.y;
940
- dst[offset_dst] = y[offset_src];
1001
+ dst[offset_dst] = y[offset_src];
941
1002
  }
942
1003
  }
943
1004
 
944
- static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
1005
+ static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
1006
+ // blockIdx.z: idx of ne02*ne03
1007
+ // blockIdx.y: idx of ne01*scale_factor, aka ne1
1008
+ // blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
1009
+ // ne00xne01: ne00 * ne01
945
1010
  int ne0 = ne00 * scale_factor;
946
1011
  int nidx = threadIdx.x + blockIdx.x * blockDim.x;
947
1012
  if (nidx >= ne0) {
@@ -953,7 +1018,7 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
953
1018
  int offset_src =
954
1019
  i00 +
955
1020
  i01 * ne00 +
956
- blockIdx.z * nb02;
1021
+ blockIdx.z * ne00xne01;
957
1022
  int offset_dst =
958
1023
  nidx +
959
1024
  blockIdx.y * ne0 +
@@ -961,7 +1026,10 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
961
1026
  dst[offset_dst] = x[offset_src];
962
1027
  }
963
1028
 
964
- static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
1029
+ static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
1030
+ // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
1031
+ // blockIdx.y: idx of ne1
1032
+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
965
1033
  int nidx = threadIdx.x + blockIdx.x * blockDim.x;
966
1034
  if (nidx >= ne0) {
967
1035
  return;
@@ -972,19 +1040,53 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
972
1040
  nidx +
973
1041
  blockIdx.y * ne0 +
974
1042
  blockIdx.z * ne0 * gridDim.y;
975
- if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
1043
+ if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
976
1044
  int offset_src =
977
1045
  nidx +
978
1046
  blockIdx.y * ne00 +
979
1047
  blockIdx.z * ne00 * ne01;
980
- dst[offset_dst] = x[offset_src];
1048
+ dst[offset_dst] = x[offset_src];
981
1049
  } else {
982
1050
  dst[offset_dst] = 0.0f;
983
1051
  }
984
1052
  }
985
1053
 
1054
+ static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
1055
+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
1056
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
1057
+ if (nidx >= ne0) {
1058
+ return;
1059
+ }
1060
+ dst[nidx] = start + step * nidx;
1061
+ }
1062
+
1063
+ static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
1064
+ // blockIDx.y: idx of timesteps->ne[0]
1065
+ // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
1066
+ int i = blockIdx.y;
1067
+ int j = threadIdx.x + blockIdx.x * blockDim.x;
1068
+ float * embed_data = (float *)((char *)dst + i*nb1);
1069
+
1070
+ if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
1071
+ embed_data[dim] = 0.f;
1072
+ }
1073
+
1074
+ int half = dim / 2;
1075
+ if (j >= half) {
1076
+ return;
1077
+ }
1078
+
1079
+ float timestep = timesteps[i];
1080
+ float freq = (float)expf(-logf(max_period) * j / half);
1081
+ float arg = timestep * freq;
1082
+ embed_data[j] = cosf(arg);
1083
+ embed_data[j + half] = sinf(arg);
1084
+ }
1085
+
986
1086
  template <int block_size>
987
1087
  static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
1088
+ // blockIdx.x: num_groups idx
1089
+ // threadIdx.x: block_size idx
988
1090
  int start = blockIdx.x * group_size;
989
1091
  int end = start + group_size;
990
1092
 
@@ -1665,6 +1767,265 @@ static const __device__ uint64_t iq2xs_grid[512] = {
1665
1767
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
1666
1768
  };
1667
1769
 
1770
+ static const __device__ uint64_t iq2s_grid[1024] = {
1771
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1772
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
1773
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
1774
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
1775
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
1776
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
1777
+ 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
1778
+ 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
1779
+ 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
1780
+ 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
1781
+ 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
1782
+ 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
1783
+ 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
1784
+ 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
1785
+ 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
1786
+ 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
1787
+ 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
1788
+ 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
1789
+ 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
1790
+ 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
1791
+ 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
1792
+ 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
1793
+ 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
1794
+ 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
1795
+ 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
1796
+ 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
1797
+ 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
1798
+ 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
1799
+ 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
1800
+ 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
1801
+ 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
1802
+ 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
1803
+ 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
1804
+ 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
1805
+ 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
1806
+ 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
1807
+ 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
1808
+ 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
1809
+ 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
1810
+ 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
1811
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
1812
+ 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
1813
+ 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
1814
+ 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
1815
+ 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
1816
+ 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
1817
+ 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
1818
+ 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
1819
+ 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
1820
+ 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
1821
+ 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
1822
+ 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
1823
+ 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
1824
+ 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
1825
+ 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
1826
+ 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
1827
+ 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
1828
+ 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
1829
+ 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
1830
+ 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
1831
+ 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
1832
+ 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
1833
+ 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
1834
+ 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
1835
+ 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
1836
+ 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
1837
+ 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
1838
+ 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
1839
+ 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
1840
+ 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
1841
+ 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
1842
+ 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
1843
+ 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
1844
+ 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
1845
+ 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
1846
+ 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
1847
+ 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
1848
+ 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
1849
+ 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
1850
+ 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
1851
+ 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
1852
+ 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
1853
+ 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
1854
+ 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
1855
+ 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
1856
+ 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
1857
+ 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
1858
+ 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
1859
+ 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
1860
+ 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
1861
+ 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
1862
+ 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
1863
+ 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
1864
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
1865
+ 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
1866
+ 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
1867
+ 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
1868
+ 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
1869
+ 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
1870
+ 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
1871
+ 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
1872
+ 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
1873
+ 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
1874
+ 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
1875
+ 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
1876
+ 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
1877
+ 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
1878
+ 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
1879
+ 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
1880
+ 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
1881
+ 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
1882
+ 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
1883
+ 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
1884
+ 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
1885
+ 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
1886
+ 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
1887
+ 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
1888
+ 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
1889
+ 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
1890
+ 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
1891
+ 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
1892
+ 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
1893
+ 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
1894
+ 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
1895
+ 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
1896
+ 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
1897
+ 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
1898
+ 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
1899
+ 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
1900
+ 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
1901
+ 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
1902
+ 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
1903
+ 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
1904
+ 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
1905
+ 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
1906
+ 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
1907
+ 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
1908
+ 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
1909
+ 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
1910
+ 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
1911
+ 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
1912
+ 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
1913
+ 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
1914
+ 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
1915
+ 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
1916
+ 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
1917
+ 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
1918
+ 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
1919
+ 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
1920
+ 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
1921
+ 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
1922
+ 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
1923
+ 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
1924
+ 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
1925
+ 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
1926
+ 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
1927
+ 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
1928
+ 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
1929
+ 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
1930
+ 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
1931
+ 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
1932
+ 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
1933
+ 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
1934
+ 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
1935
+ 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
1936
+ 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
1937
+ 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
1938
+ 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
1939
+ 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
1940
+ 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
1941
+ 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
1942
+ 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
1943
+ 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
1944
+ 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
1945
+ 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
1946
+ 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
1947
+ 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
1948
+ 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
1949
+ 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
1950
+ 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
1951
+ 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
1952
+ 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
1953
+ 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
1954
+ 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
1955
+ 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
1956
+ 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
1957
+ 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
1958
+ 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
1959
+ 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
1960
+ 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
1961
+ 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
1962
+ 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
1963
+ 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
1964
+ 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
1965
+ 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
1966
+ 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
1967
+ 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
1968
+ 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
1969
+ 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
1970
+ 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
1971
+ 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
1972
+ 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
1973
+ 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
1974
+ 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
1975
+ 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
1976
+ 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
1977
+ 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
1978
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
1979
+ 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
1980
+ 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
1981
+ 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
1982
+ 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
1983
+ 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
1984
+ 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
1985
+ 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
1986
+ 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
1987
+ 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
1988
+ 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
1989
+ 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
1990
+ 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
1991
+ 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
1992
+ 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
1993
+ 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
1994
+ 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
1995
+ 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
1996
+ 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
1997
+ 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
1998
+ 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
1999
+ 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
2000
+ 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
2001
+ 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
2002
+ 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
2003
+ 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
2004
+ 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
2005
+ 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
2006
+ 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
2007
+ 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
2008
+ 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
2009
+ 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
2010
+ 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
2011
+ 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
2012
+ 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
2013
+ 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
2014
+ 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
2015
+ 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
2016
+ 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
2017
+ 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
2018
+ 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
2019
+ 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
2020
+ 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
2021
+ 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
2022
+ 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
2023
+ 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
2024
+ 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
2025
+ 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
2026
+ 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
2027
+ };
2028
+
1668
2029
  static const __device__ uint32_t iq3xxs_grid[256] = {
1669
2030
  0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
1670
2031
  0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -1700,6 +2061,73 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
1700
2061
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
1701
2062
  };
1702
2063
 
2064
+ static const __device__ uint32_t iq3s_grid[512] = {
2065
+ 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
2066
+ 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
2067
+ 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
2068
+ 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
2069
+ 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
2070
+ 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
2071
+ 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
2072
+ 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
2073
+ 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
2074
+ 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
2075
+ 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
2076
+ 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
2077
+ 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
2078
+ 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
2079
+ 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
2080
+ 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
2081
+ 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
2082
+ 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
2083
+ 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
2084
+ 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
2085
+ 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
2086
+ 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
2087
+ 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
2088
+ 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
2089
+ 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
2090
+ 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
2091
+ 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
2092
+ 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
2093
+ 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
2094
+ 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
2095
+ 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
2096
+ 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
2097
+ 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
2098
+ 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
2099
+ 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
2100
+ 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
2101
+ 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
2102
+ 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
2103
+ 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
2104
+ 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
2105
+ 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
2106
+ 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
2107
+ 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
2108
+ 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
2109
+ 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
2110
+ 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
2111
+ 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
2112
+ 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
2113
+ 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
2114
+ 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
2115
+ 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
2116
+ 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
2117
+ 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
2118
+ 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
2119
+ 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
2120
+ 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
2121
+ 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
2122
+ 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
2123
+ 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
2124
+ 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
2125
+ 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
2126
+ 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
2127
+ 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
2128
+ 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
2129
+ };
2130
+
1703
2131
  static const __device__ uint64_t iq1s_grid[512] = {
1704
2132
  0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
1705
2133
  0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
@@ -1945,6 +2373,27 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
1945
2373
 
1946
2374
  }
1947
2375
 
2376
+ template<typename dst_t>
2377
+ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2378
+
2379
+ const int i = blockIdx.x;
2380
+ const block_iq2_s * x = (const block_iq2_s *) vx;
2381
+
2382
+ const int tid = threadIdx.x;
2383
+ #if QK_K == 256
2384
+ const int il = tid/8; // 0...3
2385
+ const int ib = tid%8; // 0...7
2386
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
2387
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
2388
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
2389
+ const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
2390
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
2391
+ #else
2392
+ assert(false);
2393
+ #endif
2394
+
2395
+ }
2396
+
1948
2397
  template<typename dst_t>
1949
2398
  static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1950
2399
 
@@ -1973,6 +2422,32 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
1973
2422
 
1974
2423
  }
1975
2424
 
2425
+ template<typename dst_t>
2426
+ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2427
+
2428
+ const int i = blockIdx.x;
2429
+ const block_iq3_s * x = (const block_iq3_s *) vx;
2430
+
2431
+ const int tid = threadIdx.x;
2432
+ #if QK_K == 256
2433
+ const int il = tid/8; // 0...3
2434
+ const int ib = tid%8; // 0...7
2435
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
2436
+ const uint8_t * qs = x[i].qs + 8*ib;
2437
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
2438
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
2439
+ const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
2440
+ const uint8_t signs = x[i].signs[4*ib + il];
2441
+ for (int j = 0; j < 4; ++j) {
2442
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
2443
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
2444
+ }
2445
+ #else
2446
+ assert(false);
2447
+ #endif
2448
+
2449
+ }
2450
+
1976
2451
  template<typename dst_t>
1977
2452
  static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1978
2453
 
@@ -2016,6 +2491,25 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
2016
2491
 
2017
2492
  }
2018
2493
 
2494
+ #if QK_K != 64
2495
+ template<typename dst_t>
2496
+ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2497
+ const int i = blockIdx.x;
2498
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
2499
+
2500
+ const int tid = threadIdx.x;
2501
+ const int il = tid/8; // 0...3
2502
+ const int ib = tid%8; // 0...7
2503
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
2504
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
2505
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
2506
+ for (int j = 0; j < 4; ++j) {
2507
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
2508
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
2509
+ }
2510
+ }
2511
+ #endif
2512
+
2019
2513
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
2020
2514
 
2021
2515
  static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -2112,10 +2606,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
2112
2606
  #endif
2113
2607
 
2114
2608
  // sum up partial sums and write back result
2115
- #pragma unroll
2116
- for (int mask = 16; mask > 0; mask >>= 1) {
2117
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2118
- }
2609
+ tmp = warp_reduce_sum(tmp);
2119
2610
 
2120
2611
  if (threadIdx.x == 0) {
2121
2612
  dst[row] = tmp;
@@ -2216,10 +2707,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
2216
2707
  #endif
2217
2708
 
2218
2709
  // sum up partial sums and write back result
2219
- #pragma unroll
2220
- for (int mask = 16; mask > 0; mask >>= 1) {
2221
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2222
- }
2710
+ tmp = warp_reduce_sum(tmp);
2223
2711
 
2224
2712
  if (threadIdx.x == 0) {
2225
2713
  dst[row] = tmp;
@@ -2352,10 +2840,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
2352
2840
  #endif
2353
2841
 
2354
2842
  // sum up partial sums and write back result
2355
- #pragma unroll
2356
- for (int mask = 16; mask > 0; mask >>= 1) {
2357
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2358
- }
2843
+ tmp = warp_reduce_sum(tmp);
2359
2844
 
2360
2845
  if (tid == 0) {
2361
2846
  dst[row] = tmp;
@@ -2468,10 +2953,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
2468
2953
  #endif
2469
2954
 
2470
2955
  // sum up partial sums and write back result
2471
- #pragma unroll
2472
- for (int mask = 16; mask > 0; mask >>= 1) {
2473
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2474
- }
2956
+ tmp = warp_reduce_sum(tmp);
2475
2957
 
2476
2958
  if (threadIdx.x == 0) {
2477
2959
  dst[row] = tmp;
@@ -2578,10 +3060,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
2578
3060
  #endif
2579
3061
 
2580
3062
  // sum up partial sums and write back result
2581
- #pragma unroll
2582
- for (int mask = 16; mask > 0; mask >>= 1) {
2583
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
2584
- }
3063
+ tmp = warp_reduce_sum(tmp);
2585
3064
 
2586
3065
  if (tid == 0) {
2587
3066
  dst[row] = tmp;
@@ -2616,11 +3095,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
2616
3095
  float amax = fabsf(xi);
2617
3096
  float sum = xi;
2618
3097
 
2619
- #pragma unroll
2620
- for (int mask = 16; mask > 0; mask >>= 1) {
2621
- amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
2622
- sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
2623
- }
3098
+ amax = warp_reduce_max(amax);
3099
+ sum = warp_reduce_sum(sum);
2624
3100
 
2625
3101
  const float d = amax / 127;
2626
3102
  const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
@@ -4682,6 +5158,54 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
4682
5158
  #endif
4683
5159
  }
4684
5160
 
5161
+ // TODO
5162
+ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
5163
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
5164
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
5165
+ #if QK_K == 256
5166
+ const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
5167
+
5168
+ const int ib32 = iqs;
5169
+ const int8_t * q8 = bq8_1[ib32].qs;
5170
+ const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
5171
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
5172
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
5173
+ int sumi1 = 0;
5174
+ for (int l = 0; l < 2; ++l) {
5175
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
5176
+ const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
5177
+ const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
5178
+ const int grid_l = __vsub4(grid[0] ^ signs0, signs0);
5179
+ const int grid_h = __vsub4(grid[1] ^ signs1, signs1);
5180
+ sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1);
5181
+ sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1);
5182
+ q8 += 8;
5183
+ }
5184
+ int sumi2 = 0;
5185
+ for (int l = 2; l < 4; ++l) {
5186
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
5187
+ const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
5188
+ const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
5189
+ const int grid_l = __vsub4(grid[0] ^ signs0, signs0);
5190
+ const int grid_h = __vsub4(grid[1] ^ signs1, signs1);
5191
+ sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2);
5192
+ sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2);
5193
+ q8 += 8;
5194
+ }
5195
+ const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
5196
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
5197
+ #else
5198
+ (void) ksigns64;
5199
+ assert(false);
5200
+ return 0.f;
5201
+ #endif
5202
+ #else
5203
+ (void) ksigns64;
5204
+ assert(false);
5205
+ return 0.f;
5206
+ #endif
5207
+ }
5208
+
4685
5209
  static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
4686
5210
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4687
5211
  #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
@@ -4717,6 +5241,41 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
4717
5241
  #endif
4718
5242
  }
4719
5243
 
5244
+ // TODO: don't use lookup table for signs
5245
+ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
5246
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
5247
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
5248
+ #if QK_K == 256
5249
+ const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
5250
+
5251
+ const int ib32 = iqs;
5252
+ const uint8_t * qs = bq2->qs + 8*ib32;
5253
+ const int8_t * q8 = bq8_1[ib32].qs;
5254
+ int sumi = 0;
5255
+ for (int l = 0; l < 4; ++l) {
5256
+ const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
5257
+ const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
5258
+ uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
5259
+ uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
5260
+ const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
5261
+ const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
5262
+ sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
5263
+ sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
5264
+ q8 += 8;
5265
+ }
5266
+ const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
5267
+ return d * sumi;
5268
+ #else
5269
+ assert(false);
5270
+ return 0.f;
5271
+ #endif
5272
+ #else
5273
+ assert(false);
5274
+ return 0.f;
5275
+ #endif
5276
+ }
5277
+
5278
+
4720
5279
  static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
4721
5280
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4722
5281
  #if QK_K == 256
@@ -4810,6 +5369,75 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
4810
5369
  return d * (sumi1 + sumi2);
4811
5370
  }
4812
5371
 
5372
+ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
5373
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
5374
+
5375
+ #if QK_K == 256
5376
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
5377
+
5378
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
5379
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
5380
+
5381
+ //// iqs is 0...7
5382
+ //const int ib64 = iqs/2;
5383
+ //const int il = iqs%2;
5384
+ //const int32_t * q8_1 = (const int *)bq8_1[2*ib64+0].qs + 2*il;
5385
+ //const int32_t * q8_2 = (const int *)bq8_1[2*ib64+1].qs + 2*il;
5386
+ //const uint32_t * q4_1 = (const uint32_t *)bq4->qs + 8*ib64 + 2*il;
5387
+ //const uint32_t * q4_2 = q4_1 + 4;
5388
+ //const int8_t ls1 = (bq4->scales_l[ib64] & 0xf) | (((bq4->scales_h >> (4*ib64+0)) & 3) << 4);
5389
+ //const int8_t ls2 = (bq4->scales_l[ib64] >> 4) | (((bq4->scales_h >> (4*ib64+2)) & 3) << 4);
5390
+ //const float d1 = (float)bq4->d * (ls1 - 32) * __low2float(bq8_1[2*ib64+0].ds);
5391
+ //const float d2 = (float)bq4->d * (ls2 - 32) * __low2float(bq8_1[2*ib64+1].ds);
5392
+ //int v1, v2;
5393
+ //int sumi1 = 0, sumi2 = 0;
5394
+ //for (int j = 0; j < 2; ++j) {
5395
+ // get_int_from_table_16(q4_1[j], values, v1, v2);
5396
+ // sumi1 = __dp4a(v2, q8_1[j+4], __dp4a(v1, q8_1[j+0], sumi1));
5397
+ // get_int_from_table_16(q4_2[j], values, v1, v2);
5398
+ // sumi2 = __dp4a(v2, q8_2[j+4], __dp4a(v1, q8_2[j+0], sumi2));
5399
+ //}
5400
+ //return d1 * sumi1 + d2 * sumi2;
5401
+
5402
+ // iqs is 0...7
5403
+ const int ib32 = iqs;
5404
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
5405
+ const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
5406
+ const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
5407
+ const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
5408
+ int v1, v2;
5409
+ int sumi1 = 0, sumi2 = 0;
5410
+ for (int j = 0; j < 4; ++j) {
5411
+ get_int_from_table_16(q4[j], values, v1, v2);
5412
+ sumi1 = __dp4a(v1, q8[j+0], sumi1);
5413
+ sumi2 = __dp4a(v2, q8[j+4], sumi2);
5414
+ }
5415
+ return d * (sumi1 + sumi2);
5416
+
5417
+ //// iqs is 0...15
5418
+ //const int ib32 = iqs/2;
5419
+ //const int il = iqs%2;
5420
+ //const int32_t * q8 = (const int *)bq8_1[ib32].qs + 2*il;
5421
+ //const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32 + 2*il;
5422
+ //const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
5423
+ //const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
5424
+ //int v1, v2;
5425
+ //int sumi1 = 0, sumi2 = 0;
5426
+ //for (int j = 0; j < 2; ++j) {
5427
+ // get_int_from_table_16(q4[j], values, v1, v2);
5428
+ // sumi1 = __dp4a(v1, q8[j+0], sumi1);
5429
+ // sumi2 = __dp4a(v2, q8[j+4], sumi2);
5430
+ //}
5431
+ //return d * (sumi1 + sumi2);
5432
+ #else
5433
+ assert(false);
5434
+ return 0.f;
5435
+ #endif
5436
+ #else
5437
+ return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
5438
+ #endif
5439
+ }
5440
+
4813
5441
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
4814
5442
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
4815
5443
  static __device__ __forceinline__ void mul_mat_q(
@@ -5730,10 +6358,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
5730
6358
  }
5731
6359
 
5732
6360
  // sum up partial sums and write back result
5733
- #pragma unroll
5734
- for (int mask = 16; mask > 0; mask >>= 1) {
5735
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
5736
- }
6361
+ tmp = warp_reduce_sum(tmp);
5737
6362
 
5738
6363
  if (tid == 0) {
5739
6364
  #ifdef GGML_CUDA_F16
@@ -5783,10 +6408,7 @@ static __global__ void mul_mat_p021_f16_f32(
5783
6408
  const int idst = channel*nrows_dst + row_dst;
5784
6409
 
5785
6410
  // sum up partial sums and write back result
5786
- #pragma unroll
5787
- for (int mask = 16; mask > 0; mask >>= 1) {
5788
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
5789
- }
6411
+ tmp = warp_reduce_sum(tmp);
5790
6412
 
5791
6413
  if (threadIdx.x == 0) {
5792
6414
  dst[idst] = tmp;
@@ -5829,10 +6451,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
5829
6451
  }
5830
6452
 
5831
6453
  // sum up partial sums and write back result
5832
- #pragma unroll
5833
- for (int mask = 16; mask > 0; mask >>= 1) {
5834
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
5835
- }
6454
+ tmp = warp_reduce_sum(tmp);
5836
6455
 
5837
6456
  if (threadIdx.x == 0) {
5838
6457
  dst[idst] = tmp;
@@ -5872,7 +6491,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
5872
6491
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
5873
6492
  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
5874
6493
  const int nb12, const int nb13) {
5875
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
6494
+ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
5876
6495
 
5877
6496
  if (i >= ne) {
5878
6497
  return;
@@ -5880,17 +6499,17 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
5880
6499
 
5881
6500
  // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
5882
6501
  // then combine those indices with the corresponding byte offsets to get the total offsets
5883
- const int i03 = i/(ne00 * ne01 * ne02);
5884
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
5885
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
5886
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
5887
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
5888
-
5889
- const int i13 = i/(ne10 * ne11 * ne12);
5890
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
5891
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
5892
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
5893
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
6502
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
6503
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
6504
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
6505
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
6506
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
6507
+
6508
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
6509
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
6510
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
6511
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
6512
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
5894
6513
 
5895
6514
  cpy_1(cx + x_offset, cdst + dst_offset);
5896
6515
  }
@@ -6216,11 +6835,11 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
6216
6835
  int ixj = col ^ j;
6217
6836
  if (ixj > col) {
6218
6837
  if ((col & k) == 0) {
6219
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
6838
+ if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
6220
6839
  swap(dst_row[col], dst_row[ixj]);
6221
6840
  }
6222
6841
  } else {
6223
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
6842
+ if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
6224
6843
  swap(dst_row[col], dst_row[ixj]);
6225
6844
  }
6226
6845
  }
@@ -6328,6 +6947,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
6328
6947
  // find the sum of exps in the block
6329
6948
  tmp = warp_reduce_sum(tmp);
6330
6949
  if (block_size > WARP_SIZE) {
6950
+ __syncthreads();
6331
6951
  if (warp_id == 0) {
6332
6952
  buf_iw[lane_id] = 0.0f;
6333
6953
  }
@@ -6379,23 +6999,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
6379
6999
 
6380
7000
  template <typename T>
6381
7001
  static __global__ void im2col_kernel(
6382
- const float * x, T * dst, int batch_offset,
6383
- int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
7002
+ const float * x, T * dst, int64_t batch_offset,
7003
+ int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
6384
7004
  int s0, int s1, int p0, int p1, int d0, int d1) {
6385
- const int i = threadIdx.x + blockIdx.x * blockDim.x;
7005
+ const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
6386
7006
  if (i >= pelements) {
6387
7007
  return;
6388
7008
  }
6389
7009
 
6390
- const int ksize = OW * (KH > 1 ? KW : 1);
6391
- const int kx = i / ksize;
6392
- const int kd = kx * ksize;
6393
- const int ky = (i - kd) / OW;
6394
- const int ix = i % OW;
7010
+ const int64_t ksize = OW * (KH > 1 ? KW : 1);
7011
+ const int64_t kx = i / ksize;
7012
+ const int64_t kd = kx * ksize;
7013
+ const int64_t ky = (i - kd) / OW;
7014
+ const int64_t ix = i % OW;
6395
7015
 
6396
- const int oh = blockIdx.y;
6397
- const int batch = blockIdx.z / IC;
6398
- const int ic = blockIdx.z % IC;
7016
+ const int64_t oh = blockIdx.y;
7017
+ const int64_t batch = blockIdx.z / IC;
7018
+ const int64_t ic = blockIdx.z % IC;
6399
7019
 
6400
7020
  const int64_t iiw = ix * s0 + kx * d0 - p0;
6401
7021
  const int64_t iih = oh * s1 + ky * d1 - p1;
@@ -6721,19 +7341,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const
6721
7341
  concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
6722
7342
  }
6723
7343
 
6724
- static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
7344
+ static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
7345
+ const int scale_factor, cudaStream_t stream) {
6725
7346
  int ne0 = (ne00 * scale_factor);
6726
7347
  int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
6727
- dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
7348
+ dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
6728
7349
  upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
6729
7350
  }
6730
7351
 
6731
7352
  static void pad_f32_cuda(const float * x, float * dst,
6732
- const int ne00, const int ne01, const int ne02,
6733
- const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
7353
+ const int ne00, const int ne01, const int ne02, const int ne03,
7354
+ const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
6734
7355
  int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
6735
- dim3 gridDim(num_blocks, ne1, ne2);
6736
- pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
7356
+ dim3 gridDim(num_blocks, ne1, ne2*ne3);
7357
+ pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
7358
+ }
7359
+
7360
+ static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
7361
+ int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
7362
+ arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
7363
+ }
7364
+
7365
+ static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
7366
+ const int dim, const int max_period, cudaStream_t stream) {
7367
+ int half_ceil = (dim + 1) / 2;
7368
+ int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
7369
+ dim3 gridDim(num_blocks, ne00, 1);
7370
+ timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
6737
7371
  }
6738
7372
 
6739
7373
  static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
@@ -6843,12 +7477,24 @@ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k,
6843
7477
  dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
6844
7478
  }
6845
7479
 
7480
+ template<typename dst_t>
7481
+ static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7482
+ const int nb = k / QK_K;
7483
+ dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
7484
+ }
7485
+
6846
7486
  template<typename dst_t>
6847
7487
  static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6848
7488
  const int nb = k / QK_K;
6849
7489
  dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
6850
7490
  }
6851
7491
 
7492
+ template<typename dst_t>
7493
+ static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7494
+ const int nb = k / QK_K;
7495
+ dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
7496
+ }
7497
+
6852
7498
  template<typename dst_t>
6853
7499
  static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6854
7500
  const int nb = k / QK_K;
@@ -6861,6 +7507,16 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k,
6861
7507
  dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
6862
7508
  }
6863
7509
 
7510
+ template<typename dst_t>
7511
+ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
7512
+ const int nb = (k + QK_K - 1) / QK_K;
7513
+ #if QK_K == 64
7514
+ dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
7515
+ #else
7516
+ dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
7517
+ #endif
7518
+ }
7519
+
6864
7520
  template <typename src_t, typename dst_t>
6865
7521
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6866
7522
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6898,12 +7554,18 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6898
7554
  return dequantize_row_iq2_xxs_cuda;
6899
7555
  case GGML_TYPE_IQ2_XS:
6900
7556
  return dequantize_row_iq2_xs_cuda;
7557
+ case GGML_TYPE_IQ2_S:
7558
+ return dequantize_row_iq2_s_cuda;
6901
7559
  case GGML_TYPE_IQ3_XXS:
6902
7560
  return dequantize_row_iq3_xxs_cuda;
6903
7561
  case GGML_TYPE_IQ1_S:
6904
7562
  return dequantize_row_iq1_s_cuda;
6905
7563
  case GGML_TYPE_IQ4_NL:
6906
7564
  return dequantize_row_iq4_nl_cuda;
7565
+ case GGML_TYPE_IQ4_XS:
7566
+ return dequantize_row_iq4_xs_cuda;
7567
+ case GGML_TYPE_IQ3_S:
7568
+ return dequantize_row_iq3_s_cuda;
6907
7569
  case GGML_TYPE_F32:
6908
7570
  return convert_unary_cuda<float>;
6909
7571
  default:
@@ -6937,12 +7599,18 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
6937
7599
  return dequantize_row_iq2_xxs_cuda;
6938
7600
  case GGML_TYPE_IQ2_XS:
6939
7601
  return dequantize_row_iq2_xs_cuda;
7602
+ case GGML_TYPE_IQ2_S:
7603
+ return dequantize_row_iq2_s_cuda;
6940
7604
  case GGML_TYPE_IQ3_XXS:
6941
7605
  return dequantize_row_iq3_xxs_cuda;
6942
7606
  case GGML_TYPE_IQ1_S:
6943
7607
  return dequantize_row_iq1_s_cuda;
6944
7608
  case GGML_TYPE_IQ4_NL:
6945
7609
  return dequantize_row_iq4_nl_cuda;
7610
+ case GGML_TYPE_IQ4_XS:
7611
+ return dequantize_row_iq4_xs_cuda;
7612
+ case GGML_TYPE_IQ3_S:
7613
+ return dequantize_row_iq3_s_cuda;
6946
7614
  case GGML_TYPE_F16:
6947
7615
  return convert_unary_cuda<half>;
6948
7616
  default:
@@ -7764,10 +8432,10 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
7764
8432
 
7765
8433
  const dim3 block_dims(ncols, 1, 1);
7766
8434
  const dim3 block_nums(1, nrows, 1);
7767
- if (order == GGML_SORT_ASC) {
7768
- k_argsort_f32_i32<GGML_SORT_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
7769
- } else if (order == GGML_SORT_DESC) {
7770
- k_argsort_f32_i32<GGML_SORT_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
8435
+ if (order == GGML_SORT_ORDER_ASC) {
8436
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
8437
+ } else if (order == GGML_SORT_ORDER_DESC) {
8438
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
7771
8439
  } else {
7772
8440
  GGML_ASSERT(false);
7773
8441
  }
@@ -7832,8 +8500,8 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
7832
8500
 
7833
8501
  template <typename T>
7834
8502
  static void im2col_cuda(const float* x, T* dst,
7835
- int IW, int IH, int OW, int OH, int KW, int KH, int IC,
7836
- int batch, int batch_offset, int offset_delta,
8503
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
8504
+ int64_t batch, int64_t batch_offset, int64_t offset_delta,
7837
8505
  int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
7838
8506
  const int parallel_elements = OW * KW * KH;
7839
8507
  const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
@@ -7916,8 +8584,8 @@ static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual
7916
8584
  *actual_size = look_ahead_size;
7917
8585
  g_cuda_pool_size[device] += look_ahead_size;
7918
8586
  #ifdef DEBUG_CUDA_MALLOC
7919
- fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
7920
- (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
8587
+ fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
8588
+ (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[device]/1024/1024), (uint32_t)(size/1024/1024));
7921
8589
  #endif
7922
8590
  return ptr;
7923
8591
  }
@@ -8003,7 +8671,7 @@ static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual
8003
8671
  g_cuda_pool_used[device] += size;
8004
8672
 
8005
8673
  #ifdef DEBUG_CUDA_MALLOC
8006
- printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
8674
+ printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
8007
8675
  #endif
8008
8676
 
8009
8677
  return ptr;
@@ -8013,7 +8681,7 @@ static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) {
8013
8681
  scoped_spin_lock lock(g_cuda_pool_lock);
8014
8682
 
8015
8683
  #ifdef DEBUG_CUDA_MALLOC
8016
- printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
8684
+ printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
8017
8685
  #endif
8018
8686
 
8019
8687
  g_cuda_pool_used[device] -= size;
@@ -8199,11 +8867,11 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
8199
8867
 
8200
8868
  cudaMemcpyKind kind;
8201
8869
  char * src_ptr;
8202
- if (src->backend == GGML_BACKEND_CPU) {
8870
+ if (src->backend == GGML_BACKEND_TYPE_CPU) {
8203
8871
  kind = cudaMemcpyHostToDevice;
8204
8872
  src_ptr = (char *) src->data;
8205
- } else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) {
8206
- GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
8873
+ } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
8874
+ GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
8207
8875
  kind = cudaMemcpyDeviceToDevice;
8208
8876
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
8209
8877
  int id;
@@ -8512,7 +9180,7 @@ static void ggml_cuda_op_group_norm(
8512
9180
 
8513
9181
  int num_groups = dst->op_params[0];
8514
9182
  int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
8515
- group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
9183
+ group_norm_f32_cuda(src0_dd, dst_dd, num_groups * src0->ne[3], group_size, ggml_nelements(src0), main_stream);
8516
9184
 
8517
9185
  (void) src1;
8518
9186
  (void) dst;
@@ -8545,7 +9213,7 @@ static void ggml_cuda_op_upscale(
8545
9213
 
8546
9214
  const int scale_factor = dst->op_params[0];
8547
9215
 
8548
- upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
9216
+ upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, main_stream);
8549
9217
 
8550
9218
  (void) src1;
8551
9219
  (void) dst;
@@ -8561,8 +9229,49 @@ static void ggml_cuda_op_pad(
8561
9229
  GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
8562
9230
 
8563
9231
  pad_f32_cuda(src0_dd, dst_dd,
8564
- src0->ne[0], src0->ne[1], src0->ne[2],
8565
- dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
9232
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
9233
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], main_stream);
9234
+
9235
+ (void) src1;
9236
+ (void) dst;
9237
+ (void) src1_dd;
9238
+ }
9239
+
9240
+ static void ggml_cuda_op_arange(
9241
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
9242
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
9243
+
9244
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
9245
+
9246
+ float start;
9247
+ float stop;
9248
+ float step;
9249
+ memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
9250
+ memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
9251
+ memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
9252
+
9253
+ int64_t steps = (int64_t)ceil((stop - start) / step);
9254
+ GGML_ASSERT(ggml_nelements(dst) == steps);
9255
+
9256
+ arange_f32_cuda(dst_dd, dst->ne[0], start, step, main_stream);
9257
+
9258
+ (void) src0;
9259
+ (void) src1;
9260
+ (void) src0_dd;
9261
+ (void) src1_dd;
9262
+ }
9263
+
9264
+ static void ggml_cuda_op_timestep_embedding(
9265
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
9266
+ const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
9267
+
9268
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
9269
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
9270
+
9271
+ const int dim = dst->op_params[0];
9272
+ const int max_period = dst->op_params[1];
9273
+
9274
+ timestep_embedding_f32_cuda(src0_dd, dst_dd, src0->ne[0], dst->nb[1], dim, max_period, main_stream);
8566
9275
 
8567
9276
  (void) src1;
8568
9277
  (void) dst;
@@ -8608,7 +9317,7 @@ static void ggml_cuda_op_mul_mat_q(
8608
9317
 
8609
9318
  // the main device has a larger memory buffer to hold the results from all GPUs
8610
9319
  // nrows_dst == nrows of the matrix that the kernel writes into
8611
- const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
9320
+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff;
8612
9321
 
8613
9322
  switch (src0->type) {
8614
9323
  case GGML_TYPE_Q4_0:
@@ -8685,9 +9394,12 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8685
9394
  case GGML_TYPE_Q6_K:
8686
9395
  case GGML_TYPE_IQ2_XXS:
8687
9396
  case GGML_TYPE_IQ2_XS:
9397
+ case GGML_TYPE_IQ2_S:
8688
9398
  case GGML_TYPE_IQ3_XXS:
8689
9399
  case GGML_TYPE_IQ1_S:
8690
9400
  case GGML_TYPE_IQ4_NL:
9401
+ case GGML_TYPE_IQ4_XS:
9402
+ case GGML_TYPE_IQ3_S:
8691
9403
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8692
9404
  default:
8693
9405
  GGML_ASSERT(false);
@@ -8710,9 +9422,12 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8710
9422
  case GGML_TYPE_Q5_K:
8711
9423
  case GGML_TYPE_IQ2_XXS:
8712
9424
  case GGML_TYPE_IQ2_XS:
9425
+ case GGML_TYPE_IQ2_S:
8713
9426
  case GGML_TYPE_IQ3_XXS:
8714
9427
  case GGML_TYPE_IQ1_S:
8715
9428
  case GGML_TYPE_IQ4_NL:
9429
+ case GGML_TYPE_IQ4_XS:
9430
+ case GGML_TYPE_IQ3_S:
8716
9431
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8717
9432
  case GGML_TYPE_Q6_K:
8718
9433
  return 64;
@@ -8755,7 +9470,7 @@ static void ggml_cuda_op_mul_mat_vec_q(
8755
9470
 
8756
9471
  // the main device has a larger memory buffer to hold the results from all GPUs
8757
9472
  // nrows_dst == nrows of the matrix that the kernel writes into
8758
- const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
9473
+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff;
8759
9474
 
8760
9475
  switch (src0->type) {
8761
9476
  case GGML_TYPE_Q4_0:
@@ -8806,6 +9521,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
8806
9521
  mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
8807
9522
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8808
9523
  break;
9524
+ case GGML_TYPE_IQ2_S:
9525
+ mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
9526
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
9527
+ break;
8809
9528
  case GGML_TYPE_IQ3_XXS:
8810
9529
  mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
8811
9530
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
@@ -8818,6 +9537,14 @@ static void ggml_cuda_op_mul_mat_vec_q(
8818
9537
  mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
8819
9538
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8820
9539
  break;
9540
+ case GGML_TYPE_IQ4_XS:
9541
+ mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
9542
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
9543
+ break;
9544
+ case GGML_TYPE_IQ3_S:
9545
+ mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
9546
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
9547
+ break;
8821
9548
  default:
8822
9549
  GGML_ASSERT(false);
8823
9550
  break;
@@ -8927,7 +9654,7 @@ static void ggml_cuda_op_mul_mat_cublas(
8927
9654
 
8928
9655
  // the main device has a larger memory buffer to hold the results from all GPUs
8929
9656
  // ldc == nrows of the matrix that cuBLAS writes into
8930
- int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
9657
+ int ldc = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device ? ne0 : row_diff;
8931
9658
 
8932
9659
  const int compute_capability = g_device_caps[id].cc;
8933
9660
 
@@ -9275,7 +10002,7 @@ static void ggml_cuda_op_soft_max(
9275
10002
  const bool use_src2 = src2 != nullptr;
9276
10003
 
9277
10004
  if (use_src2) {
9278
- const bool src2_on_device = src2->backend == GGML_BACKEND_GPU;
10005
+ const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
9279
10006
 
9280
10007
  if (src2_on_device) {
9281
10008
  ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
@@ -9333,16 +10060,16 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
9333
10060
  const bool use_src1 = src1 != nullptr;
9334
10061
  const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
9335
10062
 
9336
- GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
9337
- GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT);
10063
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
10064
+ GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
9338
10065
 
9339
10066
  ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
9340
10067
  ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
9341
10068
  ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
9342
10069
 
9343
- const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
9344
- const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
9345
- const bool dst_on_device = dst->backend == GGML_BACKEND_GPU;
10070
+ const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
10071
+ const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU;
10072
+ const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
9346
10073
 
9347
10074
  // dd = data device
9348
10075
  float * src0_ddf = nullptr;
@@ -9386,7 +10113,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
9386
10113
  CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
9387
10114
  }
9388
10115
 
9389
- if (dst->backend == GGML_BACKEND_CPU) {
10116
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
9390
10117
  CUDA_CHECK(cudaDeviceSynchronize());
9391
10118
  }
9392
10119
  }
@@ -9467,8 +10194,8 @@ static void ggml_cuda_op_mul_mat(
9467
10194
  const int nb2 = dst->nb[2];
9468
10195
  const int nb3 = dst->nb[3];
9469
10196
 
9470
- GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
9471
- GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
10197
+ GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
10198
+ GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
9472
10199
  GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
9473
10200
 
9474
10201
  GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
@@ -9484,20 +10211,20 @@ static void ggml_cuda_op_mul_mat(
9484
10211
  ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
9485
10212
  ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
9486
10213
 
9487
- const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
10214
+ const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
9488
10215
  const bool src0_is_contiguous = ggml_is_contiguous(src0);
9489
10216
  const bool src1_is_contiguous = ggml_is_contiguous(src1);
9490
10217
 
9491
10218
  const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
9492
10219
 
9493
- const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
10220
+ const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
9494
10221
  GGML_ASSERT(!(split && ne02 > 1));
9495
10222
  GGML_ASSERT(!(split && ne03 > 1));
9496
10223
  GGML_ASSERT(!(split && ne02 < ne12));
9497
10224
 
9498
10225
  std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
9499
10226
  if (split) {
9500
- // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_GPU_SPLIT check
10227
+ // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
9501
10228
  // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
9502
10229
  ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
9503
10230
  tensor_split = buft_ctx->tensor_split;
@@ -9555,8 +10282,8 @@ static void ggml_cuda_op_mul_mat(
9555
10282
 
9556
10283
  used_devices++;
9557
10284
 
9558
- const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
9559
- const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
10285
+ const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
10286
+ const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
9560
10287
 
9561
10288
  ggml_cuda_set_device(id);
9562
10289
  cudaStream_t stream = g_cudaStreams[id][0];
@@ -9607,8 +10334,8 @@ static void ggml_cuda_op_mul_mat(
9607
10334
  continue;
9608
10335
  }
9609
10336
 
9610
- const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
9611
- const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
10337
+ const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
10338
+ const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
9612
10339
  const int64_t row_diff = dev[id].row_high - dev[id].row_low;
9613
10340
 
9614
10341
  ggml_cuda_set_device(id);
@@ -9633,12 +10360,12 @@ static void ggml_cuda_op_mul_mat(
9633
10360
 
9634
10361
  // the main device memory buffer can be on VRAM scratch, with space for all partial results
9635
10362
  // in that case an offset on dst_ddf_i is needed
9636
- if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) {
10363
+ if (dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device) {
9637
10364
  dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
9638
10365
  }
9639
10366
 
9640
10367
  // copy src0, src1 to device if necessary
9641
- if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
10368
+ if (src1->backend == GGML_BACKEND_TYPE_GPU && src1_is_contiguous) {
9642
10369
  if (id != g_main_device) {
9643
10370
  if (convert_src1_to_q8_1) {
9644
10371
  char * src1_ddq_i_source = dev[g_main_device].src1_ddq + src1_ddq_i_offset;
@@ -9651,14 +10378,14 @@ static void ggml_cuda_op_mul_mat(
9651
10378
  src1_ncols*ne10*sizeof(float), stream));
9652
10379
  }
9653
10380
  }
9654
- } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) {
10381
+ } else if (src1->backend == GGML_BACKEND_TYPE_CPU || (src1_on_device && !src1_is_contiguous)) {
9655
10382
  CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
9656
10383
  src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
9657
10384
  } else {
9658
10385
  GGML_ASSERT(false);
9659
10386
  }
9660
10387
 
9661
- if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) {
10388
+ if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_TYPE_CPU || !src1_is_contiguous)) {
9662
10389
  quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
9663
10390
  CUDA_CHECK(cudaGetLastError());
9664
10391
  }
@@ -9676,10 +10403,10 @@ static void ggml_cuda_op_mul_mat(
9676
10403
  if (!dst_on_device) {
9677
10404
  void * dst_off_device;
9678
10405
  cudaMemcpyKind kind;
9679
- if (dst->backend == GGML_BACKEND_CPU) {
10406
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
9680
10407
  dst_off_device = dst->data;
9681
10408
  kind = cudaMemcpyDeviceToHost;
9682
- } else if (dst->backend == GGML_BACKEND_GPU) {
10409
+ } else if (dst->backend == GGML_BACKEND_TYPE_GPU) {
9683
10410
  dst_off_device = dst_extra->data_device[g_main_device];
9684
10411
  kind = cudaMemcpyDeviceToDevice;
9685
10412
  } else {
@@ -9744,7 +10471,7 @@ static void ggml_cuda_op_mul_mat(
9744
10471
  }
9745
10472
  }
9746
10473
 
9747
- if (dst->backend == GGML_BACKEND_CPU) {
10474
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
9748
10475
  ggml_cuda_set_device(g_main_device);
9749
10476
  CUDA_CHECK(cudaDeviceSynchronize());
9750
10477
  }
@@ -9829,6 +10556,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
9829
10556
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
9830
10557
  }
9831
10558
 
10559
+ static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10560
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
10561
+
10562
+ const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
10563
+
10564
+ // dd = data device
10565
+ float * src0_ddf = nullptr;
10566
+ float * src1_ddf = nullptr;
10567
+ float * dst_ddf = nullptr;
10568
+
10569
+ cuda_pool_alloc<float> dst_f;
10570
+
10571
+ ggml_cuda_set_device(g_main_device);
10572
+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
10573
+
10574
+ if (dst_on_device) {
10575
+ dst_ddf = (float *) dst_extra->data_device[g_main_device];
10576
+ } else {
10577
+ dst_ddf = dst_f.alloc(ggml_nelements(dst));
10578
+ }
10579
+
10580
+ // do the computation
10581
+ ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
10582
+ CUDA_CHECK(cudaGetLastError());
10583
+
10584
+ // copy dst to host if necessary
10585
+ if (!dst_on_device) {
10586
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
10587
+ }
10588
+
10589
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
10590
+ CUDA_CHECK(cudaDeviceSynchronize());
10591
+ }
10592
+ }
10593
+
10594
+ static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10595
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_timestep_embedding);
10596
+ }
10597
+
9832
10598
  static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9833
10599
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
9834
10600
  }
@@ -9850,7 +10616,7 @@ GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const stru
9850
10616
 
9851
10617
  static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
9852
10618
  GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
9853
- GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
10619
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
9854
10620
  GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
9855
10621
  GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
9856
10622
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -9881,7 +10647,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
9881
10647
  GGML_ASSERT(!ggml_is_transposed(src0));
9882
10648
  GGML_ASSERT(!ggml_is_transposed(src1));
9883
10649
  GGML_ASSERT(!ggml_is_permuted(src0));
9884
- GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
10650
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
9885
10651
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9886
10652
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
9887
10653
 
@@ -9940,7 +10706,7 @@ static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggm
9940
10706
  GGML_ASSERT(!ggml_is_transposed(src0));
9941
10707
  GGML_ASSERT(!ggml_is_transposed(src1));
9942
10708
 
9943
- GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
10709
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
9944
10710
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9945
10711
 
9946
10712
  GGML_TENSOR_BINARY_OP_LOCALS
@@ -10086,11 +10852,11 @@ static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggm
10086
10852
 
10087
10853
  static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10088
10854
  const bool all_on_device =
10089
- (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
10090
- (src1->backend == GGML_BACKEND_GPU) &&
10091
- ( dst->backend == GGML_BACKEND_GPU);
10855
+ (src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT) &&
10856
+ (src1->backend == GGML_BACKEND_TYPE_GPU) &&
10857
+ ( dst->backend == GGML_BACKEND_TYPE_GPU);
10092
10858
 
10093
- const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
10859
+ const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
10094
10860
 
10095
10861
  int64_t min_compute_capability = INT_MAX;
10096
10862
 
@@ -10240,7 +11006,7 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
10240
11006
  GGML_ASSERT(!ggml_is_transposed(src00));
10241
11007
  GGML_ASSERT(!ggml_is_transposed(src1));
10242
11008
 
10243
- GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT);
11009
+ GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
10244
11010
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
10245
11011
 
10246
11012
  const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
@@ -10384,7 +11150,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
10384
11150
 
10385
11151
  cudaStream_t stream = g_cudaStreams[g_main_device][0];
10386
11152
 
10387
- if (ids->backend == GGML_BACKEND_GPU) {
11153
+ if (ids->backend == GGML_BACKEND_TYPE_GPU) {
10388
11154
  const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
10389
11155
  CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
10390
11156
  CUDA_CHECK(cudaStreamSynchronize(stream));
@@ -10401,20 +11167,20 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
10401
11167
  ggml_tensor src1_row = *src1;
10402
11168
  ggml_tensor dst_row = *dst;
10403
11169
 
10404
- src1_row.backend = GGML_BACKEND_GPU;
10405
- dst_row.backend = GGML_BACKEND_GPU;
11170
+ src1_row.backend = GGML_BACKEND_TYPE_GPU;
11171
+ dst_row.backend = GGML_BACKEND_TYPE_GPU;
10406
11172
 
10407
11173
  src1_row.extra = &src1_row_extra;
10408
11174
  dst_row.extra = &dst_row_extra;
10409
11175
 
10410
- char * src1_original = src1->backend == GGML_BACKEND_CPU ?
11176
+ char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
10411
11177
  (char *) src1->data : (char *) src1_extra->data_device[g_main_device];
10412
- char * dst_original = dst->backend == GGML_BACKEND_CPU ?
11178
+ char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ?
10413
11179
  (char *) dst->data : (char *) dst_extra->data_device[g_main_device];
10414
11180
 
10415
11181
  if (src1->ne[1] == 1) {
10416
- GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
10417
- GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
11182
+ GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
11183
+ GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
10418
11184
 
10419
11185
  for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
10420
11186
  //int32_t row_id;
@@ -10442,9 +11208,9 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
10442
11208
  src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
10443
11209
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
10444
11210
 
10445
- const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
11211
+ const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_TYPE_CPU ?
10446
11212
  cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
10447
- const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_CPU ?
11213
+ const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_TYPE_CPU ?
10448
11214
  cudaMemcpyDeviceToHost : cudaMemcpyDeviceToDevice;
10449
11215
 
10450
11216
  for (int32_t row_id = 0; row_id < n_as; ++row_id) {
@@ -10499,7 +11265,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
10499
11265
  }
10500
11266
  }
10501
11267
 
10502
- if (dst->backend == GGML_BACKEND_CPU) {
11268
+ if (dst->backend == GGML_BACKEND_TYPE_CPU) {
10503
11269
  CUDA_CHECK(cudaStreamSynchronize(stream));
10504
11270
  }
10505
11271
  }
@@ -10516,8 +11282,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
10516
11282
  const int64_t ne = ggml_nelements(src0);
10517
11283
  GGML_ASSERT(ne == ggml_nelements(src1));
10518
11284
 
10519
- GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
10520
- GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
11285
+ GGML_ASSERT(src0->backend == GGML_BACKEND_TYPE_GPU);
11286
+ GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
10521
11287
 
10522
11288
  GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
10523
11289
  GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
@@ -10648,9 +11414,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
10648
11414
  if (!g_cublas_loaded) return false;
10649
11415
 
10650
11416
  ggml_cuda_func_t func;
10651
- const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
10652
- || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
10653
- || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
11417
+ const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU
11418
+ || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
11419
+ || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
10654
11420
 
10655
11421
  if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
10656
11422
  return false;
@@ -10729,6 +11495,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
10729
11495
  case GGML_OP_PAD:
10730
11496
  func = ggml_cuda_pad;
10731
11497
  break;
11498
+ case GGML_OP_ARANGE:
11499
+ func = ggml_cuda_arange;
11500
+ break;
11501
+ case GGML_OP_TIMESTEP_EMBEDDING:
11502
+ func = ggml_cuda_timestep_embedding;
11503
+ break;
10732
11504
  case GGML_OP_LEAKY_RELU:
10733
11505
  func = ggml_cuda_leaky_relu;
10734
11506
  break;
@@ -10797,14 +11569,14 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
10797
11569
  return false;
10798
11570
  }
10799
11571
 
10800
- if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT) {
11572
+ if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
10801
11573
  ggml_cuda_set_peer_access(tensor->src[1]->ne[1]);
10802
11574
  }
10803
11575
 
10804
11576
  if (params->ith != 0) {
10805
11577
  return true;
10806
11578
  }
10807
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11579
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10808
11580
  return true;
10809
11581
  }
10810
11582
  func(tensor->src[0], tensor->src[1], tensor);
@@ -10903,7 +11675,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
10903
11675
 
10904
11676
  extra->data_device[ctx->device] = tensor->data;
10905
11677
 
10906
- tensor->backend = GGML_BACKEND_GPU;
11678
+ tensor->backend = GGML_BACKEND_TYPE_GPU;
10907
11679
  tensor->extra = extra;
10908
11680
 
10909
11681
  if (ggml_is_quantized(tensor->type)) {
@@ -10918,7 +11690,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
10918
11690
  }
10919
11691
 
10920
11692
  GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
10921
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
11693
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
10922
11694
 
10923
11695
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
10924
11696
 
@@ -10929,7 +11701,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t
10929
11701
  }
10930
11702
 
10931
11703
  GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
10932
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
11704
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
10933
11705
 
10934
11706
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
10935
11707
 
@@ -11164,7 +11936,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
11164
11936
  CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
11165
11937
  }
11166
11938
  }
11167
- tensor->backend = GGML_BACKEND_GPU_SPLIT;
11939
+ tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
11168
11940
  tensor->extra = extra;
11169
11941
  }
11170
11942
 
@@ -11436,7 +12208,7 @@ GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend,
11436
12208
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
11437
12209
 
11438
12210
  GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
11439
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
12211
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
11440
12212
 
11441
12213
  CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0]));
11442
12214
  }
@@ -11445,7 +12217,7 @@ GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend,
11445
12217
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
11446
12218
 
11447
12219
  GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
11448
- GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
12220
+ GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
11449
12221
 
11450
12222
  CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0]));
11451
12223
  }
@@ -11469,13 +12241,13 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
11469
12241
  UNUSED(backend);
11470
12242
  }
11471
12243
 
11472
- GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
12244
+ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
11473
12245
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
11474
12246
 
11475
12247
  ggml_cuda_set_main_device(cuda_ctx->device);
11476
12248
 
11477
12249
  ggml_compute_params params = {};
11478
- params.type = GGML_TASK_COMPUTE;
12250
+ params.type = GGML_TASK_TYPE_COMPUTE;
11479
12251
  params.ith = 0;
11480
12252
  for (int i = 0; i < cgraph->n_nodes; i++) {
11481
12253
  ggml_tensor * node = cgraph->nodes[i];
@@ -11485,13 +12257,13 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg
11485
12257
  }
11486
12258
 
11487
12259
  #ifndef NDEBUG
11488
- assert(node->backend == GGML_BACKEND_GPU || node->backend == GGML_BACKEND_GPU_SPLIT);
12260
+ assert(node->backend == GGML_BACKEND_TYPE_GPU || node->backend == GGML_BACKEND_TYPE_GPU_SPLIT);
11489
12261
  assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
11490
12262
  assert(node->extra != nullptr);
11491
12263
 
11492
12264
  for (int j = 0; j < GGML_MAX_SRC; j++) {
11493
12265
  if (node->src[j] != nullptr) {
11494
- assert(node->src[j]->backend == GGML_BACKEND_GPU || node->src[j]->backend == GGML_BACKEND_GPU_SPLIT);
12266
+ assert(node->src[j]->backend == GGML_BACKEND_TYPE_GPU || node->src[j]->backend == GGML_BACKEND_TYPE_GPU_SPLIT);
11495
12267
  assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
11496
12268
  assert(node->src[j]->extra != nullptr);
11497
12269
  }
@@ -11505,7 +12277,7 @@ GGML_CALL static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, gg
11505
12277
  GGML_ASSERT(ok);
11506
12278
  }
11507
12279
 
11508
- return true;
12280
+ return GGML_STATUS_SUCCESS;
11509
12281
  }
11510
12282
 
11511
12283
  GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
@@ -11541,7 +12313,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
11541
12313
  }
11542
12314
  ggml_type a_type = a->type;
11543
12315
  if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
11544
- a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) {
12316
+ a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
12317
+ a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
11545
12318
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
11546
12319
  return false;
11547
12320
  }
@@ -11623,6 +12396,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
11623
12396
  case GGML_OP_GROUP_NORM:
11624
12397
  case GGML_OP_UPSCALE:
11625
12398
  case GGML_OP_PAD:
12399
+ case GGML_OP_ARANGE:
12400
+ case GGML_OP_TIMESTEP_EMBEDDING:
11626
12401
  case GGML_OP_LEAKY_RELU:
11627
12402
  return true;
11628
12403
  default:
@@ -11647,6 +12422,11 @@ static ggml_backend_i ggml_backend_cuda_interface = {
11647
12422
  /* .supports_op = */ ggml_backend_cuda_supports_op,
11648
12423
  };
11649
12424
 
12425
+ static ggml_guid_t ggml_backend_cuda_guid() {
12426
+ static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };
12427
+ return &guid;
12428
+ }
12429
+
11650
12430
  GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
11651
12431
  ggml_init_cublas(); // TODO: remove from ggml.c
11652
12432
 
@@ -11664,6 +12444,7 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
11664
12444
  };
11665
12445
 
11666
12446
  ggml_backend_t cuda_backend = new ggml_backend {
12447
+ /* .guid = */ ggml_backend_cuda_guid(),
11667
12448
  /* .interface = */ ggml_backend_cuda_interface,
11668
12449
  /* .context = */ ctx
11669
12450
  };
@@ -11672,7 +12453,7 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
11672
12453
  }
11673
12454
 
11674
12455
  GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend) {
11675
- return backend && backend->iface.get_name == ggml_backend_cuda_name;
12456
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());
11676
12457
  }
11677
12458
 
11678
12459
  GGML_CALL int ggml_backend_cuda_get_device_count() {