llama_cpp 0.3.7 → 0.3.8

Sign up to get free protection for your applications and to get access to all the features.
@@ -18,47 +18,6 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
- static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
22
- const int qk = QK4_0;
23
-
24
- assert(k % qk == 0);
25
-
26
- const int nb = k / qk;
27
-
28
- for (int i = 0; i < nb; i++) {
29
- const half d = x[i].d;
30
-
31
- for (int j = 0; j < qk/2; ++j) {
32
- const int x0 = (x[i].qs[j] & 0x0F) - 8;
33
- const int x1 = (x[i].qs[j] >> 4) - 8;
34
-
35
- y[i*qk + j + 0 ] = x0*d;
36
- y[i*qk + j + qk/2] = x1*d;
37
- }
38
- }
39
- }
40
-
41
- static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
42
- const int qk = QK4_1;
43
-
44
- assert(k % qk == 0);
45
-
46
- const int nb = k / qk;
47
-
48
- for (int i = 0; i < nb; i++) {
49
- const half d = x[i].d;
50
- const half m = x[i].m;
51
-
52
- for (int j = 0; j < qk/2; ++j) {
53
- const int x0 = (x[i].qs[j] & 0x0F);
54
- const int x1 = (x[i].qs[j] >> 4);
55
-
56
- y[i*qk + j + 0 ] = x0*d + m;
57
- y[i*qk + j + qk/2] = x1*d + m;
58
- }
59
- }
60
- }
61
-
62
21
  kernel void kernel_add(
63
22
  device const float * src0,
64
23
  device const float * src1,
@@ -219,54 +178,6 @@ kernel void kernel_diag_mask_inf(
219
178
  }
220
179
  }
221
180
 
222
- kernel void kernel_get_rows_f16(
223
- device const void * src0,
224
- device const int * src1,
225
- device float * dst,
226
- constant int64_t & ne00,
227
- constant uint64_t & nb01,
228
- constant uint64_t & nb1,
229
- uint tpig[[thread_position_in_grid]]) {
230
- const int i = tpig;
231
- const int r = ((device int32_t *) src1)[i];
232
-
233
- for (int j = 0; j < ne00; j++) {
234
- dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
235
- }
236
- }
237
-
238
- kernel void kernel_get_rows_q4_0(
239
- device const void * src0,
240
- device const int * src1,
241
- device float * dst,
242
- constant int64_t & ne00,
243
- constant uint64_t & nb01,
244
- constant uint64_t & nb1,
245
- uint tpig[[thread_position_in_grid]]) {
246
- const int i = tpig;
247
- const int r = ((device int32_t *) src1)[i];
248
-
249
- dequantize_row_q4_0(
250
- (device const block_q4_0 *) ((device char *) src0 + r*nb01),
251
- (device float *) ((device char *) dst + i*nb1), ne00);
252
- }
253
-
254
- kernel void kernel_get_rows_q4_1(
255
- device const void * src0,
256
- device const int * src1,
257
- device float * dst,
258
- constant int64_t & ne00,
259
- constant uint64_t & nb01,
260
- constant uint64_t & nb1,
261
- uint tpig[[thread_position_in_grid]]) {
262
- const int i = tpig;
263
- const int r = ((device int32_t *) src1)[i];
264
-
265
- dequantize_row_q4_1(
266
- (device const block_q4_1 *) ((device char *) src0 + r*nb01),
267
- (device float *) ((device char *) dst + i*nb1), ne00);
268
- }
269
-
270
181
  kernel void kernel_norm(
271
182
  device const void * src0,
272
183
  device float * dst,
@@ -432,14 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
432
343
  // N_DST, so this is another explicit assumption of the implementation.
433
344
  template<typename block_q_type, int nr, int nsg, int nw>
434
345
  void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
435
- int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
436
- uint2 tgpig, uint tiisg, uint sgitg) {
346
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
347
+ uint3 tgpig, uint tiisg, uint sgitg) {
437
348
  const int nb = ne00/QK4_0;
438
349
  const int r0 = tgpig.x;
439
350
  const int r1 = tgpig.y;
351
+ const int im = tgpig.z;
440
352
  const int first_row = (r0 * nsg + sgitg) * nr;
441
- device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
442
- device const float * y = (device const float *) src1 + r1*ne10;
353
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
354
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
443
356
  float yl[16]; // src1 vector cache
444
357
  float sumf[nr]={0.f};
445
358
 
@@ -470,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
470
383
  for (int row = 0; row < nr; ++row) {
471
384
  const float tot = simd_sum(sumf[row]);
472
385
  if (tiisg == 0 && first_row + row < ne01) {
473
- dst[r1*ne0 + first_row + row] = tot;
386
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
474
387
  }
475
388
  }
476
389
  }
@@ -480,13 +393,17 @@ kernel void kernel_mul_mat_q4_0_f32(
480
393
  device const float * src1,
481
394
  device float * dst,
482
395
  constant int64_t & ne00,
483
- constant int64_t & ne10,
484
- constant int64_t & ne0,
485
396
  constant int64_t & ne01[[buffer(4)]],
486
- uint2 tgpig[[threadgroup_position_in_grid]],
397
+ constant int64_t & ne02[[buffer(5)]],
398
+ constant int64_t & ne10[[buffer(9)]],
399
+ constant int64_t & ne12[[buffer(11)]],
400
+ constant int64_t & ne0[[buffer(15)]],
401
+ constant int64_t & ne1[[buffer(16)]],
402
+ constant uint & gqa[[buffer(17)]],
403
+ uint3 tgpig[[threadgroup_position_in_grid]],
487
404
  uint tiisg[[thread_index_in_simdgroup]],
488
405
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
406
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
490
407
  }
491
408
 
492
409
  kernel void kernel_mul_mat_q4_1_f32(
@@ -494,13 +411,17 @@ kernel void kernel_mul_mat_q4_1_f32(
494
411
  device const float * src1,
495
412
  device float * dst,
496
413
  constant int64_t & ne00,
497
- constant int64_t & ne10,
498
- constant int64_t & ne0,
499
414
  constant int64_t & ne01[[buffer(4)]],
500
- uint2 tgpig[[threadgroup_position_in_grid]],
415
+ constant int64_t & ne02[[buffer(5)]],
416
+ constant int64_t & ne10[[buffer(9)]],
417
+ constant int64_t & ne12[[buffer(11)]],
418
+ constant int64_t & ne0[[buffer(15)]],
419
+ constant int64_t & ne1[[buffer(16)]],
420
+ constant uint & gqa[[buffer(17)]],
421
+ uint3 tgpig[[threadgroup_position_in_grid]],
501
422
  uint tiisg[[thread_index_in_simdgroup]],
502
423
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
503
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
424
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
504
425
  }
505
426
 
506
427
  kernel void kernel_mul_mat_f16_f32(
@@ -869,354 +790,6 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
869
790
  return r;
870
791
  }
871
792
 
872
- //========================================== dequantization =============================
873
-
874
- static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
875
- assert(k % QK_K == 0);
876
- const int nb = k / QK_K;
877
-
878
- for (int i = 0; i < nb; i++) {
879
-
880
- const float d = x[i].d;
881
- const float min = x[i].dmin;
882
-
883
- device const uint8_t * q = x[i].qs;
884
-
885
- #if QK_K == 256
886
- int is = 0;
887
- float dl, ml;
888
- for (int n = 0; n < QK_K; n += 128) {
889
- int shift = 0;
890
- for (int j = 0; j < 4; ++j) {
891
-
892
- uint8_t sc = x[i].scales[is++];
893
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
894
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
895
-
896
- sc = x[i].scales[is++];
897
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
898
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
899
-
900
- shift += 2;
901
- }
902
- q += 32;
903
- }
904
- #else
905
- float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
906
- float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
907
- float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
908
- float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
909
- for (int l = 0; l < 16; ++l) {
910
- y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
911
- y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
912
- y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
913
- y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
914
- }
915
- y += QK_K;
916
- #endif
917
-
918
- }
919
- }
920
-
921
- static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
922
- assert(k % QK_K == 0);
923
- const int nb = k / QK_K;
924
-
925
- #if QK_K == 256
926
-
927
- const uint16_t kmask1 = 0x0303;
928
- const uint16_t kmask2 = 0x0f0f;
929
-
930
- uint16_t aux[8];
931
- thread const int8_t * scales = (thread const int8_t*)aux;
932
-
933
- for (int i = 0; i < nb; i++) {
934
-
935
- const float d_all = (float)(x[i].d);
936
-
937
- device const uint8_t * q = x[i].qs;
938
- device const uint8_t * h = x[i].hmask;
939
- uint8_t m = 1;
940
-
941
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
942
- aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
943
- aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
944
- aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
945
- aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
946
- aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
947
- aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
948
- aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
949
- aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
950
-
951
- int is = 0;
952
- float dl;
953
- for (int n = 0; n < QK_K; n += 128) {
954
- int shift = 0;
955
- for (int j = 0; j < 4; ++j) {
956
-
957
- dl = d_all * (scales[is++] - 32);
958
- for (int l = 0; l < 16; ++l) {
959
- *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
960
- }
961
-
962
- dl = d_all * (scales[is++] - 32);
963
- for (int l = 0; l < 16; ++l) {
964
- *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
965
- }
966
-
967
- shift += 2;
968
- m <<= 1;
969
- }
970
- q += 32;
971
- }
972
- }
973
- #else
974
- for (int i = 0; i < nb; i++) {
975
-
976
- const float d_all = (float)(x[i].d);
977
-
978
- device const uint8_t * q = x[i].qs;
979
- device const uint8_t * hm = x[i].hmask;
980
-
981
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
982
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
983
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
984
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
985
-
986
- for (int l = 0; l < 8; ++l) {
987
- uint8_t h = hm[l];
988
- y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
989
- y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
990
- y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
991
- y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
992
- y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
993
- y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
994
- y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
995
- y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
996
- }
997
- y += QK_K;
998
- }
999
- #endif
1000
-
1001
- }
1002
-
1003
- static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
1004
- assert(k % QK_K == 0);
1005
- const int nb = k / QK_K;
1006
-
1007
- for (int i = 0; i < nb; i++) {
1008
-
1009
- device const uint8_t * q = x[i].qs;
1010
-
1011
- #if QK_K == 256
1012
- const float d = x[i].d;
1013
- const float min = x[i].dmin;
1014
-
1015
- device const uint8_t * scales = x[i].scales;
1016
-
1017
- int is = 0;
1018
- for (int j = 0; j < QK_K; j += 64) {
1019
- const uchar4 sc = get_scale_min_k4(is, scales);
1020
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
1021
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
1022
- for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
1023
- for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
1024
- q += 32; is += 2;
1025
- }
1026
- #else
1027
- device const uint8_t * s = x[i].scales;
1028
- device const half2 * dh = (device const half2 *)x[i].d;
1029
- const float2 d = (float2)dh[0];
1030
- const float d1 = d[0] * (s[0] & 0xF);
1031
- const float d2 = d[0] * (s[1] & 0xF);
1032
- const float m1 = d[1] * (s[0] >> 4);
1033
- const float m2 = d[1] * (s[1] >> 4);
1034
- for (int l = 0; l < 32; ++l) {
1035
- y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1036
- y[l+32] = d2 * (q[l] >> 4) - m2;
1037
- }
1038
- y += QK_K;
1039
- #endif
1040
-
1041
- }
1042
- }
1043
-
1044
- static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
1045
- assert(k % QK_K == 0);
1046
- const int nb = k / QK_K;
1047
-
1048
- #if QK_K == 256
1049
- for (int i = 0; i < nb; i++) {
1050
-
1051
- const float d = (float)(x[i].d);
1052
- const float min = (float)(x[i].dmin);
1053
-
1054
- device const uint8_t * ql = x[i].qs;
1055
- device const uint8_t * qh = x[i].qh;
1056
-
1057
- int is = 0;
1058
- uint8_t u1 = 1, u2 = 2;
1059
- for (int j = 0; j < QK_K; j += 64) {
1060
- const uchar4 sc = get_scale_min_k4(is, x[i].scales);
1061
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
1062
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
1063
- for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
1064
- for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
1065
- ql += 32; is += 2;
1066
- u1 <<= 2; u2 <<= 2;
1067
- }
1068
- }
1069
- #else
1070
- for (int i = 0; i < nb; i++) {
1071
-
1072
- const float d = (float)x[i].d;
1073
-
1074
- device const uint8_t * ql = x[i].qs;
1075
- device const uint8_t * qh = x[i].qh;
1076
- device const int8_t * sc = x[i].scales;
1077
-
1078
- for (int l = 0; l < 8; ++l) {
1079
- y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1080
- y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1081
- y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1082
- y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1083
- y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1084
- y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1085
- y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1086
- y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1087
- }
1088
- y += QK_K;
1089
- }
1090
- #endif
1091
-
1092
- }
1093
-
1094
- static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
1095
- assert(k % QK_K == 0);
1096
- const int nb = k / QK_K;
1097
-
1098
- for (int i = 0; i < nb; i++) {
1099
-
1100
- device const uint8_t * ql = x[i].ql;
1101
- device const uint8_t * qh = x[i].qh;
1102
- device const int8_t * sc = x[i].scales;
1103
-
1104
- const float d = x[i].d;
1105
-
1106
- #if QK_K == 256
1107
- for (int n = 0; n < QK_K; n += 128) {
1108
- for (int l = 0; l < 32; ++l) {
1109
- int is = l/16;
1110
- const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1111
- const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1112
- const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1113
- const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1114
- y[l + 0] = d * sc[is + 0] * q1;
1115
- y[l + 32] = d * sc[is + 2] * q2;
1116
- y[l + 64] = d * sc[is + 4] * q3;
1117
- y[l + 96] = d * sc[is + 6] * q4;
1118
- }
1119
- y += 128;
1120
- ql += 64;
1121
- qh += 32;
1122
- sc += 8;
1123
- }
1124
- #else
1125
- for (int l = 0; l < 16; ++l) {
1126
- const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1127
- const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1128
- const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1129
- const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1130
- y[l+ 0] = d * sc[0] * q1;
1131
- y[l+16] = d * sc[1] * q2;
1132
- y[l+32] = d * sc[2] * q3;
1133
- y[l+48] = d * sc[3] * q4;
1134
- }
1135
- y += 64;
1136
- #endif
1137
- }
1138
- }
1139
-
1140
- kernel void kernel_get_rows_q2_K(
1141
- device const void * src0,
1142
- device const int * src1,
1143
- device float * dst,
1144
- constant int64_t & ne00,
1145
- constant uint64_t & nb01,
1146
- constant uint64_t & nb1,
1147
- uint tpig[[thread_position_in_grid]]) {
1148
- const int i = tpig;
1149
- const int r = ((device int32_t *) src1)[i];
1150
-
1151
- dequantize_row_q2_K(
1152
- (device const block_q2_K *) ((device char *) src0 + r*nb01),
1153
- (device float *) ((device char *) dst + i*nb1), ne00);
1154
- }
1155
-
1156
- kernel void kernel_get_rows_q3_K(
1157
- device const void * src0,
1158
- device const int * src1,
1159
- device float * dst,
1160
- constant int64_t & ne00,
1161
- constant uint64_t & nb01,
1162
- constant uint64_t & nb1,
1163
- uint tpig[[thread_position_in_grid]]) {
1164
- const int i = tpig;
1165
- const int r = ((device int32_t *) src1)[i];
1166
-
1167
- dequantize_row_q3_K(
1168
- (device const block_q3_K *) ((device char *) src0 + r*nb01),
1169
- (device float *) ((device char *) dst + i*nb1), ne00);
1170
- }
1171
-
1172
- kernel void kernel_get_rows_q4_K(
1173
- device const void * src0,
1174
- device const int * src1,
1175
- device float * dst,
1176
- constant int64_t & ne00,
1177
- constant uint64_t & nb01,
1178
- constant uint64_t & nb1,
1179
- uint tpig[[thread_position_in_grid]]) {
1180
- const int i = tpig;
1181
- const int r = ((device int32_t *) src1)[i];
1182
-
1183
- dequantize_row_q4_K(
1184
- (device const block_q4_K *) ((device char *) src0 + r*nb01),
1185
- (device float *) ((device char *) dst + i*nb1), ne00);
1186
- }
1187
-
1188
- kernel void kernel_get_rows_q5_K(
1189
- device const void * src0,
1190
- device const int * src1,
1191
- device float * dst,
1192
- constant int64_t & ne00,
1193
- constant uint64_t & nb01,
1194
- constant uint64_t & nb1,
1195
- uint tpig[[thread_position_in_grid]]) {
1196
- const int i = tpig;
1197
- const int r = ((device int32_t *) src1)[i];
1198
-
1199
- dequantize_row_q5_K(
1200
- (device const block_q5_K *) ((device char *) src0 + r*nb01),
1201
- (device float *) ((device char *) dst + i*nb1), ne00);
1202
- }
1203
-
1204
- kernel void kernel_get_rows_q6_K(
1205
- device const void * src0,
1206
- device const int * src1,
1207
- device float * dst,
1208
- constant int64_t & ne00,
1209
- constant uint64_t & nb01,
1210
- constant uint64_t & nb1,
1211
- uint tpig[[thread_position_in_grid]]) {
1212
- const int i = tpig;
1213
- const int r = ((device int32_t *) src1)[i];
1214
-
1215
- dequantize_row_q6_K(
1216
- (device const block_q6_K *) ((device char *) src0 + r*nb01),
1217
- (device float *) ((device char *) dst + i*nb1), ne00);
1218
- }
1219
-
1220
793
  //====================================== dot products =========================
1221
794
 
1222
795
  kernel void kernel_mul_mat_q2_K_f32(
@@ -1224,21 +797,27 @@ kernel void kernel_mul_mat_q2_K_f32(
1224
797
  device const float * src1,
1225
798
  device float * dst,
1226
799
  constant int64_t & ne00,
1227
- constant int64_t & ne10,
1228
- constant int64_t & ne0,
1229
800
  constant int64_t & ne01[[buffer(4)]],
1230
- uint2 tgpig[[threadgroup_position_in_grid]],
801
+ constant int64_t & ne02[[buffer(5)]],
802
+ constant int64_t & ne10[[buffer(9)]],
803
+ constant int64_t & ne12[[buffer(11)]],
804
+ constant int64_t & ne0[[buffer(15)]],
805
+ constant int64_t & ne1[[buffer(16)]],
806
+ constant uint & gqa[[buffer(17)]],
807
+ uint3 tgpig[[threadgroup_position_in_grid]],
1231
808
  uint tiisg[[thread_index_in_simdgroup]],
1232
809
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1233
810
 
1234
811
  const int nb = ne00/QK_K;
1235
812
  const int r0 = tgpig.x;
1236
813
  const int r1 = tgpig.y;
814
+ const int r2 = tgpig.z;
1237
815
 
1238
816
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1239
817
  const int ib_row = first_row * nb;
1240
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1241
- device const float * y = (device const float *) src1 + r1*ne10;
818
+ const uint offset0 = r2/gqa*(nb*ne0);
819
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
820
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1242
821
  float yl[32];
1243
822
  float sumf[N_DST]={0.f}, all_sum;
1244
823
 
@@ -1351,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1351
930
  for (int row = 0; row < N_DST; ++row) {
1352
931
  all_sum = simd_sum(sumf[row]);
1353
932
  if (tiisg == 0) {
1354
- dst[r1*ne0 + first_row + row] = all_sum;
933
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1355
934
  }
1356
935
  }
1357
936
  }
@@ -1362,10 +941,14 @@ kernel void kernel_mul_mat_q3_K_f32(
1362
941
  device const float * src1,
1363
942
  device float * dst,
1364
943
  constant int64_t & ne00,
1365
- constant int64_t & ne10,
1366
- constant int64_t & ne0,
1367
- constant int64_t & ne1,
1368
- uint2 tgpig[[threadgroup_position_in_grid]],
944
+ constant int64_t & ne01[[buffer(4)]],
945
+ constant int64_t & ne02[[buffer(5)]],
946
+ constant int64_t & ne10[[buffer(9)]],
947
+ constant int64_t & ne12[[buffer(11)]],
948
+ constant int64_t & ne0[[buffer(15)]],
949
+ constant int64_t & ne1[[buffer(16)]],
950
+ constant uint & gqa[[buffer(17)]],
951
+ uint3 tgpig[[threadgroup_position_in_grid]],
1369
952
  uint tiisg[[thread_index_in_simdgroup]],
1370
953
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1371
954
 
@@ -1373,11 +956,12 @@ kernel void kernel_mul_mat_q3_K_f32(
1373
956
 
1374
957
  const int64_t r0 = tgpig.x;
1375
958
  const int64_t r1 = tgpig.y;
959
+ const int64_t r2 = tgpig.z;
1376
960
 
1377
961
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1378
-
1379
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1380
- device const float * yy = (device const float *) src1 + r1*ne10;
962
+ const uint offset0 = r2/gqa*(nb*ne0);
963
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
964
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1381
965
 
1382
966
  float yl[16];
1383
967
 
@@ -1465,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1465
1049
  const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1466
1050
  const float tot = simd_sum(sumf);
1467
1051
  if (tiisg == 0) {
1468
- dst[r1*ne0 + first_row + row] = tot;
1052
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1469
1053
  }
1470
1054
  }
1471
1055
  }
@@ -1475,10 +1059,14 @@ kernel void kernel_mul_mat_q3_K_f32(
1475
1059
  device const float * src1,
1476
1060
  device float * dst,
1477
1061
  constant int64_t & ne00,
1478
- constant int64_t & ne10,
1479
- constant int64_t & ne0,
1480
- constant int64_t & ne1,
1481
- uint2 tgpig[[threadgroup_position_in_grid]],
1062
+ constant int64_t & ne01[[buffer(4)]],
1063
+ constant int64_t & ne02[[buffer(5)]],
1064
+ constant int64_t & ne10[[buffer(9)]],
1065
+ constant int64_t & ne12[[buffer(11)]],
1066
+ constant int64_t & ne0[[buffer(15)]],
1067
+ constant int64_t & ne1[[buffer(16)]],
1068
+ constant uint & gqa[[buffer(17)]],
1069
+ uint3 tgpig[[threadgroup_position_in_grid]],
1482
1070
  uint tiisg[[thread_index_in_simdgroup]],
1483
1071
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1484
1072
 
@@ -1486,11 +1074,12 @@ kernel void kernel_mul_mat_q3_K_f32(
1486
1074
 
1487
1075
  const int64_t r0 = tgpig.x;
1488
1076
  const int64_t r1 = tgpig.y;
1077
+ const int64_t r2 = tgpig.z;
1489
1078
 
1490
1079
  const int row = 2 * r0 + sgitg;
1491
-
1492
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1493
- device const float * yy = (device const float *) src1 + r1*ne10;
1080
+ const uint offset0 = r2/gqa*(nb*ne0);
1081
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1082
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1494
1083
  const int ix = tiisg/4;
1495
1084
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1496
1085
  const int im = il/8; // 0, 0, 1, 1
@@ -1529,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1529
1118
 
1530
1119
  const float tot = simd_sum(sumf);
1531
1120
  if (tiisg == 0) {
1532
- dst[r1*ne0 + row] = tot;
1121
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1533
1122
  }
1534
1123
 
1535
1124
  }
@@ -1541,10 +1130,14 @@ kernel void kernel_mul_mat_q4_K_f32(
1541
1130
  device const float * src1,
1542
1131
  device float * dst,
1543
1132
  constant int64_t & ne00,
1544
- constant int64_t & ne10,
1545
- constant int64_t & ne0,
1546
1133
  constant int64_t & ne01[[buffer(4)]],
1547
- uint2 tgpig[[threadgroup_position_in_grid]],
1134
+ constant int64_t & ne02[[buffer(5)]],
1135
+ constant int64_t & ne10[[buffer(9)]],
1136
+ constant int64_t & ne12[[buffer(11)]],
1137
+ constant int64_t & ne0[[buffer(15)]],
1138
+ constant int64_t & ne1[[buffer(16)]],
1139
+ constant uint & gqa[[buffer(17)]],
1140
+ uint3 tgpig[[threadgroup_position_in_grid]],
1548
1141
  uint tiisg[[thread_index_in_simdgroup]],
1549
1142
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1550
1143
 
@@ -1560,10 +1153,12 @@ kernel void kernel_mul_mat_q4_K_f32(
1560
1153
  const int nb = ne00/QK_K;
1561
1154
  const int r0 = tgpig.x;
1562
1155
  const int r1 = tgpig.y;
1156
+ const int r2 = tgpig.z;
1563
1157
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1564
1158
  const int ib_row = first_row * nb;
1565
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1566
- device const float * y = (device const float *) src1 + r1*ne10;
1159
+ const uint offset0 = r2/gqa*(nb*ne0);
1160
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1161
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1567
1162
  float yl[16];
1568
1163
  float yh[16];
1569
1164
  float sumf[N_DST]={0.f}, all_sum;
@@ -1630,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1630
1225
  for (int row = 0; row < N_DST; ++row) {
1631
1226
  all_sum = simd_sum(sumf[row]);
1632
1227
  if (tiisg == 0) {
1633
- dst[r1*ne0 + first_row + row] = all_sum;
1228
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1634
1229
  }
1635
1230
  }
1636
1231
  }
@@ -1640,10 +1235,14 @@ kernel void kernel_mul_mat_q4_K_f32(
1640
1235
  device const float * src1,
1641
1236
  device float * dst,
1642
1237
  constant int64_t & ne00,
1643
- constant int64_t & ne10,
1644
- constant int64_t & ne0,
1645
1238
  constant int64_t & ne01[[buffer(4)]],
1646
- uint2 tgpig[[threadgroup_position_in_grid]],
1239
+ constant int64_t & ne02[[buffer(5)]],
1240
+ constant int64_t & ne10[[buffer(9)]],
1241
+ constant int64_t & ne12[[buffer(11)]],
1242
+ constant int64_t & ne0[[buffer(15)]],
1243
+ constant int64_t & ne1[[buffer(16)]],
1244
+ constant uint & gqa[[buffer(17)]],
1245
+ uint3 tgpig[[threadgroup_position_in_grid]],
1647
1246
  uint tiisg[[thread_index_in_simdgroup]],
1648
1247
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1649
1248
 
@@ -1653,10 +1252,12 @@ kernel void kernel_mul_mat_q4_K_f32(
1653
1252
  const int nb = ne00/QK_K;
1654
1253
  const int r0 = tgpig.x;
1655
1254
  const int r1 = tgpig.y;
1255
+ const int r2 = tgpig.z;
1656
1256
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1657
1257
  const int ib_row = first_row * nb;
1658
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1659
- device const float * y = (device const float *) src1 + r1*ne10;
1258
+ const uint offset0 = r2/gqa*(nb*ne0);
1259
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1260
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1660
1261
  float yl[8];
1661
1262
  float yh[8];
1662
1263
  float sumf[N_DST]={0.f}, all_sum;
@@ -1712,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1712
1313
  for (int row = 0; row < N_DST; ++row) {
1713
1314
  all_sum = simd_sum(sumf[row]);
1714
1315
  if (tiisg == 0) {
1715
- dst[r1*ne0 + first_row + row] = all_sum;
1316
+ dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
1716
1317
  }
1717
1318
  }
1718
1319
  }
@@ -1723,9 +1324,14 @@ kernel void kernel_mul_mat_q5_K_f32(
1723
1324
  device const float * src1,
1724
1325
  device float * dst,
1725
1326
  constant int64_t & ne00,
1726
- constant int64_t & ne10,
1727
- constant int64_t & ne0,
1728
- uint2 tgpig[[threadgroup_position_in_grid]],
1327
+ constant int64_t & ne01[[buffer(4)]],
1328
+ constant int64_t & ne02[[buffer(5)]],
1329
+ constant int64_t & ne10[[buffer(9)]],
1330
+ constant int64_t & ne12[[buffer(11)]],
1331
+ constant int64_t & ne0[[buffer(15)]],
1332
+ constant int64_t & ne1[[buffer(16)]],
1333
+ constant uint & gqa[[buffer(17)]],
1334
+ uint3 tgpig[[threadgroup_position_in_grid]],
1729
1335
  uint tiisg[[thread_index_in_simdgroup]],
1730
1336
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1731
1337
 
@@ -1733,11 +1339,12 @@ kernel void kernel_mul_mat_q5_K_f32(
1733
1339
 
1734
1340
  const int64_t r0 = tgpig.x;
1735
1341
  const int64_t r1 = tgpig.y;
1342
+ const int r2 = tgpig.z;
1736
1343
 
1737
1344
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1738
-
1739
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1740
- device const float * yy = (device const float *) src1 + r1*ne10;
1345
+ const uint offset0 = r2/gqa*(nb*ne0);
1346
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
1347
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1741
1348
 
1742
1349
  float sumf[2]={0.f};
1743
1350
 
@@ -1871,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1871
1478
  for (int row = 0; row < 2; ++row) {
1872
1479
  const float tot = simd_sum(sumf[row]);
1873
1480
  if (tiisg == 0) {
1874
- dst[r1*ne0 + first_row + row] = tot;
1481
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1875
1482
  }
1876
1483
  }
1877
1484
 
@@ -1882,9 +1489,14 @@ kernel void kernel_mul_mat_q6_K_f32(
1882
1489
  device const float * src1,
1883
1490
  device float * dst,
1884
1491
  constant int64_t & ne00,
1885
- constant int64_t & ne10,
1886
- constant int64_t & ne0,
1887
- uint2 tgpig[[threadgroup_position_in_grid]],
1492
+ constant int64_t & ne01[[buffer(4)]],
1493
+ constant int64_t & ne02[[buffer(5)]],
1494
+ constant int64_t & ne10[[buffer(9)]],
1495
+ constant int64_t & ne12[[buffer(11)]],
1496
+ constant int64_t & ne0[[buffer(15)]],
1497
+ constant int64_t & ne1[[buffer(16)]],
1498
+ constant uint & gqa[[buffer(17)]],
1499
+ uint3 tgpig[[threadgroup_position_in_grid]],
1888
1500
  uint tiisg[[thread_index_in_simdgroup]],
1889
1501
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1890
1502
 
@@ -1897,11 +1509,12 @@ kernel void kernel_mul_mat_q6_K_f32(
1897
1509
 
1898
1510
  const int64_t r0 = tgpig.x;
1899
1511
  const int64_t r1 = tgpig.y;
1512
+ const int r2 = tgpig.z;
1900
1513
 
1901
1514
  const int row = 2 * r0 + sgitg;
1902
-
1903
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1904
- device const float * yy = (device const float *) src1 + r1*ne10;
1515
+ const uint offset0 = r2/gqa*(nb*ne0);
1516
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
1517
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1905
1518
 
1906
1519
  float sumf = 0;
1907
1520
 
@@ -1967,6 +1580,366 @@ kernel void kernel_mul_mat_q6_K_f32(
1967
1580
 
1968
1581
  const float tot = simd_sum(sumf);
1969
1582
  if (tiisg == 0) {
1970
- dst[r1*ne0 + row] = tot;
1583
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1584
+ }
1585
+ }
1586
+
1587
+ //============================= templates and their specializations =============================
1588
+
1589
+ template <typename type4x4>
1590
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1591
+ half4x4 temp = *(((device half4x4 *)src));
1592
+ for (int i = 0; i < 16; i++){
1593
+ reg[i/4][i%4] = temp[i/4][i%4];
1594
+ }
1595
+ }
1596
+
1597
+ template <typename type4x4>
1598
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1599
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1600
+ const half d = il ? (xb->d / 16.h) : xb->d;
1601
+ const half m = il ? (-8.h * 16.h) : -8.h;
1602
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
1603
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
1604
+
1605
+ for (int i=0;i<8;i++) {
1606
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
1607
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1608
+ }
1609
+ }
1610
+
1611
+ template <typename type4x4>
1612
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1613
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1614
+ const half d = il ? (xb->d / 16.h) : xb->d;
1615
+ const half m = xb->m;
1616
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
1617
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
1618
+
1619
+ for (int i=0;i<8;i++) {
1620
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
1621
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1622
+ }
1623
+ }
1624
+
1625
+ template <typename type4x4>
1626
+ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
1627
+ const half d = xb->d;
1628
+ const half min = xb->dmin;
1629
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
1630
+ half dl, ml;
1631
+ uint8_t sc = xb->scales[il];
1632
+
1633
+ #if QK_K == 256
1634
+ q = q + 32*(il/8) + 16*(il&1);
1635
+ il = (il/2)%4;
1636
+ #endif
1637
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1638
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1639
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
1640
+ for (int i = 0; i < 16; ++i) {
1641
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1642
+ }
1643
+ }
1644
+
1645
+ template <typename type4x4>
1646
+ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1647
+ const float d_all = (float)(xb->d);
1648
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
1649
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
1650
+ device const int8_t * scales = (device const int8_t *)xb->scales;
1651
+
1652
+ #if QK_K == 256
1653
+ q = q + 32 * (il/8) + 16 * (il&1);
1654
+ h = h + 16 * (il&1);
1655
+ uint8_t m = 1 << (il/2);
1656
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
1657
+ ((il/4)>0 ? 12 : 3);
1658
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1659
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1660
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1661
+ (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1662
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
1663
+
1664
+ il = (il/2)%4;
1665
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1666
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1667
+
1668
+ for (int i = 0; i < 16; ++i) {
1669
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1670
+ }
1671
+ #else
1672
+ float kcoef = il&1 ? 1.f/16.f : 1.f;
1673
+ uint16_t kmask = il&1 ? 0xF0 : 0x0F;
1674
+ float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
1675
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1676
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1677
+ uint8_t m = 1<<(il*2);
1678
+ for (int i = 0; i < 16; ++i) {
1679
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
1680
+ }
1681
+ #endif
1682
+ }
1683
+
1684
+ template <typename type4x4>
1685
+ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1686
+ device const uint8_t * q = xb->qs;
1687
+
1688
+ #if QK_K == 256
1689
+ const float d = (float)(xb->d);
1690
+ const float min = (float)(xb->dmin);
1691
+ short is = (il/4) * 2;
1692
+ q = q + (il/4) * 32 + 16 * (il&1);
1693
+ il = il%4;
1694
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
1695
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1696
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
1697
+ #else
1698
+ q = q + 16 * (il&1);
1699
+ device const uint8_t * s = xb->scales;
1700
+ device const half2 * dh = (device const half2 *)xb->d;
1701
+ const float2 d = (float2)dh[0];
1702
+ const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1703
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
1704
+ #endif
1705
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1706
+ for (int i = 0; i < 16; ++i) {
1707
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1708
+ }
1709
+ }
1710
+
1711
+ template <typename type4x4>
1712
+ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
1713
+ device const uint8_t * q = xb->qs;
1714
+ device const uint8_t * qh = xb->qh;
1715
+
1716
+ #if QK_K == 256
1717
+ const float d = (float)(xb->d);
1718
+ const float min = (float)(xb->dmin);
1719
+ short is = (il/4) * 2;
1720
+ q = q + 32 * (il/4) + 16 * (il&1);
1721
+ qh = qh + 16 * (il&1);
1722
+ uint8_t ul = 1 << (il/2);
1723
+ il = il%4;
1724
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
1725
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1726
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
1727
+
1728
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1729
+ const float qh_val = il<2 ? 16.f : 256.f;
1730
+ for (int i = 0; i < 16; ++i) {
1731
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1732
+ }
1733
+ #else
1734
+ q = q + 16 * (il&1);
1735
+ device const int8_t * s = xb->scales;
1736
+ const float dl = xb->d * s[il];
1737
+ uint8_t m = 1<<(il*2);
1738
+ const float coef = il<2 ? 1.f : 1.f/16.f;
1739
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1740
+ for (int i = 0; i < 16; ++i) {
1741
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
1742
+ }
1743
+ #endif
1744
+ }
1745
+
1746
+ template <typename type4x4>
1747
+ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1748
+ const float d_all = (float)(xb->d);
1749
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
1750
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
1751
+ device const int8_t * scales = (device const int8_t *)xb->scales;
1752
+
1753
+ #if QK_K == 256
1754
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1755
+ qh = qh + 32*(il/8) + 16*(il&1);
1756
+ float sc = scales[(il%2) + 2 * ((il/2))];
1757
+ il = (il/2)%4;
1758
+ #else
1759
+ ql = ql + 16 * (il&1);
1760
+ float sc = scales[il];
1761
+ #endif
1762
+ for (int i = 0; i < 16; ++i) {
1763
+ uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1764
+ uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1765
+ const float coef = il>1 ? 1.f/16.f : 1.f;
1766
+ float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1767
+ ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1768
+ reg[i/4][i%4] = d_all * sc * q * coef;
1769
+ }
1770
+ }
1771
+
1772
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
1773
+ kernel void kernel_get_rows(
1774
+ device const void * src0,
1775
+ device const int * src1,
1776
+ device float * dst,
1777
+ constant int64_t & ne00,
1778
+ constant uint64_t & nb01,
1779
+ constant uint64_t & nb1,
1780
+ uint tgpig[[threadgroup_position_in_grid]],
1781
+ uint tiitg[[thread_index_in_threadgroup]],
1782
+ uint tptg[[threads_per_threadgroup]]) {
1783
+ const int i = tgpig;
1784
+ const int r = ((device int32_t *) src1)[i];
1785
+
1786
+ for (int ind = tiitg; ind < ne00/16; ind += tptg) {
1787
+ float4x4 temp;
1788
+ dequantize_func(
1789
+ ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
1790
+ *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
1791
+ }
1792
+ }
1793
+
1794
+ #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
1795
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
1796
+ #define BLOCK_SIZE_K 32
1797
+ #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
1798
+ #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
1799
+ #define THREAD_PER_BLOCK 128
1800
+ #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
1801
+ #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
1802
+ #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
1803
+ #define SG_MAT_ROW 8
1804
+
1805
+ // each block_q contains 16*nl weights
1806
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
1807
+ kernel void kernel_mul_mm(device const uchar * src0,
1808
+ device const float * src1,
1809
+ device float * dst,
1810
+ constant int64_t & ne00,
1811
+ constant int64_t & ne02,
1812
+ constant int64_t & nb01,
1813
+ constant int64_t & nb02,
1814
+ constant int64_t & ne12,
1815
+ constant int64_t & ne0,
1816
+ constant int64_t & ne1,
1817
+ constant uint & gqa,
1818
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
1819
+ uint3 tgpig[[threadgroup_position_in_grid]],
1820
+ uint tiitg[[thread_index_in_threadgroup]],
1821
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1822
+
1823
+ threadgroup half * sa = ((threadgroup half *)shared_memory);
1824
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
1825
+
1826
+ const uint r0 = tgpig.y;
1827
+ const uint r1 = tgpig.x;
1828
+ const uint im = tgpig.z;
1829
+ // if this block is of 64x32 shape or smaller
1830
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
1831
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
1832
+ // a thread shouldn't load data outside of the matrix
1833
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
1834
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
1835
+
1836
+ simdgroup_half8x8 ma[4];
1837
+ simdgroup_float8x8 mb[2];
1838
+ simdgroup_float8x8 c_res[8];
1839
+ for (int i = 0; i < 8; i++){
1840
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
1841
+ }
1842
+
1843
+ short il = (tiitg % THREAD_PER_ROW);
1844
+ uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
1845
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
1846
+ device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
1847
+ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
1848
+
1849
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
1850
+ //load data and store to threadgroup memory
1851
+ half4x4 temp_a;
1852
+ dequantize_func(x, il, temp_a);
1853
+ #pragma unroll(16)
1854
+ for (int i = 0; i < 16; i++) {
1855
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
1856
+ + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
1857
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
1858
+ }
1859
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
1860
+ = *((device float2x4 *)y);
1861
+ il = (il + 2 < nl) ? il + 2 : il % 2;
1862
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
1863
+ y += BLOCK_SIZE_K;
1864
+
1865
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1866
+ //load matrices from threadgroup memory and conduct outer products
1867
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
1868
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
1869
+ #pragma unroll(4)
1870
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
1871
+ #pragma unroll(4)
1872
+ for (int i = 0; i < 4; i++) {
1873
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
1874
+ }
1875
+ simdgroup_barrier(mem_flags::mem_none);
1876
+ #pragma unroll(2)
1877
+ for (int i = 0; i < 2; i++) {
1878
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
1879
+ }
1880
+
1881
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
1882
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
1883
+ #pragma unroll(8)
1884
+ for (int i = 0; i < 8; i++){
1885
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
1886
+ }
1887
+ }
1888
+ }
1889
+
1890
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
1891
+ device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
1892
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
1893
+ for (int i = 0; i < 8; i++) {
1894
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
1895
+ }
1896
+ } else {
1897
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
1898
+ threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
1899
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
1900
+ for (int i = 0; i < 8; i++) {
1901
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
1902
+ }
1903
+
1904
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1905
+ device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
1906
+ if (sgitg==0) {
1907
+ for (int i = 0; i < n_rows; i++) {
1908
+ for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
1909
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
1910
+ }
1911
+ }
1912
+ }
1971
1913
  }
1972
1914
  }
1915
+
1916
+ #if QK_K == 256
1917
+ #define QK_NL 16
1918
+ #else
1919
+ #define QK_NL 4
1920
+ #endif
1921
+
1922
+ typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
1923
+ constant uint64_t &, constant uint64_t &, uint, uint, uint);
1924
+
1925
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
1926
+ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
1927
+ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
1928
+ template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
1929
+ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
1930
+ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
1931
+ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
1932
+ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
1933
+
1934
+ typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
1935
+ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1936
+ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
1937
+
1938
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
1939
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
1940
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
1941
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
1942
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
1943
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
1944
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
1945
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;