llama_cpp 0.2.2 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
117
117
 
118
118
  //================================= k-quants
119
119
 
120
+ #ifdef GGML_QKK_64
121
+ #define QK_K 64
122
+ #define K_SCALE_SIZE 4
123
+ #else
120
124
  #define QK_K 256
125
+ #define K_SCALE_SIZE 12
126
+ #endif
121
127
 
122
128
  typedef struct {
123
129
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -128,13 +134,25 @@ typedef struct {
128
134
  static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
129
135
 
130
136
  typedef struct {
131
- uint8_t hmask[QK_K/8];
132
- uint8_t qs[QK_K/4]; // nibbles / quants
133
- uint8_t scales[3*QK_K/64];
134
- half d;
137
+ uint8_t hmask[QK_K/8]; // quants - high bit
138
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
139
+ #ifdef GGML_QKK_64
140
+ uint8_t scales[2]; // scales, quantized with 8 bits
141
+ #else
142
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
143
+ #endif
144
+ half d; // super-block scale
135
145
  } block_q3_K;
136
- static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
146
+ //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
137
147
 
148
+ #ifdef GGML_QKK_64
149
+ typedef struct {
150
+ half d[2]; // super-block scales/mins
151
+ uint8_t scales[2]; // 4-bit block scales/mins
152
+ uint8_t qs[QK_K/2]; // 4--bit quants
153
+ } block_q4_K;
154
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
155
+ #else
138
156
  typedef struct {
139
157
  half d; // super-block scale for quantized scales
140
158
  half dmin; // super-block scale for quantized mins
@@ -142,15 +160,26 @@ typedef struct {
142
160
  uint8_t qs[QK_K/2]; // 4--bit quants
143
161
  } block_q4_K;
144
162
  static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
163
+ #endif
145
164
 
165
+ #ifdef GGML_QKK_64
146
166
  typedef struct {
147
- half d; // super-block scale for quantized scales
148
- half dmin; // super-block scale for quantized mins
149
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
167
+ half d; // super-block scale
168
+ int8_t scales[QK_K/16]; // block scales
169
+ uint8_t qh[QK_K/8]; // quants, high bit
170
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
171
+ } block_q5_K;
172
+ static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
173
+ #else
174
+ typedef struct {
175
+ half d; // super-block scale for quantized scales
176
+ half dmin; // super-block scale for quantized mins
177
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
150
178
  uint8_t qh[QK_K/8]; // quants, high bit
151
179
  uint8_t qs[QK_K/2]; // quants, low 4 bits
152
180
  } block_q5_K;
153
- static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
181
+ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
182
+ #endif
154
183
 
155
184
  typedef struct {
156
185
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -349,13 +378,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
349
378
  static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
350
379
 
351
380
  const int i = blockIdx.x;
381
+ const block_q2_K * x = (const block_q2_K *) vx;
382
+
352
383
  const int tid = threadIdx.x;
384
+ #if QK_K == 256
353
385
  const int n = tid/32;
354
386
  const int l = tid - 32*n;
355
387
  const int is = 8*n + l/16;
356
388
 
357
- const block_q2_K * x = (const block_q2_K *) vx;
358
-
359
389
  const uint8_t q = x[i].qs[32*n + l];
360
390
  float * y = yy + i*QK_K + 128*n;
361
391
 
@@ -365,21 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
365
395
  y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
366
396
  y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
367
397
  y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
398
+ #else
399
+ const int is = tid/16; // 0 or 1
400
+ const int il = tid%16; // 0...15
401
+ const uint8_t q = x[i].qs[il] >> (2*is);
402
+ float * y = yy + i*QK_K + 16*is + il;
403
+ float dall = x[i].d;
404
+ float dmin = x[i].dmin;
405
+ y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
406
+ y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
407
+ #endif
368
408
 
369
409
  }
370
410
 
371
411
  static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
372
412
 
373
- int r = threadIdx.x/4;
374
- int i = blockIdx.x;
375
- int tid = r/2;
376
- int is0 = r%2;
377
- int l0 = 16*is0 + 4*(threadIdx.x%4);
378
- int n = tid / 4;
379
- int j = tid - 4*n;
380
-
413
+ const int i = blockIdx.x;
381
414
  const block_q3_K * x = (const block_q3_K *) vx;
382
415
 
416
+ #if QK_K == 256
417
+ const int r = threadIdx.x/4;
418
+ const int tid = r/2;
419
+ const int is0 = r%2;
420
+ const int l0 = 16*is0 + 4*(threadIdx.x%4);
421
+ const int n = tid / 4;
422
+ const int j = tid - 4*n;
423
+
383
424
  uint8_t m = 1 << (4*n + j);
384
425
  int is = 8*n + 2*j + is0;
385
426
  int shift = 2*j;
@@ -396,9 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
396
437
  const uint8_t * hm = x[i].hmask;
397
438
 
398
439
  for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
440
+ #else
441
+ const int tid = threadIdx.x;
442
+ const int is = tid/16; // 0 or 1
443
+ const int il = tid%16; // 0...15
444
+ const int im = il/8; // 0...1
445
+ const int in = il%8; // 0...7
446
+
447
+ float * y = yy + i*QK_K + 16*is + il;
448
+
449
+ const uint8_t q = x[i].qs[il] >> (2*is);
450
+ const uint8_t h = x[i].hmask[in] >> (2*is + im);
451
+ const float d = (float)x[i].d;
452
+
453
+ if (is == 0) {
454
+ y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
455
+ y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
456
+ } else {
457
+ y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
458
+ y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
459
+ }
460
+ #endif
399
461
 
400
462
  }
401
463
 
464
+ #if QK_K == 256
402
465
  static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
403
466
  if (j < 4) {
404
467
  d = q[j] & 63; m = q[j + 4] & 63;
@@ -407,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
407
470
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
408
471
  }
409
472
  }
473
+ #endif
410
474
 
411
475
  static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
412
476
  const block_q4_K * x = (const block_q4_K *) vx;
413
477
 
414
478
  const int i = blockIdx.x;
415
479
 
416
- //// assume 64 threads - this is very slightly better than the one below
417
- //const int tid = threadIdx.x;
418
- //const int il = tid/16;
419
- //const int ir = tid%16;
420
- //const int is = 2*il;
421
- //const int n = 2;
422
-
480
+ #if QK_K == 256
423
481
  // assume 32 threads
424
482
  const int tid = threadIdx.x;
425
483
  const int il = tid/8;
@@ -443,6 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
443
501
  y[l + 0] = d1 * (q[l] & 0xF) - m1;
444
502
  y[l +32] = d2 * (q[l] >> 4) - m2;
445
503
  }
504
+ #else
505
+ const int tid = threadIdx.x;
506
+ const uint8_t * q = x[i].qs;
507
+ float * y = yy + i*QK_K;
508
+ const float d = (float)x[i].d[0];
509
+ const float m = (float)x[i].d[1];
510
+ y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
511
+ y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
512
+ #endif
446
513
  }
447
514
 
448
515
  static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -450,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
450
517
 
451
518
  const int i = blockIdx.x;
452
519
 
520
+ #if QK_K == 256
453
521
  // assume 64 threads - this is very slightly better than the one below
454
522
  const int tid = threadIdx.x;
455
523
  const int il = tid/16; // il is in 0...3
@@ -476,12 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
476
544
  hm <<= 1;
477
545
  y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
478
546
  y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
547
+ #else
548
+ const int tid = threadIdx.x;
549
+ const uint8_t q = x[i].qs[tid];
550
+ const int im = tid/8; // 0...3
551
+ const int in = tid%8; // 0...7
552
+ const int is = tid/16; // 0 or 1
553
+ const uint8_t h = x[i].qh[in] >> im;
554
+ const float d = x[i].d;
555
+ float * y = yy + i*QK_K + tid;
556
+ y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
557
+ y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
558
+ #endif
479
559
  }
480
560
 
481
561
  static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
482
562
  const block_q6_K * x = (const block_q6_K *) vx;
483
563
 
484
564
  const int i = blockIdx.x;
565
+ #if QK_K == 256
485
566
 
486
567
  // assume 64 threads - this is very slightly better than the one below
487
568
  const int tid = threadIdx.x;
@@ -501,6 +582,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
501
582
  y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
502
583
  y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
503
584
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
585
+ #else
586
+
587
+ // assume 32 threads
588
+ const int tid = threadIdx.x;
589
+ const int ip = tid/16; // 0 or 1
590
+ const int il = tid - 16*ip; // 0...15
591
+
592
+ float * y = yy + i*QK_K + 16*ip + il;
593
+
594
+ const float d = x[i].d;
595
+
596
+ const uint8_t ql = x[i].ql[16*ip + il];
597
+ const uint8_t qh = x[i].qh[il] >> (2*ip);
598
+ const int8_t * sc = x[i].scales;
599
+
600
+ y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
601
+ y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
602
+ #endif
504
603
  }
505
604
 
506
605
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
@@ -515,6 +614,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
515
614
 
516
615
  const block_q2_K * x = (const block_q2_K *)vx + ib0;
517
616
 
617
+ float tmp = 0; // partial sum for thread in warp
618
+
619
+ #if QK_K == 256
518
620
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
519
621
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
520
622
 
@@ -528,8 +630,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
528
630
  const int s_offset = 8*im;
529
631
  const int y_offset = 128*im + l0;
530
632
 
531
- float tmp = 0; // partial sum for thread in warp
532
-
533
633
  uint32_t aux[4];
534
634
  const uint8_t * d = (const uint8_t *)aux;
535
635
  const uint8_t * m = (const uint8_t *)(aux + 2);
@@ -565,6 +665,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
565
665
  tmp += dall * sum1 - dmin * sum2;
566
666
 
567
667
  }
668
+ #else
669
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
670
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
671
+ const int offset = tid * K_QUANTS_PER_ITERATION;
672
+
673
+ uint32_t uaux[2];
674
+ const uint8_t * d = (const uint8_t *)uaux;
675
+
676
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
677
+
678
+ const float * y = yy + i * QK_K + offset;
679
+ const uint8_t * q = x[i].qs + offset;
680
+ const uint32_t * s = (const uint32_t *)x[i].scales;
681
+
682
+ uaux[0] = s[0] & 0x0f0f0f0f;
683
+ uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
684
+
685
+ const half2 * dh = (const half2 *)&x[i].d;
686
+
687
+ const float2 dall = __half22float2(dh[0]);
688
+
689
+ float sum1 = 0, sum2 = 0;
690
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
691
+ const uint8_t ql = q[l];
692
+ sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
693
+ + y[l+16] * d[1] * ((ql >> 2) & 3)
694
+ + y[l+32] * d[2] * ((ql >> 4) & 3)
695
+ + y[l+48] * d[3] * ((ql >> 6) & 3);
696
+ sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
697
+ }
698
+ tmp += dall.x * sum1 - dall.y * sum2;
699
+ }
700
+ #endif
568
701
 
569
702
  // sum up partial sums and write back result
570
703
  __syncthreads();
@@ -573,16 +706,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
573
706
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
574
707
  }
575
708
 
576
- if (tid == 0) {
709
+ if (threadIdx.x == 0) {
577
710
  dst[row] = tmp;
578
711
  }
579
712
  }
580
713
 
581
714
  static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
582
715
 
583
- const uint16_t kmask1 = 0x0303;
584
- const uint16_t kmask2 = 0x0f0f;
585
-
586
716
  const int row = blockIdx.y*blockDim.y + threadIdx.y;
587
717
  if (row > nrows) return;
588
718
 
@@ -591,6 +721,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
591
721
 
592
722
  const block_q3_K * x = (const block_q3_K *)vx + ib0;
593
723
 
724
+ float tmp = 0; // partial sum for thread in warp
725
+
726
+ #if QK_K == 256
727
+
728
+ const uint16_t kmask1 = 0x0303;
729
+ const uint16_t kmask2 = 0x0f0f;
730
+
594
731
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
595
732
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
596
733
 
@@ -610,8 +747,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
610
747
 
611
748
  const uint16_t s_shift = 4*im;
612
749
 
613
- float tmp = 0; // partial sum for thread in warp
614
-
615
750
  for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
616
751
 
617
752
  const float * y = yy + i * QK_K + y_offset;
@@ -640,6 +775,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
640
775
  tmp += d * sum;
641
776
 
642
777
  }
778
+ #else
779
+
780
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
781
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
782
+ const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
783
+ const int in = offset/8; // 0 or 1
784
+ const int im = offset%8; // 0...7
785
+
786
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
787
+
788
+ const float * y = yy + i * QK_K + offset;
789
+ const uint8_t * q = x[i].qs + offset;
790
+ const uint8_t * s = x[i].scales;
791
+
792
+ const float dall = (float)x[i].d;
793
+
794
+ float sum = 0;
795
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
796
+ const uint8_t hl = x[i].hmask[im+l] >> in;
797
+ const uint8_t ql = q[l];
798
+ sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
799
+ + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
800
+ + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
801
+ + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
802
+ }
803
+ tmp += sum;
804
+ }
805
+ #endif
643
806
 
644
807
  // sum up partial sums and write back result
645
808
  __syncthreads();
@@ -648,22 +811,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
648
811
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
649
812
  }
650
813
 
651
- if (tid == 0) {
814
+ if (threadIdx.x == 0) {
652
815
  dst[row] = tmp;
653
816
  }
654
817
  }
655
818
 
656
819
  static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
657
820
 
658
- const uint16_t kmask1 = 0x3f3f;
659
- const uint16_t kmask2 = 0x0f0f;
660
- const uint16_t kmask3 = 0xc0c0;
661
-
662
821
  const int row = blockIdx.y*blockDim.y + threadIdx.y;
663
822
  if (row > nrows) return;
664
823
  const int num_blocks_per_row = ncols / QK_K;
665
824
  const int ib0 = row*num_blocks_per_row;
666
825
 
826
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
827
+
828
+ #if QK_K == 256
829
+ const uint16_t kmask1 = 0x3f3f;
830
+ const uint16_t kmask2 = 0x0f0f;
831
+ const uint16_t kmask3 = 0xc0c0;
832
+
667
833
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
668
834
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
669
835
 
@@ -683,8 +849,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
683
849
  uint16_t aux[4];
684
850
  const uint8_t * sc = (const uint8_t *)aux;
685
851
 
686
- const block_q4_K * x = (const block_q4_K *)vx + ib0;
687
-
688
852
  float tmp = 0; // partial sum for thread in warp
689
853
 
690
854
  for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
@@ -713,6 +877,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
713
877
  tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
714
878
 
715
879
  }
880
+ #else
881
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
882
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
883
+
884
+ const int step = tid * K_QUANTS_PER_ITERATION;
885
+
886
+ uint16_t aux16[2];
887
+ const uint8_t * s = (const uint8_t *)aux16;
888
+
889
+ float tmp = 0;
890
+
891
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
892
+ const uint8_t * q = x[i].qs + step;
893
+ const float * y = yy + i*QK_K + step;
894
+ const uint16_t * a = (const uint16_t *)x[i].scales;
895
+ aux16[0] = a[0] & 0x0f0f;
896
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
897
+ const float d = (float)x[i].d[0];
898
+ const float m = (float)x[i].d[1];
899
+ float sum = 0.f;
900
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
901
+ sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
902
+ + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
903
+ + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
904
+ + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
905
+ }
906
+ tmp += sum;
907
+ }
908
+
909
+ #endif
716
910
 
717
911
  // sum up partial sums and write back result
718
912
  __syncthreads();
@@ -728,15 +922,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
728
922
 
729
923
  static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
730
924
 
731
- const uint16_t kmask1 = 0x3f3f;
732
- const uint16_t kmask2 = 0x0f0f;
733
- const uint16_t kmask3 = 0xc0c0;
734
-
735
- //const int row = blockIdx.x*blockDim.y + threadIdx.y;
736
925
  const int row = blockIdx.x;
737
926
  const int num_blocks_per_row = ncols / QK_K;
738
927
  const int ib0 = row*num_blocks_per_row;
739
928
 
929
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
930
+
931
+ float tmp = 0; // partial sum for thread in warp
932
+
933
+ #if QK_K == 256
934
+ const uint16_t kmask1 = 0x3f3f;
935
+ const uint16_t kmask2 = 0x0f0f;
936
+ const uint16_t kmask3 = 0xc0c0;
937
+
740
938
  const int tid = threadIdx.x/2; // 0...15
741
939
  const int ix = threadIdx.x%2;
742
940
 
@@ -757,10 +955,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
757
955
  uint16_t aux[4];
758
956
  const uint8_t * sc = (const uint8_t *)aux;
759
957
 
760
- const block_q5_K * x = (const block_q5_K *)vx + ib0;
761
-
762
- float tmp = 0; // partial sum for thread in warp
763
-
764
958
  for (int i = ix; i < num_blocks_per_row; i += 2) {
765
959
 
766
960
  const uint8_t * ql1 = x[i].qs + q_offset;
@@ -793,8 +987,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
793
987
  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
794
988
  }
795
989
  tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
990
+ }
796
991
 
992
+ #else
993
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
994
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
995
+ const int step = tid * K_QUANTS_PER_ITERATION;
996
+ const int im = step/8;
997
+ const int in = step%8;
998
+
999
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
1000
+ const uint8_t * q = x[i].qs + step;
1001
+ const int8_t * s = x[i].scales;
1002
+ const float * y = yy + i*QK_K + step;
1003
+ const float d = x[i].d;
1004
+ float sum = 0.f;
1005
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
1006
+ const uint8_t h = x[i].qh[in+j] >> im;
1007
+ sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
1008
+ + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
1009
+ + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
1010
+ + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
1011
+ }
1012
+ tmp += sum;
797
1013
  }
1014
+ #endif
798
1015
 
799
1016
  // sum up partial sums and write back result
800
1017
  __syncthreads();
@@ -803,7 +1020,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
803
1020
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
804
1021
  }
805
1022
 
806
- if (tid == 0) {
1023
+ if (threadIdx.x == 0) {
807
1024
  dst[row] = tmp;
808
1025
  }
809
1026
  }
@@ -820,6 +1037,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
820
1037
 
821
1038
  const block_q6_K * x = (const block_q6_K *)vx + ib0;
822
1039
 
1040
+ #if QK_K == 256
1041
+
823
1042
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
824
1043
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
825
1044
 
@@ -874,6 +1093,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
874
1093
 
875
1094
  }
876
1095
 
1096
+ #else
1097
+
1098
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
1099
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
1100
+
1101
+ const int step = tid * K_QUANTS_PER_ITERATION;
1102
+
1103
+ float tmp = 0; // partial sum for thread in warp
1104
+
1105
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
1106
+
1107
+ const float * y = yy + i * QK_K + step;
1108
+ const uint8_t * ql = x[i].ql + step;
1109
+ const uint8_t * qh = x[i].qh + step;
1110
+ const int8_t * s = x[i].scales;
1111
+
1112
+ const float d = x[i+0].d;
1113
+
1114
+ float sum = 0;
1115
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
1116
+ sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
1117
+ + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
1118
+ + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
1119
+ + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
1120
+ }
1121
+ tmp += sum;
1122
+
1123
+ }
1124
+
1125
+ #endif
1126
+
877
1127
  // sum up partial sums and write back result
878
1128
  __syncthreads();
879
1129
  #pragma unroll
@@ -1252,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
1252
1502
 
1253
1503
  static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1254
1504
  const int nb = k / QK_K;
1505
+ #if QK_K == 256
1255
1506
  dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
1507
+ #else
1508
+ dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
1509
+ #endif
1256
1510
  }
1257
1511
 
1258
1512
  static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1259
1513
  const int nb = k / QK_K;
1514
+ #if QK_K == 256
1260
1515
  dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
1516
+ #else
1517
+ dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
1518
+ #endif
1261
1519
  }
1262
1520
 
1263
1521
  static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1267,12 +1525,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
1267
1525
 
1268
1526
  static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1269
1527
  const int nb = k / QK_K;
1528
+ #if QK_K == 256
1270
1529
  dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
1530
+ #else
1531
+ dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
1532
+ #endif
1271
1533
  }
1272
1534
 
1273
1535
  static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1274
1536
  const int nb = k / QK_K;
1537
+ #if QK_K == 256
1275
1538
  dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
1539
+ #else
1540
+ dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
1541
+ #endif
1276
1542
  }
1277
1543
 
1278
1544
  static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -2553,6 +2819,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2553
2819
 
2554
2820
  tensor->backend = GGML_BACKEND_GPU;
2555
2821
  struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
2822
+ memset(extra, 0, sizeof(*extra));
2556
2823
 
2557
2824
  const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
2558
2825
  tensor->op == GGML_OP_VIEW;
@@ -2635,7 +2902,7 @@ void ggml_cuda_free_scratch() {
2635
2902
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
2636
2903
  ggml_cuda_func_t func;
2637
2904
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
2638
- || tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
2905
+ || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
2639
2906
  || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
2640
2907
 
2641
2908
  switch (tensor->op) {