llama_cpp 0.2.0 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
256
256
  (device float *) ((device char *) dst + i*nb1), ne00);
257
257
  }
258
258
 
259
+ kernel void kernel_norm(
260
+ device const void * src0,
261
+ device float * dst,
262
+ constant int64_t & ne00,
263
+ constant uint64_t & nb01,
264
+ constant float & eps,
265
+ threadgroup float * sum [[threadgroup(0)]],
266
+ uint tgpig[[threadgroup_position_in_grid]],
267
+ uint tpitg[[thread_position_in_threadgroup]],
268
+ uint ntg[[threads_per_threadgroup]]) {
269
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
270
+ // MEAN
271
+ // parallel sum
272
+ sum[tpitg] = 0.0f;
273
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
274
+ sum[tpitg] += x[i00];
275
+ }
276
+ // reduce
277
+ threadgroup_barrier(mem_flags::mem_threadgroup);
278
+ for (uint i = ntg/2; i > 0; i /= 2) {
279
+ if (tpitg < i) {
280
+ sum[tpitg] += sum[tpitg + i];
281
+ }
282
+ threadgroup_barrier(mem_flags::mem_threadgroup);
283
+ }
284
+ // broadcast
285
+ if (tpitg == 0) {
286
+ sum[0] /= ne00;
287
+ }
288
+ threadgroup_barrier(mem_flags::mem_threadgroup);
289
+ const float mean = sum[0];
290
+
291
+ // recenter
292
+ device float * y = dst + tgpig*ne00;
293
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
294
+ y[i00] = x[i00] - mean;
295
+ }
296
+
297
+ // VARIANCE
298
+ // parallel sum
299
+ sum[tpitg] = 0.0f;
300
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
301
+ sum[tpitg] += y[i00] * y[i00];
302
+ }
303
+ // reduce
304
+ threadgroup_barrier(mem_flags::mem_threadgroup);
305
+ for (uint i = ntg/2; i > 0; i /= 2) {
306
+ if (tpitg < i) {
307
+ sum[tpitg] += sum[tpitg + i];
308
+ }
309
+ threadgroup_barrier(mem_flags::mem_threadgroup);
310
+ }
311
+ // broadcast
312
+ if (tpitg == 0) {
313
+ sum[0] /= ne00;
314
+ }
315
+ threadgroup_barrier(mem_flags::mem_threadgroup);
316
+ const float variance = sum[0];
317
+
318
+ const float scale = 1.0f/sqrt(variance + eps);
319
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
320
+ y[i00] = y[i00] * scale;
321
+ }
322
+ }
323
+
324
+
259
325
  kernel void kernel_rms_norm(
260
326
  device const void * src0,
261
327
  device float * dst,
@@ -304,34 +370,22 @@ kernel void kernel_mul_mat_q4_0_f32(
304
370
  device const float * src1,
305
371
  device float * dst,
306
372
  constant int64_t & ne00,
307
- constant int64_t & ne01,
308
- constant uint64_t & nb00,
309
- constant uint64_t & nb01,
310
- constant uint64_t & nb02,
311
373
  constant int64_t & ne10,
312
- constant int64_t & ne11,
313
- constant uint64_t & nb10,
314
- constant uint64_t & nb11,
315
- constant uint64_t & nb12,
316
374
  constant int64_t & ne0,
317
- constant int64_t & ne1,
318
375
  threadgroup float * sum [[threadgroup(0)]],
319
376
  uint2 tgpig[[threadgroup_position_in_grid]],
320
- uint2 tpig[[thread_position_in_grid]],
321
377
  uint2 tpitg[[thread_position_in_threadgroup]],
322
378
  uint2 tptg[[threads_per_threadgroup]]) {
323
379
  const int nb = ne00/QK4_0;
324
380
 
325
- const int8_t m8 = 8;
326
-
327
381
  const int64_t r0 = tgpig.x;
328
382
  const int64_t r1 = tgpig.y;
329
383
 
330
384
  device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
331
385
  device const float * y = (device const float *) src1 + r1*ne10;
332
386
 
333
- const uint nth = tptg.x*tptg.y;
334
- const uint ith = tptg.y*tpitg.x + tpitg.y;
387
+ const int nth = tptg.x*tptg.y;
388
+ const int ith = tptg.y*tpitg.x + tpitg.y;
335
389
 
336
390
  const int ix = tpitg.y/4; // 0 or 1
337
391
  const int iy = tpitg.y - 4*ix; // 0...3
@@ -351,47 +405,32 @@ kernel void kernel_mul_mat_q4_0_f32(
351
405
 
352
406
  for (int j = 0; j < 4; ++j) {
353
407
 
354
- acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8);
355
- acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8);
408
+ acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
409
+ acc[1] += yl[j] + yl[j+16];
356
410
 
357
411
  }
358
412
 
359
- sumf += d * (acc[0] + acc[1]);
413
+ sumf += d * (acc[0] - 8.f*acc[1]);
360
414
  }
361
415
 
362
416
  sum[ith] = sumf;
363
417
 
364
418
  //
365
419
  // Accumulate the sum from all threads in the threadgroup
366
- // This version is slightly faster than the commented out one below,
367
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
368
420
  //
369
421
  threadgroup_barrier(mem_flags::mem_threadgroup);
370
422
  if (ith%4 == 0) {
371
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
423
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
372
424
  }
373
425
  threadgroup_barrier(mem_flags::mem_threadgroup);
374
426
  if (ith%16 == 0) {
375
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
427
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
376
428
  }
377
429
  threadgroup_barrier(mem_flags::mem_threadgroup);
378
430
  if (ith == 0) {
379
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
431
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
380
432
  dst[r1*ne0 + r0] = sum[0];
381
433
  }
382
-
383
- //// accumulate the sum from all threads in the threadgroup
384
- //threadgroup_barrier(mem_flags::mem_threadgroup);
385
- //for (uint i = nth/2; i > 0; i /= 2) {
386
- // if (ith < i) {
387
- // sum[ith] += sum[ith + i];
388
- // }
389
- // threadgroup_barrier(mem_flags::mem_threadgroup);
390
- //}
391
-
392
- //if (ith == 0) {
393
- // dst[r1*ne0 + r0] = sum[0];
394
- //}
395
434
  }
396
435
 
397
436
  kernel void kernel_mul_mat_q4_1_f32(
@@ -399,20 +438,10 @@ kernel void kernel_mul_mat_q4_1_f32(
399
438
  device const float * src1,
400
439
  device float * dst,
401
440
  constant int64_t & ne00,
402
- constant int64_t & ne01,
403
- constant uint64_t & nb00,
404
- constant uint64_t & nb01,
405
- constant uint64_t & nb02,
406
441
  constant int64_t & ne10,
407
- constant int64_t & ne11,
408
- constant uint64_t & nb10,
409
- constant uint64_t & nb11,
410
- constant uint64_t & nb12,
411
442
  constant int64_t & ne0,
412
- constant int64_t & ne1,
413
443
  threadgroup float * sum [[threadgroup(0)]],
414
444
  uint2 tgpig[[threadgroup_position_in_grid]],
415
- uint2 tpig[[thread_position_in_grid]],
416
445
  uint2 tpitg[[thread_position_in_threadgroup]],
417
446
  uint2 tptg[[threads_per_threadgroup]]) {
418
447
  const int nb = ne00/QK4_1;
@@ -460,11 +489,11 @@ kernel void kernel_mul_mat_q4_1_f32(
460
489
  //
461
490
  threadgroup_barrier(mem_flags::mem_threadgroup);
462
491
  if (ith%4 == 0) {
463
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
492
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
464
493
  }
465
494
  threadgroup_barrier(mem_flags::mem_threadgroup);
466
495
  if (ith%16 == 0) {
467
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
496
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
468
497
  }
469
498
  threadgroup_barrier(mem_flags::mem_threadgroup);
470
499
  if (ith == 0) {
@@ -522,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
522
551
  }
523
552
  }
524
553
 
554
+ kernel void kernel_alibi_f32(
555
+ device const float * src0,
556
+ device float * dst,
557
+ constant int64_t & ne00,
558
+ constant int64_t & ne01,
559
+ constant int64_t & ne02,
560
+ constant int64_t & ne03,
561
+ constant uint64_t & nb00,
562
+ constant uint64_t & nb01,
563
+ constant uint64_t & nb02,
564
+ constant uint64_t & nb03,
565
+ constant int64_t & ne0,
566
+ constant int64_t & ne1,
567
+ constant int64_t & ne2,
568
+ constant int64_t & ne3,
569
+ constant uint64_t & nb0,
570
+ constant uint64_t & nb1,
571
+ constant uint64_t & nb2,
572
+ constant uint64_t & nb3,
573
+ constant float & m0,
574
+ uint3 tgpig[[threadgroup_position_in_grid]],
575
+ uint3 tpitg[[thread_position_in_threadgroup]],
576
+ uint3 ntg[[threads_per_threadgroup]]) {
577
+ const int64_t i03 = tgpig[2];
578
+ const int64_t i02 = tgpig[1];
579
+ const int64_t i01 = tgpig[0];
580
+
581
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
582
+
583
+ const int64_t i3 = n / (ne2*ne1*ne0);
584
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
585
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
586
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
587
+
588
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
589
+ float m_k = pow(m0, i2 + 1);
590
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
591
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
592
+ dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
593
+ }
594
+ }
595
+
525
596
  kernel void kernel_rope(
526
597
  device const void * src0,
527
598
  device float * dst,
@@ -577,6 +648,47 @@ kernel void kernel_rope(
577
648
  }
578
649
  }
579
650
 
651
+ kernel void kernel_cpy_f16_f16(
652
+ device const half * src0,
653
+ device half * dst,
654
+ constant int64_t & ne00,
655
+ constant int64_t & ne01,
656
+ constant int64_t & ne02,
657
+ constant int64_t & ne03,
658
+ constant uint64_t & nb00,
659
+ constant uint64_t & nb01,
660
+ constant uint64_t & nb02,
661
+ constant uint64_t & nb03,
662
+ constant int64_t & ne0,
663
+ constant int64_t & ne1,
664
+ constant int64_t & ne2,
665
+ constant int64_t & ne3,
666
+ constant uint64_t & nb0,
667
+ constant uint64_t & nb1,
668
+ constant uint64_t & nb2,
669
+ constant uint64_t & nb3,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint3 tpitg[[thread_position_in_threadgroup]],
672
+ uint3 ntg[[threads_per_threadgroup]]) {
673
+ const int64_t i03 = tgpig[2];
674
+ const int64_t i02 = tgpig[1];
675
+ const int64_t i01 = tgpig[0];
676
+
677
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
678
+
679
+ const int64_t i3 = n / (ne2*ne1*ne0);
680
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
681
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
682
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
683
+
684
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
685
+
686
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
687
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
688
+ dst_data[i00] = src[0];
689
+ }
690
+ }
691
+
580
692
  kernel void kernel_cpy_f32_f16(
581
693
  device const float * src0,
582
694
  device half * dst,
@@ -671,6 +783,15 @@ typedef struct {
671
783
  half d; // super-block scale for quantized scales
672
784
  half dmin; // super-block scale for quantized mins
673
785
  } block_q2_k;
786
+ // 84 bytes / block
787
+
788
+ typedef struct {
789
+ uint8_t hmask[QK_K/8]; // quants - high bit
790
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
791
+ uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
792
+ half d; // super-block scale
793
+ } block_q3_k;
794
+ // 110 bytes / block
674
795
 
675
796
  typedef struct {
676
797
  half d; // super-block scale for quantized scales
@@ -678,6 +799,16 @@ typedef struct {
678
799
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
679
800
  uint8_t qs[QK_K/2]; // 4--bit quants
680
801
  } block_q4_k;
802
+ // 144 bytes / block
803
+
804
+ typedef struct {
805
+ half d; // super-block scale for quantized scales
806
+ half dmin; // super-block scale for quantized mins
807
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
808
+ uint8_t qh[QK_K/8]; // quants, high bit
809
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
810
+ } block_q5_k;
811
+ // 176 bytes / block
681
812
 
682
813
  typedef struct {
683
814
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -685,16 +816,19 @@ typedef struct {
685
816
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
686
817
  half d; // super-block scale
687
818
  } block_q6_k;
819
+ // 210 bytes / block
688
820
 
689
821
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
690
822
  uchar4 r;
691
823
  if (j < 4) {
692
- r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
693
- r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
824
+ r[0] = q[j+0] & 63;
825
+ r[2] = q[j+1] & 63;
826
+ r[1] = q[j+4] & 63;
827
+ r[3] = q[j+5] & 63;
694
828
  } else {
695
829
  r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
696
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
697
830
  r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
831
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
698
832
  r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
699
833
  }
700
834
  return r;
@@ -735,10 +869,65 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
735
869
  }
736
870
  }
737
871
 
872
+ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
873
+ assert(k % QK_K == 0);
874
+ const int nb = k / QK_K;
875
+
876
+ const uint16_t kmask1 = 0x0303;
877
+ const uint16_t kmask2 = 0x0f0f;
878
+
879
+ uint16_t aux[8];
880
+ thread const int8_t * scales = (thread const int8_t*)aux;
881
+
882
+ for (int i = 0; i < nb; i++) {
883
+
884
+ const float d_all = (float)(x[i].d);
885
+
886
+ device const uint8_t * q = x[i].qs;
887
+ device const uint8_t * h = x[i].hmask;
888
+ uint8_t m = 1;
889
+
890
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
891
+ aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
892
+ aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
893
+ aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
894
+ aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
895
+ aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
896
+ aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
897
+ aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
898
+ aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
899
+
900
+ int is = 0;
901
+ float dl;
902
+ for (int n = 0; n < QK_K; n += 128) {
903
+ int shift = 0;
904
+ for (int j = 0; j < 4; ++j) {
905
+
906
+ dl = d_all * (scales[is++] - 32);
907
+ for (int l = 0; l < 16; ++l) {
908
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
909
+ }
910
+
911
+ dl = d_all * (scales[is++] - 32);
912
+ for (int l = 0; l < 16; ++l) {
913
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
914
+ }
915
+
916
+ shift += 2;
917
+ m <<= 1;
918
+ }
919
+ q += 32;
920
+ }
921
+
922
+ }
923
+
924
+ }
925
+
738
926
  static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
739
927
  assert(k % QK_K == 0);
740
928
  const int nb = k / QK_K;
741
929
 
930
+
742
931
  for (int i = 0; i < nb; i++) {
743
932
 
744
933
  const float d = x[i].d;
@@ -760,6 +949,33 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
760
949
  }
761
950
  }
762
951
 
952
+ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
953
+ assert(k % QK_K == 0);
954
+ const int nb = k / QK_K;
955
+
956
+ for (int i = 0; i < nb; i++) {
957
+
958
+ const float d = (float)(x[i].d);
959
+ const float min = (float)(x[i].dmin);
960
+
961
+ device const uint8_t * ql = x[i].qs;
962
+ device const uint8_t * qh = x[i].qh;
963
+
964
+ int is = 0;
965
+ uint8_t u1 = 1, u2 = 2;
966
+ for (int j = 0; j < QK_K; j += 64) {
967
+ const uchar4 sc = get_scale_min_k4(is, x[i].scales);
968
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
969
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
970
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
971
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
972
+ ql += 32; is += 2;
973
+ u1 <<= 2; u2 <<= 2;
974
+ }
975
+ }
976
+
977
+ }
978
+
763
979
  static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
764
980
  assert(k % QK_K == 0);
765
981
  const int nb = k / QK_K;
@@ -808,6 +1024,22 @@ kernel void kernel_get_rows_q2_k(
808
1024
  (device float *) ((device char *) dst + i*nb1), ne00);
809
1025
  }
810
1026
 
1027
+ kernel void kernel_get_rows_q3_k(
1028
+ device const void * src0,
1029
+ device const int * src1,
1030
+ device float * dst,
1031
+ constant int64_t & ne00,
1032
+ constant uint64_t & nb01,
1033
+ constant uint64_t & nb1,
1034
+ uint tpig[[thread_position_in_grid]]) {
1035
+ const int i = tpig;
1036
+ const int r = ((device int32_t *) src1)[i];
1037
+
1038
+ dequantize_row_q3_k(
1039
+ (device const block_q3_k *) ((device char *) src0 + r*nb01),
1040
+ (device float *) ((device char *) dst + i*nb1), ne00);
1041
+ }
1042
+
811
1043
  kernel void kernel_get_rows_q4_k(
812
1044
  device const void * src0,
813
1045
  device const int * src1,
@@ -824,6 +1056,22 @@ kernel void kernel_get_rows_q4_k(
824
1056
  (device float *) ((device char *) dst + i*nb1), ne00);
825
1057
  }
826
1058
 
1059
+ kernel void kernel_get_rows_q5_k(
1060
+ device const void * src0,
1061
+ device const int * src1,
1062
+ device float * dst,
1063
+ constant int64_t & ne00,
1064
+ constant uint64_t & nb01,
1065
+ constant uint64_t & nb1,
1066
+ uint tpig[[thread_position_in_grid]]) {
1067
+ const int i = tpig;
1068
+ const int r = ((device int32_t *) src1)[i];
1069
+
1070
+ dequantize_row_q5_k(
1071
+ (device const block_q5_k *) ((device char *) src0 + r*nb01),
1072
+ (device float *) ((device char *) dst + i*nb1), ne00);
1073
+ }
1074
+
827
1075
  kernel void kernel_get_rows_q6_k(
828
1076
  device const void * src0,
829
1077
  device const int * src1,
@@ -847,20 +1095,10 @@ kernel void kernel_mul_mat_q2_k_f32(
847
1095
  device const float * src1,
848
1096
  device float * dst,
849
1097
  constant int64_t & ne00,
850
- constant int64_t & ne01,
851
- constant uint64_t & nb00,
852
- constant uint64_t & nb01,
853
- constant uint64_t & nb02,
854
1098
  constant int64_t & ne10,
855
- constant int64_t & ne11,
856
- constant uint64_t & nb10,
857
- constant uint64_t & nb11,
858
- constant uint64_t & nb12,
859
1099
  constant int64_t & ne0,
860
- constant int64_t & ne1,
861
1100
  threadgroup float * sum [[threadgroup(0)]],
862
1101
  uint2 tgpig[[threadgroup_position_in_grid]],
863
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
864
1102
  uint2 tpitg[[thread_position_in_threadgroup]],
865
1103
  uint2 tptg[[threads_per_threadgroup]]) {
866
1104
 
@@ -875,7 +1113,6 @@ kernel void kernel_mul_mat_q2_k_f32(
875
1113
  const int nth = tptg.x*tptg.y;
876
1114
  const int ith = tptg.y*tpitg.x + tpitg.y;
877
1115
 
878
-
879
1116
  const int tid = tpitg.y; // 0...16
880
1117
  const int il = tid/4; // 0...3
881
1118
  const int ir = tid%4; // 0...3
@@ -885,35 +1122,54 @@ kernel void kernel_mul_mat_q2_k_f32(
885
1122
  const int n = 8;
886
1123
  const int is = 4*il + (n*ir)/16;
887
1124
 
1125
+ const int y_offset = 64*il + n*ir;
1126
+ const int q_offset = 32*ip + n*ir;
1127
+
888
1128
  sum[ith] = 0.0f;
889
1129
 
890
1130
  float sumf = 0;
891
1131
  for (int i = tpitg.x; i < nb; i += tptg.x) {
892
1132
 
893
- device const uint8_t * q = x[i].qs + 32*ip + n*ir;
1133
+ device const uint8_t * q = x[i].qs + q_offset;
894
1134
  device const uint8_t * scales = x[i].scales + is;
895
1135
 
896
1136
  uint8_t d1 = scales[0] & 0xF;
897
- uint8_t m1 = scales[0] >> 4;
898
1137
  uint8_t d2 = scales[2] & 0xF;
1138
+ uint8_t m1 = scales[0] >> 4;
899
1139
  uint8_t m2 = scales[2] >> 4;
900
1140
 
901
- device const float * y = yy + i*QK_K + 64*il + n*ir;
1141
+ device const float * y = yy + i*QK_K + y_offset;
902
1142
 
903
- const float dall = (float)x[i].d;
904
- const float dmin = (float)x[i].dmin;
905
-
906
- float4 s = {0.f, 0.f, 0.f, 0.f};
1143
+ //float4 s = {0.f, 0.f, 0.f, 0.f};
1144
+ float2 s = {0.f, 0.f};
1145
+ float smin = 0;
907
1146
  for (int l = 0; l < n; ++l) {
908
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
909
- s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
1147
+ s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1148
+ s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1149
+ smin += y[l+ 0] * m1 + y[l+32] * m2;
910
1150
  }
911
- sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
912
1151
 
1152
+ const float dall = (float)x[i].d;
1153
+ const float dmin = (float)x[i].dmin;
1154
+
1155
+ sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
913
1156
 
914
1157
  }
915
1158
  sum[ith] = sumf;
916
1159
 
1160
+ //int mask1 = (ith%4 == 0);
1161
+ //int mask2 = (ith%16 == 0);
1162
+
1163
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1164
+ //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1165
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1166
+ //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1167
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1168
+ //if (ith == 0) {
1169
+ // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1170
+ // dst[r1*ne0 + r0] = sum[0];
1171
+ //}
1172
+
917
1173
  //
918
1174
  // Accumulate the sum from all threads in the threadgroup
919
1175
  // This version is slightly faster than the commented out one below,
@@ -932,19 +1188,109 @@ kernel void kernel_mul_mat_q2_k_f32(
932
1188
  for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
933
1189
  dst[r1*ne0 + r0] = sum[0];
934
1190
  }
1191
+ }
935
1192
 
936
- //// accumulate the sum from all threads in the threadgroup
937
- //threadgroup_barrier(mem_flags::mem_threadgroup);
938
- //for (uint i = nth/2; i > 0; i /= 2) {
939
- // if (ith < i) {
940
- // sum[ith] += sum[ith + i];
941
- // }
942
- // threadgroup_barrier(mem_flags::mem_threadgroup);
943
- //}
1193
+ kernel void kernel_mul_mat_q3_k_f32(
1194
+ device const void * src0,
1195
+ device const float * src1,
1196
+ device float * dst,
1197
+ constant int64_t & ne00,
1198
+ constant int64_t & ne10,
1199
+ constant int64_t & ne0,
1200
+ constant int64_t & ne1,
1201
+ threadgroup float * sum [[threadgroup(0)]],
1202
+ uint2 tgpig[[threadgroup_position_in_grid]],
1203
+ uint2 tpitg[[thread_position_in_threadgroup]],
1204
+ uint2 tptg[[threads_per_threadgroup]]) {
1205
+
1206
+ const uint16_t kmask1 = 0x0303;
1207
+ const uint16_t kmask2 = 0x0f0f;
1208
+
1209
+ const uint8_t m3 = 3;
1210
+ const int8_t m4 = 4;
1211
+
1212
+ const int nb = ne00/QK_K;
1213
+
1214
+ const int64_t r0 = tgpig.x;
1215
+ const int64_t r1 = tgpig.y;
1216
+
1217
+ device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1218
+ device const float * yy = (device const float *) src1 + r1*ne10;
1219
+
1220
+ const int nth = tptg.x*tptg.y;
1221
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1222
+
1223
+ const int tid = tpitg.y; // expecting 16
1224
+ const int ip = tid/8; // 0 or 1
1225
+ const int il = tid/2 - 4*ip; // 0...3
1226
+ const int ir = tid%2;
1227
+ const int n = 8;
1228
+ const int l0 = n*ir;
1229
+
1230
+ const uint8_t m = 1 << (4*ip + il);
1231
+
1232
+ const int shift = 2*il;
1233
+
1234
+ const uint16_t s_shift1 = 4*ip;
1235
+ const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1236
+ const int ik = 4 + (il%2);
1237
+
1238
+ const int q_offset = 32*ip + l0;
1239
+ const int y_offset = 128*ip + 32*il + l0;
1240
+
1241
+ //float sumf = 0;
1242
+ float sumf1 = 0, sumf2 = 0;
1243
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1244
+
1245
+ const float d_all = (float)(x[i].d);
1246
+
1247
+ device const uint8_t * q = x[i].qs + q_offset;
1248
+ device const uint8_t * h = x[i].hmask + l0;
1249
+ device const float * y = yy + i * QK_K + y_offset;
1250
+
1251
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1252
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1253
+
1254
+ float s = 0;
1255
+ for (int l = 0; l < n; ++l) {
1256
+ s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1257
+ }
1258
+ float d = d_all * s;
1259
+ sumf1 += d * scales[0];
1260
+ sumf2 += d;
1261
+ //sumf += d_all * s * (scales[0] - 32);
1262
+
1263
+ s = 0;
1264
+ for (int l = 0; l < n; ++l) {
1265
+ s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1266
+ }
1267
+ d = d_all * s;
1268
+ sumf1 += d * scales[1];
1269
+ sumf2 += d;
1270
+ //sumf += d_all * s * (scales[1] - 32);
1271
+
1272
+ }
1273
+
1274
+ //sum[ith] = sumf;
1275
+ sum[ith] = sumf1 - 32.f*sumf2;
1276
+
1277
+ //
1278
+ // Accumulate the sum from all threads in the threadgroup
1279
+ //
1280
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1281
+ if (ith%4 == 0) {
1282
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1283
+ }
1284
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1285
+ if (ith%16 == 0) {
1286
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1287
+ }
1288
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1289
+ if (ith == 0) {
1290
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1291
+ dst[r1*ne0 + r0] = sum[0];
1292
+ }
944
1293
 
945
- //if (ith == 0) {
946
- // dst[r1*ne0 + r0] = sum[0];
947
- //}
948
1294
  }
949
1295
 
950
1296
  kernel void kernel_mul_mat_q4_k_f32(
@@ -952,23 +1298,17 @@ kernel void kernel_mul_mat_q4_k_f32(
952
1298
  device const float * src1,
953
1299
  device float * dst,
954
1300
  constant int64_t & ne00,
955
- constant int64_t & ne01,
956
- constant uint64_t & nb00,
957
- constant uint64_t & nb01,
958
- constant uint64_t & nb02,
959
1301
  constant int64_t & ne10,
960
- constant int64_t & ne11,
961
- constant uint64_t & nb10,
962
- constant uint64_t & nb11,
963
- constant uint64_t & nb12,
964
1302
  constant int64_t & ne0,
965
- constant int64_t & ne1,
966
1303
  threadgroup float * sum [[threadgroup(0)]],
967
1304
  uint2 tgpig[[threadgroup_position_in_grid]],
968
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
969
1305
  uint2 tpitg[[thread_position_in_threadgroup]],
970
1306
  uint2 tptg[[threads_per_threadgroup]]) {
971
1307
 
1308
+ const uint16_t kmask1 = 0x3f3f;
1309
+ const uint16_t kmask2 = 0x0f0f;
1310
+ const uint16_t kmask3 = 0xc0c0;
1311
+
972
1312
  const int nb = ne00/QK_K;
973
1313
 
974
1314
  const int64_t r0 = tgpig.x;
@@ -977,37 +1317,55 @@ kernel void kernel_mul_mat_q4_k_f32(
977
1317
  device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
978
1318
  device const float * yy = (device const float *) src1 + r1*ne10;
979
1319
 
980
- const uint nth = tptg.x*tptg.y;
981
- const uint ith = tptg.y*tpitg.x + tpitg.y;
1320
+ const int nth = tptg.x*tptg.y;
1321
+ const int ith = tptg.y*tpitg.x + tpitg.y;
982
1322
 
983
1323
  const int tid = tpitg.y; // 0...16
984
1324
  const int il = tid/4; // 0...3
985
- const int ir = tid%4; // 0...3
986
- const int n = 8;
987
- const int is = 2*il;
1325
+ const int ir = tid - 4*il;// 0...3
1326
+ const int n = 4;
1327
+
1328
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1329
+ const int in = il%2;
1330
+
1331
+ const int l0 = n*(2*ir + in);
1332
+ const int q_offset = 32*im + l0;
1333
+ const int y_offset = 64*im + l0;
988
1334
 
989
1335
  sum[ith] = 0.0f;
990
1336
 
1337
+ uchar2 sc1, sc2, sc3, sc4;
1338
+
991
1339
  float sumf = 0;
992
1340
  for (int i = tpitg.x; i < nb; i += tptg.x) {
993
1341
 
994
- device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
995
- device const float * y = yy + i*QK_K + 64*il + n*ir;
996
- device const uint8_t * scales = (x + i)->scales;
1342
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1343
+ device const uint8_t * q2 = q1 + 64;
1344
+ device const float * y1 = yy + i*QK_K + y_offset;
1345
+ device const float * y2 = y1 + 128;
997
1346
 
998
1347
  const float dall = (float)((x + i)->d);
999
1348
  const float dmin = (float)((x + i)->dmin);
1000
1349
 
1001
- const uchar4 sc = get_scale_min_k4(is, scales);
1350
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1351
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1352
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1353
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1354
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1002
1355
 
1003
1356
  float4 s = {0.f, 0.f, 0.f, 0.f};
1357
+ float smin = 0;
1004
1358
  for (int l = 0; l < n; ++l) {
1005
- s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
1006
- s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
1359
+
1360
+ s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1361
+ s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1362
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1363
+
1007
1364
  }
1008
- sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
1365
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1009
1366
 
1010
1367
  }
1368
+
1011
1369
  sum[ith] = sumf;
1012
1370
 
1013
1371
  //
@@ -1043,25 +1401,114 @@ kernel void kernel_mul_mat_q4_k_f32(
1043
1401
  //}
1044
1402
  }
1045
1403
 
1404
+ kernel void kernel_mul_mat_q5_k_f32(
1405
+ device const void * src0,
1406
+ device const float * src1,
1407
+ device float * dst,
1408
+ constant int64_t & ne00,
1409
+ constant int64_t & ne10,
1410
+ constant int64_t & ne0,
1411
+ threadgroup float * sum [[threadgroup(0)]],
1412
+ uint2 tgpig[[threadgroup_position_in_grid]],
1413
+ uint2 tpitg[[thread_position_in_threadgroup]],
1414
+ uint2 tptg[[threads_per_threadgroup]]) {
1415
+
1416
+ const uint16_t kmask1 = 0x3f3f;
1417
+ const uint16_t kmask2 = 0x0f0f;
1418
+ const uint16_t kmask3 = 0xc0c0;
1419
+
1420
+ const int nb = ne00/QK_K;
1421
+
1422
+ const int64_t r0 = tgpig.x;
1423
+ const int64_t r1 = tgpig.y;
1424
+
1425
+ device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1426
+ device const float * yy = (device const float *) src1 + r1*ne10;
1427
+
1428
+ const int nth = tptg.x*tptg.y;
1429
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1430
+
1431
+ const int tid = tpitg.y; // 0...16
1432
+ const int il = tid/4; // 0...3
1433
+ const int ir = tid - 4*il;// 0...3
1434
+ const int n = 4;
1435
+
1436
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1437
+ const int in = il%2;
1438
+
1439
+ const int l0 = n*(2*ir + in);
1440
+ const int q_offset = 32*im + l0;
1441
+ const int y_offset = 64*im + l0;
1442
+
1443
+ const uint8_t hm1 = 1u << (2*im);
1444
+ const uint8_t hm2 = hm1 << 1;
1445
+ const uint8_t hm3 = hm1 << 4;
1446
+ const uint8_t hm4 = hm2 << 4;
1447
+
1448
+ uchar2 sc1, sc2, sc3, sc4;
1449
+
1450
+ float sumf = 0;
1451
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1452
+
1453
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1454
+ device const uint8_t * q2 = q1 + 64;
1455
+ device const uint8_t * qh = (x + i)->qh + l0;
1456
+ device const float * y1 = yy + i*QK_K + y_offset;
1457
+ device const float * y2 = y1 + 128;
1458
+
1459
+ const float dall = (float)((x + i)->d);
1460
+ const float dmin = (float)((x + i)->dmin);
1461
+
1462
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1463
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1464
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1465
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1466
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1467
+
1468
+ float4 s = {0.f, 0.f, 0.f, 0.f};
1469
+ float smin = 0;
1470
+ for (int l = 0; l < n; ++l) {
1471
+
1472
+ s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1473
+ s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1474
+ s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1475
+ s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1476
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1477
+
1478
+ }
1479
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1480
+
1481
+ }
1482
+ sum[ith] = sumf;
1483
+
1484
+ //
1485
+ // Accumulate the sum from all threads in the threadgroup
1486
+ //
1487
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1488
+ if (ith%4 == 0) {
1489
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1490
+ }
1491
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1492
+ if (ith%16 == 0) {
1493
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1494
+ }
1495
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1496
+ if (ith == 0) {
1497
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1498
+ dst[r1*ne0 + r0] = sum[0];
1499
+ }
1500
+
1501
+ }
1502
+
1046
1503
  kernel void kernel_mul_mat_q6_k_f32(
1047
1504
  device const void * src0,
1048
1505
  device const float * src1,
1049
1506
  device float * dst,
1050
1507
  constant int64_t & ne00,
1051
- constant int64_t & ne01,
1052
- constant uint64_t & nb00,
1053
- constant uint64_t & nb01,
1054
- constant uint64_t & nb02,
1055
1508
  constant int64_t & ne10,
1056
- constant int64_t & ne11,
1057
- constant uint64_t & nb10,
1058
- constant uint64_t & nb11,
1059
- constant uint64_t & nb12,
1060
1509
  constant int64_t & ne0,
1061
- constant int64_t & ne1,
1062
1510
  threadgroup float * sum [[threadgroup(0)]],
1063
1511
  uint2 tgpig[[threadgroup_position_in_grid]],
1064
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
1065
1512
  uint2 tpitg[[thread_position_in_threadgroup]],
1066
1513
  uint2 tptg[[threads_per_threadgroup]]) {
1067
1514
 
@@ -1078,24 +1525,29 @@ kernel void kernel_mul_mat_q6_k_f32(
1078
1525
  device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1079
1526
  device const float * yy = (device const float *) src1 + r1*ne10;
1080
1527
 
1081
- const uint nth = tptg.x*tptg.y;
1082
- const uint ith = tptg.y*tpitg.x + tpitg.y;
1528
+ const int nth = tptg.x*tptg.y;
1529
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1083
1530
 
1084
- const int step = QK_K / tptg.y; // we expect this to be 16
1085
- const int iqs = step * tpitg.y; // 0...240 in steps of 16
1531
+ // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1532
+ const int iqs = 16 * tpitg.y;
1086
1533
  const int ip = iqs / 128; // 0 or 1
1087
1534
  const int il = (iqs - 128*ip)/16; // 0...7
1088
1535
  const int n = 4;
1089
- const int is = 8*ip + (n*il)/16;
1536
+ const int l0 = n*il;
1537
+ const int is = 8*ip + l0/16;
1538
+
1539
+ const int y_offset = 128*ip + l0;
1540
+ const int q_offset_l = 64*ip + l0;
1541
+ const int q_offset_h = 32*ip + l0;
1090
1542
 
1091
1543
  float sumf = 0;
1092
1544
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1093
1545
 
1094
- device const uint8_t * ql = x[i].ql + 64*ip + n*il;
1095
- device const uint8_t * qh = x[i].qh + 32*ip + n*il;
1546
+ device const uint8_t * ql = x[i].ql + q_offset_l;
1547
+ device const uint8_t * qh = x[i].qh + q_offset_h;
1096
1548
  device const int8_t * sc = x[i].scales + is;
1097
1549
 
1098
- device const float * y = yy + i * QK_K + 128*ip + n*il;
1550
+ device const float * y = yy + i * QK_K + y_offset;
1099
1551
 
1100
1552
  const float dall = x[i].d;
1101
1553