llama_cpp 0.2.0 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
256
256
  (device float *) ((device char *) dst + i*nb1), ne00);
257
257
  }
258
258
 
259
+ kernel void kernel_norm(
260
+ device const void * src0,
261
+ device float * dst,
262
+ constant int64_t & ne00,
263
+ constant uint64_t & nb01,
264
+ constant float & eps,
265
+ threadgroup float * sum [[threadgroup(0)]],
266
+ uint tgpig[[threadgroup_position_in_grid]],
267
+ uint tpitg[[thread_position_in_threadgroup]],
268
+ uint ntg[[threads_per_threadgroup]]) {
269
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
270
+ // MEAN
271
+ // parallel sum
272
+ sum[tpitg] = 0.0f;
273
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
274
+ sum[tpitg] += x[i00];
275
+ }
276
+ // reduce
277
+ threadgroup_barrier(mem_flags::mem_threadgroup);
278
+ for (uint i = ntg/2; i > 0; i /= 2) {
279
+ if (tpitg < i) {
280
+ sum[tpitg] += sum[tpitg + i];
281
+ }
282
+ threadgroup_barrier(mem_flags::mem_threadgroup);
283
+ }
284
+ // broadcast
285
+ if (tpitg == 0) {
286
+ sum[0] /= ne00;
287
+ }
288
+ threadgroup_barrier(mem_flags::mem_threadgroup);
289
+ const float mean = sum[0];
290
+
291
+ // recenter
292
+ device float * y = dst + tgpig*ne00;
293
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
294
+ y[i00] = x[i00] - mean;
295
+ }
296
+
297
+ // VARIANCE
298
+ // parallel sum
299
+ sum[tpitg] = 0.0f;
300
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
301
+ sum[tpitg] += y[i00] * y[i00];
302
+ }
303
+ // reduce
304
+ threadgroup_barrier(mem_flags::mem_threadgroup);
305
+ for (uint i = ntg/2; i > 0; i /= 2) {
306
+ if (tpitg < i) {
307
+ sum[tpitg] += sum[tpitg + i];
308
+ }
309
+ threadgroup_barrier(mem_flags::mem_threadgroup);
310
+ }
311
+ // broadcast
312
+ if (tpitg == 0) {
313
+ sum[0] /= ne00;
314
+ }
315
+ threadgroup_barrier(mem_flags::mem_threadgroup);
316
+ const float variance = sum[0];
317
+
318
+ const float scale = 1.0f/sqrt(variance + eps);
319
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
320
+ y[i00] = y[i00] * scale;
321
+ }
322
+ }
323
+
324
+
259
325
  kernel void kernel_rms_norm(
260
326
  device const void * src0,
261
327
  device float * dst,
@@ -304,34 +370,22 @@ kernel void kernel_mul_mat_q4_0_f32(
304
370
  device const float * src1,
305
371
  device float * dst,
306
372
  constant int64_t & ne00,
307
- constant int64_t & ne01,
308
- constant uint64_t & nb00,
309
- constant uint64_t & nb01,
310
- constant uint64_t & nb02,
311
373
  constant int64_t & ne10,
312
- constant int64_t & ne11,
313
- constant uint64_t & nb10,
314
- constant uint64_t & nb11,
315
- constant uint64_t & nb12,
316
374
  constant int64_t & ne0,
317
- constant int64_t & ne1,
318
375
  threadgroup float * sum [[threadgroup(0)]],
319
376
  uint2 tgpig[[threadgroup_position_in_grid]],
320
- uint2 tpig[[thread_position_in_grid]],
321
377
  uint2 tpitg[[thread_position_in_threadgroup]],
322
378
  uint2 tptg[[threads_per_threadgroup]]) {
323
379
  const int nb = ne00/QK4_0;
324
380
 
325
- const int8_t m8 = 8;
326
-
327
381
  const int64_t r0 = tgpig.x;
328
382
  const int64_t r1 = tgpig.y;
329
383
 
330
384
  device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
331
385
  device const float * y = (device const float *) src1 + r1*ne10;
332
386
 
333
- const uint nth = tptg.x*tptg.y;
334
- const uint ith = tptg.y*tpitg.x + tpitg.y;
387
+ const int nth = tptg.x*tptg.y;
388
+ const int ith = tptg.y*tpitg.x + tpitg.y;
335
389
 
336
390
  const int ix = tpitg.y/4; // 0 or 1
337
391
  const int iy = tpitg.y - 4*ix; // 0...3
@@ -351,47 +405,32 @@ kernel void kernel_mul_mat_q4_0_f32(
351
405
 
352
406
  for (int j = 0; j < 4; ++j) {
353
407
 
354
- acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8);
355
- acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8);
408
+ acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
409
+ acc[1] += yl[j] + yl[j+16];
356
410
 
357
411
  }
358
412
 
359
- sumf += d * (acc[0] + acc[1]);
413
+ sumf += d * (acc[0] - 8.f*acc[1]);
360
414
  }
361
415
 
362
416
  sum[ith] = sumf;
363
417
 
364
418
  //
365
419
  // Accumulate the sum from all threads in the threadgroup
366
- // This version is slightly faster than the commented out one below,
367
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
368
420
  //
369
421
  threadgroup_barrier(mem_flags::mem_threadgroup);
370
422
  if (ith%4 == 0) {
371
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
423
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
372
424
  }
373
425
  threadgroup_barrier(mem_flags::mem_threadgroup);
374
426
  if (ith%16 == 0) {
375
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
427
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
376
428
  }
377
429
  threadgroup_barrier(mem_flags::mem_threadgroup);
378
430
  if (ith == 0) {
379
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
431
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
380
432
  dst[r1*ne0 + r0] = sum[0];
381
433
  }
382
-
383
- //// accumulate the sum from all threads in the threadgroup
384
- //threadgroup_barrier(mem_flags::mem_threadgroup);
385
- //for (uint i = nth/2; i > 0; i /= 2) {
386
- // if (ith < i) {
387
- // sum[ith] += sum[ith + i];
388
- // }
389
- // threadgroup_barrier(mem_flags::mem_threadgroup);
390
- //}
391
-
392
- //if (ith == 0) {
393
- // dst[r1*ne0 + r0] = sum[0];
394
- //}
395
434
  }
396
435
 
397
436
  kernel void kernel_mul_mat_q4_1_f32(
@@ -399,20 +438,10 @@ kernel void kernel_mul_mat_q4_1_f32(
399
438
  device const float * src1,
400
439
  device float * dst,
401
440
  constant int64_t & ne00,
402
- constant int64_t & ne01,
403
- constant uint64_t & nb00,
404
- constant uint64_t & nb01,
405
- constant uint64_t & nb02,
406
441
  constant int64_t & ne10,
407
- constant int64_t & ne11,
408
- constant uint64_t & nb10,
409
- constant uint64_t & nb11,
410
- constant uint64_t & nb12,
411
442
  constant int64_t & ne0,
412
- constant int64_t & ne1,
413
443
  threadgroup float * sum [[threadgroup(0)]],
414
444
  uint2 tgpig[[threadgroup_position_in_grid]],
415
- uint2 tpig[[thread_position_in_grid]],
416
445
  uint2 tpitg[[thread_position_in_threadgroup]],
417
446
  uint2 tptg[[threads_per_threadgroup]]) {
418
447
  const int nb = ne00/QK4_1;
@@ -460,11 +489,11 @@ kernel void kernel_mul_mat_q4_1_f32(
460
489
  //
461
490
  threadgroup_barrier(mem_flags::mem_threadgroup);
462
491
  if (ith%4 == 0) {
463
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
492
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
464
493
  }
465
494
  threadgroup_barrier(mem_flags::mem_threadgroup);
466
495
  if (ith%16 == 0) {
467
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
496
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
468
497
  }
469
498
  threadgroup_barrier(mem_flags::mem_threadgroup);
470
499
  if (ith == 0) {
@@ -522,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
522
551
  }
523
552
  }
524
553
 
554
+ kernel void kernel_alibi_f32(
555
+ device const float * src0,
556
+ device float * dst,
557
+ constant int64_t & ne00,
558
+ constant int64_t & ne01,
559
+ constant int64_t & ne02,
560
+ constant int64_t & ne03,
561
+ constant uint64_t & nb00,
562
+ constant uint64_t & nb01,
563
+ constant uint64_t & nb02,
564
+ constant uint64_t & nb03,
565
+ constant int64_t & ne0,
566
+ constant int64_t & ne1,
567
+ constant int64_t & ne2,
568
+ constant int64_t & ne3,
569
+ constant uint64_t & nb0,
570
+ constant uint64_t & nb1,
571
+ constant uint64_t & nb2,
572
+ constant uint64_t & nb3,
573
+ constant float & m0,
574
+ uint3 tgpig[[threadgroup_position_in_grid]],
575
+ uint3 tpitg[[thread_position_in_threadgroup]],
576
+ uint3 ntg[[threads_per_threadgroup]]) {
577
+ const int64_t i03 = tgpig[2];
578
+ const int64_t i02 = tgpig[1];
579
+ const int64_t i01 = tgpig[0];
580
+
581
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
582
+
583
+ const int64_t i3 = n / (ne2*ne1*ne0);
584
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
585
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
586
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
587
+
588
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
589
+ float m_k = pow(m0, i2 + 1);
590
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
591
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
592
+ dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
593
+ }
594
+ }
595
+
525
596
  kernel void kernel_rope(
526
597
  device const void * src0,
527
598
  device float * dst,
@@ -577,6 +648,47 @@ kernel void kernel_rope(
577
648
  }
578
649
  }
579
650
 
651
+ kernel void kernel_cpy_f16_f16(
652
+ device const half * src0,
653
+ device half * dst,
654
+ constant int64_t & ne00,
655
+ constant int64_t & ne01,
656
+ constant int64_t & ne02,
657
+ constant int64_t & ne03,
658
+ constant uint64_t & nb00,
659
+ constant uint64_t & nb01,
660
+ constant uint64_t & nb02,
661
+ constant uint64_t & nb03,
662
+ constant int64_t & ne0,
663
+ constant int64_t & ne1,
664
+ constant int64_t & ne2,
665
+ constant int64_t & ne3,
666
+ constant uint64_t & nb0,
667
+ constant uint64_t & nb1,
668
+ constant uint64_t & nb2,
669
+ constant uint64_t & nb3,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint3 tpitg[[thread_position_in_threadgroup]],
672
+ uint3 ntg[[threads_per_threadgroup]]) {
673
+ const int64_t i03 = tgpig[2];
674
+ const int64_t i02 = tgpig[1];
675
+ const int64_t i01 = tgpig[0];
676
+
677
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
678
+
679
+ const int64_t i3 = n / (ne2*ne1*ne0);
680
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
681
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
682
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
683
+
684
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
685
+
686
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
687
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
688
+ dst_data[i00] = src[0];
689
+ }
690
+ }
691
+
580
692
  kernel void kernel_cpy_f32_f16(
581
693
  device const float * src0,
582
694
  device half * dst,
@@ -671,6 +783,15 @@ typedef struct {
671
783
  half d; // super-block scale for quantized scales
672
784
  half dmin; // super-block scale for quantized mins
673
785
  } block_q2_k;
786
+ // 84 bytes / block
787
+
788
+ typedef struct {
789
+ uint8_t hmask[QK_K/8]; // quants - high bit
790
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
791
+ uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
792
+ half d; // super-block scale
793
+ } block_q3_k;
794
+ // 110 bytes / block
674
795
 
675
796
  typedef struct {
676
797
  half d; // super-block scale for quantized scales
@@ -678,6 +799,16 @@ typedef struct {
678
799
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
679
800
  uint8_t qs[QK_K/2]; // 4--bit quants
680
801
  } block_q4_k;
802
+ // 144 bytes / block
803
+
804
+ typedef struct {
805
+ half d; // super-block scale for quantized scales
806
+ half dmin; // super-block scale for quantized mins
807
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
808
+ uint8_t qh[QK_K/8]; // quants, high bit
809
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
810
+ } block_q5_k;
811
+ // 176 bytes / block
681
812
 
682
813
  typedef struct {
683
814
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -685,16 +816,19 @@ typedef struct {
685
816
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
686
817
  half d; // super-block scale
687
818
  } block_q6_k;
819
+ // 210 bytes / block
688
820
 
689
821
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
690
822
  uchar4 r;
691
823
  if (j < 4) {
692
- r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
693
- r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
824
+ r[0] = q[j+0] & 63;
825
+ r[2] = q[j+1] & 63;
826
+ r[1] = q[j+4] & 63;
827
+ r[3] = q[j+5] & 63;
694
828
  } else {
695
829
  r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
696
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
697
830
  r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
831
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
698
832
  r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
699
833
  }
700
834
  return r;
@@ -735,10 +869,65 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
735
869
  }
736
870
  }
737
871
 
872
+ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
873
+ assert(k % QK_K == 0);
874
+ const int nb = k / QK_K;
875
+
876
+ const uint16_t kmask1 = 0x0303;
877
+ const uint16_t kmask2 = 0x0f0f;
878
+
879
+ uint16_t aux[8];
880
+ thread const int8_t * scales = (thread const int8_t*)aux;
881
+
882
+ for (int i = 0; i < nb; i++) {
883
+
884
+ const float d_all = (float)(x[i].d);
885
+
886
+ device const uint8_t * q = x[i].qs;
887
+ device const uint8_t * h = x[i].hmask;
888
+ uint8_t m = 1;
889
+
890
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
891
+ aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
892
+ aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
893
+ aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
894
+ aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
895
+ aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
896
+ aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
897
+ aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
898
+ aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
899
+
900
+ int is = 0;
901
+ float dl;
902
+ for (int n = 0; n < QK_K; n += 128) {
903
+ int shift = 0;
904
+ for (int j = 0; j < 4; ++j) {
905
+
906
+ dl = d_all * (scales[is++] - 32);
907
+ for (int l = 0; l < 16; ++l) {
908
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
909
+ }
910
+
911
+ dl = d_all * (scales[is++] - 32);
912
+ for (int l = 0; l < 16; ++l) {
913
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
914
+ }
915
+
916
+ shift += 2;
917
+ m <<= 1;
918
+ }
919
+ q += 32;
920
+ }
921
+
922
+ }
923
+
924
+ }
925
+
738
926
  static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
739
927
  assert(k % QK_K == 0);
740
928
  const int nb = k / QK_K;
741
929
 
930
+
742
931
  for (int i = 0; i < nb; i++) {
743
932
 
744
933
  const float d = x[i].d;
@@ -760,6 +949,33 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
760
949
  }
761
950
  }
762
951
 
952
+ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
953
+ assert(k % QK_K == 0);
954
+ const int nb = k / QK_K;
955
+
956
+ for (int i = 0; i < nb; i++) {
957
+
958
+ const float d = (float)(x[i].d);
959
+ const float min = (float)(x[i].dmin);
960
+
961
+ device const uint8_t * ql = x[i].qs;
962
+ device const uint8_t * qh = x[i].qh;
963
+
964
+ int is = 0;
965
+ uint8_t u1 = 1, u2 = 2;
966
+ for (int j = 0; j < QK_K; j += 64) {
967
+ const uchar4 sc = get_scale_min_k4(is, x[i].scales);
968
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
969
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
970
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
971
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
972
+ ql += 32; is += 2;
973
+ u1 <<= 2; u2 <<= 2;
974
+ }
975
+ }
976
+
977
+ }
978
+
763
979
  static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
764
980
  assert(k % QK_K == 0);
765
981
  const int nb = k / QK_K;
@@ -808,6 +1024,22 @@ kernel void kernel_get_rows_q2_k(
808
1024
  (device float *) ((device char *) dst + i*nb1), ne00);
809
1025
  }
810
1026
 
1027
+ kernel void kernel_get_rows_q3_k(
1028
+ device const void * src0,
1029
+ device const int * src1,
1030
+ device float * dst,
1031
+ constant int64_t & ne00,
1032
+ constant uint64_t & nb01,
1033
+ constant uint64_t & nb1,
1034
+ uint tpig[[thread_position_in_grid]]) {
1035
+ const int i = tpig;
1036
+ const int r = ((device int32_t *) src1)[i];
1037
+
1038
+ dequantize_row_q3_k(
1039
+ (device const block_q3_k *) ((device char *) src0 + r*nb01),
1040
+ (device float *) ((device char *) dst + i*nb1), ne00);
1041
+ }
1042
+
811
1043
  kernel void kernel_get_rows_q4_k(
812
1044
  device const void * src0,
813
1045
  device const int * src1,
@@ -824,6 +1056,22 @@ kernel void kernel_get_rows_q4_k(
824
1056
  (device float *) ((device char *) dst + i*nb1), ne00);
825
1057
  }
826
1058
 
1059
+ kernel void kernel_get_rows_q5_k(
1060
+ device const void * src0,
1061
+ device const int * src1,
1062
+ device float * dst,
1063
+ constant int64_t & ne00,
1064
+ constant uint64_t & nb01,
1065
+ constant uint64_t & nb1,
1066
+ uint tpig[[thread_position_in_grid]]) {
1067
+ const int i = tpig;
1068
+ const int r = ((device int32_t *) src1)[i];
1069
+
1070
+ dequantize_row_q5_k(
1071
+ (device const block_q5_k *) ((device char *) src0 + r*nb01),
1072
+ (device float *) ((device char *) dst + i*nb1), ne00);
1073
+ }
1074
+
827
1075
  kernel void kernel_get_rows_q6_k(
828
1076
  device const void * src0,
829
1077
  device const int * src1,
@@ -847,20 +1095,10 @@ kernel void kernel_mul_mat_q2_k_f32(
847
1095
  device const float * src1,
848
1096
  device float * dst,
849
1097
  constant int64_t & ne00,
850
- constant int64_t & ne01,
851
- constant uint64_t & nb00,
852
- constant uint64_t & nb01,
853
- constant uint64_t & nb02,
854
1098
  constant int64_t & ne10,
855
- constant int64_t & ne11,
856
- constant uint64_t & nb10,
857
- constant uint64_t & nb11,
858
- constant uint64_t & nb12,
859
1099
  constant int64_t & ne0,
860
- constant int64_t & ne1,
861
1100
  threadgroup float * sum [[threadgroup(0)]],
862
1101
  uint2 tgpig[[threadgroup_position_in_grid]],
863
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
864
1102
  uint2 tpitg[[thread_position_in_threadgroup]],
865
1103
  uint2 tptg[[threads_per_threadgroup]]) {
866
1104
 
@@ -875,7 +1113,6 @@ kernel void kernel_mul_mat_q2_k_f32(
875
1113
  const int nth = tptg.x*tptg.y;
876
1114
  const int ith = tptg.y*tpitg.x + tpitg.y;
877
1115
 
878
-
879
1116
  const int tid = tpitg.y; // 0...16
880
1117
  const int il = tid/4; // 0...3
881
1118
  const int ir = tid%4; // 0...3
@@ -885,35 +1122,54 @@ kernel void kernel_mul_mat_q2_k_f32(
885
1122
  const int n = 8;
886
1123
  const int is = 4*il + (n*ir)/16;
887
1124
 
1125
+ const int y_offset = 64*il + n*ir;
1126
+ const int q_offset = 32*ip + n*ir;
1127
+
888
1128
  sum[ith] = 0.0f;
889
1129
 
890
1130
  float sumf = 0;
891
1131
  for (int i = tpitg.x; i < nb; i += tptg.x) {
892
1132
 
893
- device const uint8_t * q = x[i].qs + 32*ip + n*ir;
1133
+ device const uint8_t * q = x[i].qs + q_offset;
894
1134
  device const uint8_t * scales = x[i].scales + is;
895
1135
 
896
1136
  uint8_t d1 = scales[0] & 0xF;
897
- uint8_t m1 = scales[0] >> 4;
898
1137
  uint8_t d2 = scales[2] & 0xF;
1138
+ uint8_t m1 = scales[0] >> 4;
899
1139
  uint8_t m2 = scales[2] >> 4;
900
1140
 
901
- device const float * y = yy + i*QK_K + 64*il + n*ir;
1141
+ device const float * y = yy + i*QK_K + y_offset;
902
1142
 
903
- const float dall = (float)x[i].d;
904
- const float dmin = (float)x[i].dmin;
905
-
906
- float4 s = {0.f, 0.f, 0.f, 0.f};
1143
+ //float4 s = {0.f, 0.f, 0.f, 0.f};
1144
+ float2 s = {0.f, 0.f};
1145
+ float smin = 0;
907
1146
  for (int l = 0; l < n; ++l) {
908
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
909
- s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
1147
+ s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1148
+ s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1149
+ smin += y[l+ 0] * m1 + y[l+32] * m2;
910
1150
  }
911
- sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
912
1151
 
1152
+ const float dall = (float)x[i].d;
1153
+ const float dmin = (float)x[i].dmin;
1154
+
1155
+ sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
913
1156
 
914
1157
  }
915
1158
  sum[ith] = sumf;
916
1159
 
1160
+ //int mask1 = (ith%4 == 0);
1161
+ //int mask2 = (ith%16 == 0);
1162
+
1163
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1164
+ //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1165
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1166
+ //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1167
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1168
+ //if (ith == 0) {
1169
+ // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1170
+ // dst[r1*ne0 + r0] = sum[0];
1171
+ //}
1172
+
917
1173
  //
918
1174
  // Accumulate the sum from all threads in the threadgroup
919
1175
  // This version is slightly faster than the commented out one below,
@@ -932,19 +1188,109 @@ kernel void kernel_mul_mat_q2_k_f32(
932
1188
  for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
933
1189
  dst[r1*ne0 + r0] = sum[0];
934
1190
  }
1191
+ }
935
1192
 
936
- //// accumulate the sum from all threads in the threadgroup
937
- //threadgroup_barrier(mem_flags::mem_threadgroup);
938
- //for (uint i = nth/2; i > 0; i /= 2) {
939
- // if (ith < i) {
940
- // sum[ith] += sum[ith + i];
941
- // }
942
- // threadgroup_barrier(mem_flags::mem_threadgroup);
943
- //}
1193
+ kernel void kernel_mul_mat_q3_k_f32(
1194
+ device const void * src0,
1195
+ device const float * src1,
1196
+ device float * dst,
1197
+ constant int64_t & ne00,
1198
+ constant int64_t & ne10,
1199
+ constant int64_t & ne0,
1200
+ constant int64_t & ne1,
1201
+ threadgroup float * sum [[threadgroup(0)]],
1202
+ uint2 tgpig[[threadgroup_position_in_grid]],
1203
+ uint2 tpitg[[thread_position_in_threadgroup]],
1204
+ uint2 tptg[[threads_per_threadgroup]]) {
1205
+
1206
+ const uint16_t kmask1 = 0x0303;
1207
+ const uint16_t kmask2 = 0x0f0f;
1208
+
1209
+ const uint8_t m3 = 3;
1210
+ const int8_t m4 = 4;
1211
+
1212
+ const int nb = ne00/QK_K;
1213
+
1214
+ const int64_t r0 = tgpig.x;
1215
+ const int64_t r1 = tgpig.y;
1216
+
1217
+ device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1218
+ device const float * yy = (device const float *) src1 + r1*ne10;
1219
+
1220
+ const int nth = tptg.x*tptg.y;
1221
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1222
+
1223
+ const int tid = tpitg.y; // expecting 16
1224
+ const int ip = tid/8; // 0 or 1
1225
+ const int il = tid/2 - 4*ip; // 0...3
1226
+ const int ir = tid%2;
1227
+ const int n = 8;
1228
+ const int l0 = n*ir;
1229
+
1230
+ const uint8_t m = 1 << (4*ip + il);
1231
+
1232
+ const int shift = 2*il;
1233
+
1234
+ const uint16_t s_shift1 = 4*ip;
1235
+ const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1236
+ const int ik = 4 + (il%2);
1237
+
1238
+ const int q_offset = 32*ip + l0;
1239
+ const int y_offset = 128*ip + 32*il + l0;
1240
+
1241
+ //float sumf = 0;
1242
+ float sumf1 = 0, sumf2 = 0;
1243
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1244
+
1245
+ const float d_all = (float)(x[i].d);
1246
+
1247
+ device const uint8_t * q = x[i].qs + q_offset;
1248
+ device const uint8_t * h = x[i].hmask + l0;
1249
+ device const float * y = yy + i * QK_K + y_offset;
1250
+
1251
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1252
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1253
+
1254
+ float s = 0;
1255
+ for (int l = 0; l < n; ++l) {
1256
+ s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1257
+ }
1258
+ float d = d_all * s;
1259
+ sumf1 += d * scales[0];
1260
+ sumf2 += d;
1261
+ //sumf += d_all * s * (scales[0] - 32);
1262
+
1263
+ s = 0;
1264
+ for (int l = 0; l < n; ++l) {
1265
+ s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1266
+ }
1267
+ d = d_all * s;
1268
+ sumf1 += d * scales[1];
1269
+ sumf2 += d;
1270
+ //sumf += d_all * s * (scales[1] - 32);
1271
+
1272
+ }
1273
+
1274
+ //sum[ith] = sumf;
1275
+ sum[ith] = sumf1 - 32.f*sumf2;
1276
+
1277
+ //
1278
+ // Accumulate the sum from all threads in the threadgroup
1279
+ //
1280
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1281
+ if (ith%4 == 0) {
1282
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1283
+ }
1284
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1285
+ if (ith%16 == 0) {
1286
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1287
+ }
1288
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1289
+ if (ith == 0) {
1290
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1291
+ dst[r1*ne0 + r0] = sum[0];
1292
+ }
944
1293
 
945
- //if (ith == 0) {
946
- // dst[r1*ne0 + r0] = sum[0];
947
- //}
948
1294
  }
949
1295
 
950
1296
  kernel void kernel_mul_mat_q4_k_f32(
@@ -952,23 +1298,17 @@ kernel void kernel_mul_mat_q4_k_f32(
952
1298
  device const float * src1,
953
1299
  device float * dst,
954
1300
  constant int64_t & ne00,
955
- constant int64_t & ne01,
956
- constant uint64_t & nb00,
957
- constant uint64_t & nb01,
958
- constant uint64_t & nb02,
959
1301
  constant int64_t & ne10,
960
- constant int64_t & ne11,
961
- constant uint64_t & nb10,
962
- constant uint64_t & nb11,
963
- constant uint64_t & nb12,
964
1302
  constant int64_t & ne0,
965
- constant int64_t & ne1,
966
1303
  threadgroup float * sum [[threadgroup(0)]],
967
1304
  uint2 tgpig[[threadgroup_position_in_grid]],
968
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
969
1305
  uint2 tpitg[[thread_position_in_threadgroup]],
970
1306
  uint2 tptg[[threads_per_threadgroup]]) {
971
1307
 
1308
+ const uint16_t kmask1 = 0x3f3f;
1309
+ const uint16_t kmask2 = 0x0f0f;
1310
+ const uint16_t kmask3 = 0xc0c0;
1311
+
972
1312
  const int nb = ne00/QK_K;
973
1313
 
974
1314
  const int64_t r0 = tgpig.x;
@@ -977,37 +1317,55 @@ kernel void kernel_mul_mat_q4_k_f32(
977
1317
  device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
978
1318
  device const float * yy = (device const float *) src1 + r1*ne10;
979
1319
 
980
- const uint nth = tptg.x*tptg.y;
981
- const uint ith = tptg.y*tpitg.x + tpitg.y;
1320
+ const int nth = tptg.x*tptg.y;
1321
+ const int ith = tptg.y*tpitg.x + tpitg.y;
982
1322
 
983
1323
  const int tid = tpitg.y; // 0...16
984
1324
  const int il = tid/4; // 0...3
985
- const int ir = tid%4; // 0...3
986
- const int n = 8;
987
- const int is = 2*il;
1325
+ const int ir = tid - 4*il;// 0...3
1326
+ const int n = 4;
1327
+
1328
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1329
+ const int in = il%2;
1330
+
1331
+ const int l0 = n*(2*ir + in);
1332
+ const int q_offset = 32*im + l0;
1333
+ const int y_offset = 64*im + l0;
988
1334
 
989
1335
  sum[ith] = 0.0f;
990
1336
 
1337
+ uchar2 sc1, sc2, sc3, sc4;
1338
+
991
1339
  float sumf = 0;
992
1340
  for (int i = tpitg.x; i < nb; i += tptg.x) {
993
1341
 
994
- device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
995
- device const float * y = yy + i*QK_K + 64*il + n*ir;
996
- device const uint8_t * scales = (x + i)->scales;
1342
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1343
+ device const uint8_t * q2 = q1 + 64;
1344
+ device const float * y1 = yy + i*QK_K + y_offset;
1345
+ device const float * y2 = y1 + 128;
997
1346
 
998
1347
  const float dall = (float)((x + i)->d);
999
1348
  const float dmin = (float)((x + i)->dmin);
1000
1349
 
1001
- const uchar4 sc = get_scale_min_k4(is, scales);
1350
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1351
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1352
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1353
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1354
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1002
1355
 
1003
1356
  float4 s = {0.f, 0.f, 0.f, 0.f};
1357
+ float smin = 0;
1004
1358
  for (int l = 0; l < n; ++l) {
1005
- s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
1006
- s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
1359
+
1360
+ s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1361
+ s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1362
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1363
+
1007
1364
  }
1008
- sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
1365
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1009
1366
 
1010
1367
  }
1368
+
1011
1369
  sum[ith] = sumf;
1012
1370
 
1013
1371
  //
@@ -1043,25 +1401,114 @@ kernel void kernel_mul_mat_q4_k_f32(
1043
1401
  //}
1044
1402
  }
1045
1403
 
1404
+ kernel void kernel_mul_mat_q5_k_f32(
1405
+ device const void * src0,
1406
+ device const float * src1,
1407
+ device float * dst,
1408
+ constant int64_t & ne00,
1409
+ constant int64_t & ne10,
1410
+ constant int64_t & ne0,
1411
+ threadgroup float * sum [[threadgroup(0)]],
1412
+ uint2 tgpig[[threadgroup_position_in_grid]],
1413
+ uint2 tpitg[[thread_position_in_threadgroup]],
1414
+ uint2 tptg[[threads_per_threadgroup]]) {
1415
+
1416
+ const uint16_t kmask1 = 0x3f3f;
1417
+ const uint16_t kmask2 = 0x0f0f;
1418
+ const uint16_t kmask3 = 0xc0c0;
1419
+
1420
+ const int nb = ne00/QK_K;
1421
+
1422
+ const int64_t r0 = tgpig.x;
1423
+ const int64_t r1 = tgpig.y;
1424
+
1425
+ device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1426
+ device const float * yy = (device const float *) src1 + r1*ne10;
1427
+
1428
+ const int nth = tptg.x*tptg.y;
1429
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1430
+
1431
+ const int tid = tpitg.y; // 0...16
1432
+ const int il = tid/4; // 0...3
1433
+ const int ir = tid - 4*il;// 0...3
1434
+ const int n = 4;
1435
+
1436
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1437
+ const int in = il%2;
1438
+
1439
+ const int l0 = n*(2*ir + in);
1440
+ const int q_offset = 32*im + l0;
1441
+ const int y_offset = 64*im + l0;
1442
+
1443
+ const uint8_t hm1 = 1u << (2*im);
1444
+ const uint8_t hm2 = hm1 << 1;
1445
+ const uint8_t hm3 = hm1 << 4;
1446
+ const uint8_t hm4 = hm2 << 4;
1447
+
1448
+ uchar2 sc1, sc2, sc3, sc4;
1449
+
1450
+ float sumf = 0;
1451
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1452
+
1453
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1454
+ device const uint8_t * q2 = q1 + 64;
1455
+ device const uint8_t * qh = (x + i)->qh + l0;
1456
+ device const float * y1 = yy + i*QK_K + y_offset;
1457
+ device const float * y2 = y1 + 128;
1458
+
1459
+ const float dall = (float)((x + i)->d);
1460
+ const float dmin = (float)((x + i)->dmin);
1461
+
1462
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1463
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1464
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1465
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1466
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1467
+
1468
+ float4 s = {0.f, 0.f, 0.f, 0.f};
1469
+ float smin = 0;
1470
+ for (int l = 0; l < n; ++l) {
1471
+
1472
+ s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1473
+ s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1474
+ s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1475
+ s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1476
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1477
+
1478
+ }
1479
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1480
+
1481
+ }
1482
+ sum[ith] = sumf;
1483
+
1484
+ //
1485
+ // Accumulate the sum from all threads in the threadgroup
1486
+ //
1487
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1488
+ if (ith%4 == 0) {
1489
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1490
+ }
1491
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1492
+ if (ith%16 == 0) {
1493
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1494
+ }
1495
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1496
+ if (ith == 0) {
1497
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1498
+ dst[r1*ne0 + r0] = sum[0];
1499
+ }
1500
+
1501
+ }
1502
+
1046
1503
  kernel void kernel_mul_mat_q6_k_f32(
1047
1504
  device const void * src0,
1048
1505
  device const float * src1,
1049
1506
  device float * dst,
1050
1507
  constant int64_t & ne00,
1051
- constant int64_t & ne01,
1052
- constant uint64_t & nb00,
1053
- constant uint64_t & nb01,
1054
- constant uint64_t & nb02,
1055
1508
  constant int64_t & ne10,
1056
- constant int64_t & ne11,
1057
- constant uint64_t & nb10,
1058
- constant uint64_t & nb11,
1059
- constant uint64_t & nb12,
1060
1509
  constant int64_t & ne0,
1061
- constant int64_t & ne1,
1062
1510
  threadgroup float * sum [[threadgroup(0)]],
1063
1511
  uint2 tgpig[[threadgroup_position_in_grid]],
1064
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
1065
1512
  uint2 tpitg[[thread_position_in_threadgroup]],
1066
1513
  uint2 tptg[[threads_per_threadgroup]]) {
1067
1514
 
@@ -1078,24 +1525,29 @@ kernel void kernel_mul_mat_q6_k_f32(
1078
1525
  device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1079
1526
  device const float * yy = (device const float *) src1 + r1*ne10;
1080
1527
 
1081
- const uint nth = tptg.x*tptg.y;
1082
- const uint ith = tptg.y*tpitg.x + tpitg.y;
1528
+ const int nth = tptg.x*tptg.y;
1529
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1083
1530
 
1084
- const int step = QK_K / tptg.y; // we expect this to be 16
1085
- const int iqs = step * tpitg.y; // 0...240 in steps of 16
1531
+ // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1532
+ const int iqs = 16 * tpitg.y;
1086
1533
  const int ip = iqs / 128; // 0 or 1
1087
1534
  const int il = (iqs - 128*ip)/16; // 0...7
1088
1535
  const int n = 4;
1089
- const int is = 8*ip + (n*il)/16;
1536
+ const int l0 = n*il;
1537
+ const int is = 8*ip + l0/16;
1538
+
1539
+ const int y_offset = 128*ip + l0;
1540
+ const int q_offset_l = 64*ip + l0;
1541
+ const int q_offset_h = 32*ip + l0;
1090
1542
 
1091
1543
  float sumf = 0;
1092
1544
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1093
1545
 
1094
- device const uint8_t * ql = x[i].ql + 64*ip + n*il;
1095
- device const uint8_t * qh = x[i].qh + 32*ip + n*il;
1546
+ device const uint8_t * ql = x[i].ql + q_offset_l;
1547
+ device const uint8_t * qh = x[i].qh + q_offset_h;
1096
1548
  device const int8_t * sc = x[i].scales + is;
1097
1549
 
1098
- device const float * y = yy + i * QK_K + 128*ip + n*il;
1550
+ device const float * y = yy + i * QK_K + y_offset;
1099
1551
 
1100
1552
  const float dall = x[i].d;
1101
1553