llama_cpp 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,47 +18,6 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
- static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
22
- const int qk = QK4_0;
23
-
24
- assert(k % qk == 0);
25
-
26
- const int nb = k / qk;
27
-
28
- for (int i = 0; i < nb; i++) {
29
- const half d = x[i].d;
30
-
31
- for (int j = 0; j < qk/2; ++j) {
32
- const int x0 = (x[i].qs[j] & 0x0F) - 8;
33
- const int x1 = (x[i].qs[j] >> 4) - 8;
34
-
35
- y[i*qk + j + 0 ] = x0*d;
36
- y[i*qk + j + qk/2] = x1*d;
37
- }
38
- }
39
- }
40
-
41
- static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
42
- const int qk = QK4_1;
43
-
44
- assert(k % qk == 0);
45
-
46
- const int nb = k / qk;
47
-
48
- for (int i = 0; i < nb; i++) {
49
- const half d = x[i].d;
50
- const half m = x[i].m;
51
-
52
- for (int j = 0; j < qk/2; ++j) {
53
- const int x0 = (x[i].qs[j] & 0x0F);
54
- const int x1 = (x[i].qs[j] >> 4);
55
-
56
- y[i*qk + j + 0 ] = x0*d + m;
57
- y[i*qk + j + qk/2] = x1*d + m;
58
- }
59
- }
60
- }
61
-
62
21
  kernel void kernel_add(
63
22
  device const float * src0,
64
23
  device const float * src1,
@@ -219,54 +178,6 @@ kernel void kernel_diag_mask_inf(
219
178
  }
220
179
  }
221
180
 
222
- kernel void kernel_get_rows_f16(
223
- device const void * src0,
224
- device const int * src1,
225
- device float * dst,
226
- constant int64_t & ne00,
227
- constant uint64_t & nb01,
228
- constant uint64_t & nb1,
229
- uint tpig[[thread_position_in_grid]]) {
230
- const int i = tpig;
231
- const int r = ((device int32_t *) src1)[i];
232
-
233
- for (int j = 0; j < ne00; j++) {
234
- dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
235
- }
236
- }
237
-
238
- kernel void kernel_get_rows_q4_0(
239
- device const void * src0,
240
- device const int * src1,
241
- device float * dst,
242
- constant int64_t & ne00,
243
- constant uint64_t & nb01,
244
- constant uint64_t & nb1,
245
- uint tpig[[thread_position_in_grid]]) {
246
- const int i = tpig;
247
- const int r = ((device int32_t *) src1)[i];
248
-
249
- dequantize_row_q4_0(
250
- (device const block_q4_0 *) ((device char *) src0 + r*nb01),
251
- (device float *) ((device char *) dst + i*nb1), ne00);
252
- }
253
-
254
- kernel void kernel_get_rows_q4_1(
255
- device const void * src0,
256
- device const int * src1,
257
- device float * dst,
258
- constant int64_t & ne00,
259
- constant uint64_t & nb01,
260
- constant uint64_t & nb1,
261
- uint tpig[[thread_position_in_grid]]) {
262
- const int i = tpig;
263
- const int r = ((device int32_t *) src1)[i];
264
-
265
- dequantize_row_q4_1(
266
- (device const block_q4_1 *) ((device char *) src0 + r*nb01),
267
- (device float *) ((device char *) dst + i*nb1), ne00);
268
- }
269
-
270
181
  kernel void kernel_norm(
271
182
  device const void * src0,
272
183
  device float * dst,
@@ -432,14 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
432
343
  // N_DST, so this is another explicit assumption of the implementation.
433
344
  template<typename block_q_type, int nr, int nsg, int nw>
434
345
  void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
435
- int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
436
- uint2 tgpig, uint tiisg, uint sgitg) {
346
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
347
+ uint3 tgpig, uint tiisg, uint sgitg) {
437
348
  const int nb = ne00/QK4_0;
438
349
  const int r0 = tgpig.x;
439
350
  const int r1 = tgpig.y;
351
+ const int im = tgpig.z;
440
352
  const int first_row = (r0 * nsg + sgitg) * nr;
441
- device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
442
- device const float * y = (device const float *) src1 + r1*ne10;
353
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
354
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
443
356
  float yl[16]; // src1 vector cache
444
357
  float sumf[nr]={0.f};
445
358
 
@@ -470,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
470
383
  for (int row = 0; row < nr; ++row) {
471
384
  const float tot = simd_sum(sumf[row]);
472
385
  if (tiisg == 0 && first_row + row < ne01) {
473
- dst[r1*ne0 + first_row + row] = tot;
386
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
474
387
  }
475
388
  }
476
389
  }
@@ -480,13 +393,17 @@ kernel void kernel_mul_mat_q4_0_f32(
480
393
  device const float * src1,
481
394
  device float * dst,
482
395
  constant int64_t & ne00,
483
- constant int64_t & ne10,
484
- constant int64_t & ne0,
485
396
  constant int64_t & ne01[[buffer(4)]],
486
- uint2 tgpig[[threadgroup_position_in_grid]],
397
+ constant int64_t & ne02[[buffer(5)]],
398
+ constant int64_t & ne10[[buffer(9)]],
399
+ constant int64_t & ne12[[buffer(11)]],
400
+ constant int64_t & ne0[[buffer(15)]],
401
+ constant int64_t & ne1[[buffer(16)]],
402
+ constant uint & gqa[[buffer(17)]],
403
+ uint3 tgpig[[threadgroup_position_in_grid]],
487
404
  uint tiisg[[thread_index_in_simdgroup]],
488
405
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
406
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
490
407
  }
491
408
 
492
409
  kernel void kernel_mul_mat_q4_1_f32(
@@ -494,13 +411,17 @@ kernel void kernel_mul_mat_q4_1_f32(
494
411
  device const float * src1,
495
412
  device float * dst,
496
413
  constant int64_t & ne00,
497
- constant int64_t & ne10,
498
- constant int64_t & ne0,
499
414
  constant int64_t & ne01[[buffer(4)]],
500
- uint2 tgpig[[threadgroup_position_in_grid]],
415
+ constant int64_t & ne02[[buffer(5)]],
416
+ constant int64_t & ne10[[buffer(9)]],
417
+ constant int64_t & ne12[[buffer(11)]],
418
+ constant int64_t & ne0[[buffer(15)]],
419
+ constant int64_t & ne1[[buffer(16)]],
420
+ constant uint & gqa[[buffer(17)]],
421
+ uint3 tgpig[[threadgroup_position_in_grid]],
501
422
  uint tiisg[[thread_index_in_simdgroup]],
502
423
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
503
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
424
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
504
425
  }
505
426
 
506
427
  kernel void kernel_mul_mat_f16_f32(
@@ -869,354 +790,6 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
869
790
  return r;
870
791
  }
871
792
 
872
- //========================================== dequantization =============================
873
-
874
- static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
875
- assert(k % QK_K == 0);
876
- const int nb = k / QK_K;
877
-
878
- for (int i = 0; i < nb; i++) {
879
-
880
- const float d = x[i].d;
881
- const float min = x[i].dmin;
882
-
883
- device const uint8_t * q = x[i].qs;
884
-
885
- #if QK_K == 256
886
- int is = 0;
887
- float dl, ml;
888
- for (int n = 0; n < QK_K; n += 128) {
889
- int shift = 0;
890
- for (int j = 0; j < 4; ++j) {
891
-
892
- uint8_t sc = x[i].scales[is++];
893
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
894
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
895
-
896
- sc = x[i].scales[is++];
897
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
898
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
899
-
900
- shift += 2;
901
- }
902
- q += 32;
903
- }
904
- #else
905
- float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
906
- float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
907
- float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
908
- float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
909
- for (int l = 0; l < 16; ++l) {
910
- y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
911
- y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
912
- y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
913
- y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
914
- }
915
- y += QK_K;
916
- #endif
917
-
918
- }
919
- }
920
-
921
- static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
922
- assert(k % QK_K == 0);
923
- const int nb = k / QK_K;
924
-
925
- #if QK_K == 256
926
-
927
- const uint16_t kmask1 = 0x0303;
928
- const uint16_t kmask2 = 0x0f0f;
929
-
930
- uint16_t aux[8];
931
- thread const int8_t * scales = (thread const int8_t*)aux;
932
-
933
- for (int i = 0; i < nb; i++) {
934
-
935
- const float d_all = (float)(x[i].d);
936
-
937
- device const uint8_t * q = x[i].qs;
938
- device const uint8_t * h = x[i].hmask;
939
- uint8_t m = 1;
940
-
941
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
942
- aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
943
- aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
944
- aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
945
- aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
946
- aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
947
- aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
948
- aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
949
- aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
950
-
951
- int is = 0;
952
- float dl;
953
- for (int n = 0; n < QK_K; n += 128) {
954
- int shift = 0;
955
- for (int j = 0; j < 4; ++j) {
956
-
957
- dl = d_all * (scales[is++] - 32);
958
- for (int l = 0; l < 16; ++l) {
959
- *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
960
- }
961
-
962
- dl = d_all * (scales[is++] - 32);
963
- for (int l = 0; l < 16; ++l) {
964
- *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
965
- }
966
-
967
- shift += 2;
968
- m <<= 1;
969
- }
970
- q += 32;
971
- }
972
- }
973
- #else
974
- for (int i = 0; i < nb; i++) {
975
-
976
- const float d_all = (float)(x[i].d);
977
-
978
- device const uint8_t * q = x[i].qs;
979
- device const uint8_t * hm = x[i].hmask;
980
-
981
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
982
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
983
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
984
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
985
-
986
- for (int l = 0; l < 8; ++l) {
987
- uint8_t h = hm[l];
988
- y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
989
- y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
990
- y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
991
- y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
992
- y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
993
- y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
994
- y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
995
- y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
996
- }
997
- y += QK_K;
998
- }
999
- #endif
1000
-
1001
- }
1002
-
1003
- static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
1004
- assert(k % QK_K == 0);
1005
- const int nb = k / QK_K;
1006
-
1007
- for (int i = 0; i < nb; i++) {
1008
-
1009
- device const uint8_t * q = x[i].qs;
1010
-
1011
- #if QK_K == 256
1012
- const float d = x[i].d;
1013
- const float min = x[i].dmin;
1014
-
1015
- device const uint8_t * scales = x[i].scales;
1016
-
1017
- int is = 0;
1018
- for (int j = 0; j < QK_K; j += 64) {
1019
- const uchar4 sc = get_scale_min_k4(is, scales);
1020
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
1021
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
1022
- for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
1023
- for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
1024
- q += 32; is += 2;
1025
- }
1026
- #else
1027
- device const uint8_t * s = x[i].scales;
1028
- device const half2 * dh = (device const half2 *)x[i].d;
1029
- const float2 d = (float2)dh[0];
1030
- const float d1 = d[0] * (s[0] & 0xF);
1031
- const float d2 = d[0] * (s[1] & 0xF);
1032
- const float m1 = d[1] * (s[0] >> 4);
1033
- const float m2 = d[1] * (s[1] >> 4);
1034
- for (int l = 0; l < 32; ++l) {
1035
- y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1036
- y[l+32] = d2 * (q[l] >> 4) - m2;
1037
- }
1038
- y += QK_K;
1039
- #endif
1040
-
1041
- }
1042
- }
1043
-
1044
- static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
1045
- assert(k % QK_K == 0);
1046
- const int nb = k / QK_K;
1047
-
1048
- #if QK_K == 256
1049
- for (int i = 0; i < nb; i++) {
1050
-
1051
- const float d = (float)(x[i].d);
1052
- const float min = (float)(x[i].dmin);
1053
-
1054
- device const uint8_t * ql = x[i].qs;
1055
- device const uint8_t * qh = x[i].qh;
1056
-
1057
- int is = 0;
1058
- uint8_t u1 = 1, u2 = 2;
1059
- for (int j = 0; j < QK_K; j += 64) {
1060
- const uchar4 sc = get_scale_min_k4(is, x[i].scales);
1061
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
1062
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
1063
- for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
1064
- for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
1065
- ql += 32; is += 2;
1066
- u1 <<= 2; u2 <<= 2;
1067
- }
1068
- }
1069
- #else
1070
- for (int i = 0; i < nb; i++) {
1071
-
1072
- const float d = (float)x[i].d;
1073
-
1074
- device const uint8_t * ql = x[i].qs;
1075
- device const uint8_t * qh = x[i].qh;
1076
- device const int8_t * sc = x[i].scales;
1077
-
1078
- for (int l = 0; l < 8; ++l) {
1079
- y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1080
- y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1081
- y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1082
- y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1083
- y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1084
- y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1085
- y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1086
- y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1087
- }
1088
- y += QK_K;
1089
- }
1090
- #endif
1091
-
1092
- }
1093
-
1094
- static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
1095
- assert(k % QK_K == 0);
1096
- const int nb = k / QK_K;
1097
-
1098
- for (int i = 0; i < nb; i++) {
1099
-
1100
- device const uint8_t * ql = x[i].ql;
1101
- device const uint8_t * qh = x[i].qh;
1102
- device const int8_t * sc = x[i].scales;
1103
-
1104
- const float d = x[i].d;
1105
-
1106
- #if QK_K == 256
1107
- for (int n = 0; n < QK_K; n += 128) {
1108
- for (int l = 0; l < 32; ++l) {
1109
- int is = l/16;
1110
- const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1111
- const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1112
- const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1113
- const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1114
- y[l + 0] = d * sc[is + 0] * q1;
1115
- y[l + 32] = d * sc[is + 2] * q2;
1116
- y[l + 64] = d * sc[is + 4] * q3;
1117
- y[l + 96] = d * sc[is + 6] * q4;
1118
- }
1119
- y += 128;
1120
- ql += 64;
1121
- qh += 32;
1122
- sc += 8;
1123
- }
1124
- #else
1125
- for (int l = 0; l < 16; ++l) {
1126
- const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1127
- const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1128
- const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1129
- const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1130
- y[l+ 0] = d * sc[0] * q1;
1131
- y[l+16] = d * sc[1] * q2;
1132
- y[l+32] = d * sc[2] * q3;
1133
- y[l+48] = d * sc[3] * q4;
1134
- }
1135
- y += 64;
1136
- #endif
1137
- }
1138
- }
1139
-
1140
- kernel void kernel_get_rows_q2_K(
1141
- device const void * src0,
1142
- device const int * src1,
1143
- device float * dst,
1144
- constant int64_t & ne00,
1145
- constant uint64_t & nb01,
1146
- constant uint64_t & nb1,
1147
- uint tpig[[thread_position_in_grid]]) {
1148
- const int i = tpig;
1149
- const int r = ((device int32_t *) src1)[i];
1150
-
1151
- dequantize_row_q2_K(
1152
- (device const block_q2_K *) ((device char *) src0 + r*nb01),
1153
- (device float *) ((device char *) dst + i*nb1), ne00);
1154
- }
1155
-
1156
- kernel void kernel_get_rows_q3_K(
1157
- device const void * src0,
1158
- device const int * src1,
1159
- device float * dst,
1160
- constant int64_t & ne00,
1161
- constant uint64_t & nb01,
1162
- constant uint64_t & nb1,
1163
- uint tpig[[thread_position_in_grid]]) {
1164
- const int i = tpig;
1165
- const int r = ((device int32_t *) src1)[i];
1166
-
1167
- dequantize_row_q3_K(
1168
- (device const block_q3_K *) ((device char *) src0 + r*nb01),
1169
- (device float *) ((device char *) dst + i*nb1), ne00);
1170
- }
1171
-
1172
- kernel void kernel_get_rows_q4_K(
1173
- device const void * src0,
1174
- device const int * src1,
1175
- device float * dst,
1176
- constant int64_t & ne00,
1177
- constant uint64_t & nb01,
1178
- constant uint64_t & nb1,
1179
- uint tpig[[thread_position_in_grid]]) {
1180
- const int i = tpig;
1181
- const int r = ((device int32_t *) src1)[i];
1182
-
1183
- dequantize_row_q4_K(
1184
- (device const block_q4_K *) ((device char *) src0 + r*nb01),
1185
- (device float *) ((device char *) dst + i*nb1), ne00);
1186
- }
1187
-
1188
- kernel void kernel_get_rows_q5_K(
1189
- device const void * src0,
1190
- device const int * src1,
1191
- device float * dst,
1192
- constant int64_t & ne00,
1193
- constant uint64_t & nb01,
1194
- constant uint64_t & nb1,
1195
- uint tpig[[thread_position_in_grid]]) {
1196
- const int i = tpig;
1197
- const int r = ((device int32_t *) src1)[i];
1198
-
1199
- dequantize_row_q5_K(
1200
- (device const block_q5_K *) ((device char *) src0 + r*nb01),
1201
- (device float *) ((device char *) dst + i*nb1), ne00);
1202
- }
1203
-
1204
- kernel void kernel_get_rows_q6_K(
1205
- device const void * src0,
1206
- device const int * src1,
1207
- device float * dst,
1208
- constant int64_t & ne00,
1209
- constant uint64_t & nb01,
1210
- constant uint64_t & nb1,
1211
- uint tpig[[thread_position_in_grid]]) {
1212
- const int i = tpig;
1213
- const int r = ((device int32_t *) src1)[i];
1214
-
1215
- dequantize_row_q6_K(
1216
- (device const block_q6_K *) ((device char *) src0 + r*nb01),
1217
- (device float *) ((device char *) dst + i*nb1), ne00);
1218
- }
1219
-
1220
793
  //====================================== dot products =========================
1221
794
 
1222
795
  kernel void kernel_mul_mat_q2_K_f32(
@@ -1224,21 +797,27 @@ kernel void kernel_mul_mat_q2_K_f32(
1224
797
  device const float * src1,
1225
798
  device float * dst,
1226
799
  constant int64_t & ne00,
1227
- constant int64_t & ne10,
1228
- constant int64_t & ne0,
1229
800
  constant int64_t & ne01[[buffer(4)]],
1230
- uint2 tgpig[[threadgroup_position_in_grid]],
801
+ constant int64_t & ne02[[buffer(5)]],
802
+ constant int64_t & ne10[[buffer(9)]],
803
+ constant int64_t & ne12[[buffer(11)]],
804
+ constant int64_t & ne0[[buffer(15)]],
805
+ constant int64_t & ne1[[buffer(16)]],
806
+ constant uint & gqa[[buffer(17)]],
807
+ uint3 tgpig[[threadgroup_position_in_grid]],
1231
808
  uint tiisg[[thread_index_in_simdgroup]],
1232
809
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1233
810
 
1234
811
  const int nb = ne00/QK_K;
1235
812
  const int r0 = tgpig.x;
1236
813
  const int r1 = tgpig.y;
814
+ const int r2 = tgpig.z;
1237
815
 
1238
816
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1239
817
  const int ib_row = first_row * nb;
1240
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1241
- device const float * y = (device const float *) src1 + r1*ne10;
818
+ const uint offset0 = r2/gqa*(nb*ne0);
819
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
820
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1242
821
  float yl[32];
1243
822
  float sumf[N_DST]={0.f}, all_sum;
1244
823
 
@@ -1351,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1351
930
  for (int row = 0; row < N_DST; ++row) {
1352
931
  all_sum = simd_sum(sumf[row]);
1353
932
  if (tiisg == 0) {
1354
- dst[r1*ne0 + first_row + row] = all_sum;
933
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1355
934
  }
1356
935
  }
1357
936
  }
@@ -1362,10 +941,14 @@ kernel void kernel_mul_mat_q3_K_f32(
1362
941
  device const float * src1,
1363
942
  device float * dst,
1364
943
  constant int64_t & ne00,
1365
- constant int64_t & ne10,
1366
- constant int64_t & ne0,
1367
- constant int64_t & ne1,
1368
- uint2 tgpig[[threadgroup_position_in_grid]],
944
+ constant int64_t & ne01[[buffer(4)]],
945
+ constant int64_t & ne02[[buffer(5)]],
946
+ constant int64_t & ne10[[buffer(9)]],
947
+ constant int64_t & ne12[[buffer(11)]],
948
+ constant int64_t & ne0[[buffer(15)]],
949
+ constant int64_t & ne1[[buffer(16)]],
950
+ constant uint & gqa[[buffer(17)]],
951
+ uint3 tgpig[[threadgroup_position_in_grid]],
1369
952
  uint tiisg[[thread_index_in_simdgroup]],
1370
953
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1371
954
 
@@ -1373,11 +956,12 @@ kernel void kernel_mul_mat_q3_K_f32(
1373
956
 
1374
957
  const int64_t r0 = tgpig.x;
1375
958
  const int64_t r1 = tgpig.y;
959
+ const int64_t r2 = tgpig.z;
1376
960
 
1377
961
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1378
-
1379
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1380
- device const float * yy = (device const float *) src1 + r1*ne10;
962
+ const uint offset0 = r2/gqa*(nb*ne0);
963
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
964
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1381
965
 
1382
966
  float yl[16];
1383
967
 
@@ -1465,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1465
1049
  const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1466
1050
  const float tot = simd_sum(sumf);
1467
1051
  if (tiisg == 0) {
1468
- dst[r1*ne0 + first_row + row] = tot;
1052
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1469
1053
  }
1470
1054
  }
1471
1055
  }
@@ -1475,10 +1059,14 @@ kernel void kernel_mul_mat_q3_K_f32(
1475
1059
  device const float * src1,
1476
1060
  device float * dst,
1477
1061
  constant int64_t & ne00,
1478
- constant int64_t & ne10,
1479
- constant int64_t & ne0,
1480
- constant int64_t & ne1,
1481
- uint2 tgpig[[threadgroup_position_in_grid]],
1062
+ constant int64_t & ne01[[buffer(4)]],
1063
+ constant int64_t & ne02[[buffer(5)]],
1064
+ constant int64_t & ne10[[buffer(9)]],
1065
+ constant int64_t & ne12[[buffer(11)]],
1066
+ constant int64_t & ne0[[buffer(15)]],
1067
+ constant int64_t & ne1[[buffer(16)]],
1068
+ constant uint & gqa[[buffer(17)]],
1069
+ uint3 tgpig[[threadgroup_position_in_grid]],
1482
1070
  uint tiisg[[thread_index_in_simdgroup]],
1483
1071
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1484
1072
 
@@ -1486,11 +1074,12 @@ kernel void kernel_mul_mat_q3_K_f32(
1486
1074
 
1487
1075
  const int64_t r0 = tgpig.x;
1488
1076
  const int64_t r1 = tgpig.y;
1077
+ const int64_t r2 = tgpig.z;
1489
1078
 
1490
1079
  const int row = 2 * r0 + sgitg;
1491
-
1492
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1493
- device const float * yy = (device const float *) src1 + r1*ne10;
1080
+ const uint offset0 = r2/gqa*(nb*ne0);
1081
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1082
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1494
1083
  const int ix = tiisg/4;
1495
1084
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1496
1085
  const int im = il/8; // 0, 0, 1, 1
@@ -1529,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1529
1118
 
1530
1119
  const float tot = simd_sum(sumf);
1531
1120
  if (tiisg == 0) {
1532
- dst[r1*ne0 + row] = tot;
1121
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1533
1122
  }
1534
1123
 
1535
1124
  }
@@ -1541,10 +1130,14 @@ kernel void kernel_mul_mat_q4_K_f32(
1541
1130
  device const float * src1,
1542
1131
  device float * dst,
1543
1132
  constant int64_t & ne00,
1544
- constant int64_t & ne10,
1545
- constant int64_t & ne0,
1546
1133
  constant int64_t & ne01[[buffer(4)]],
1547
- uint2 tgpig[[threadgroup_position_in_grid]],
1134
+ constant int64_t & ne02[[buffer(5)]],
1135
+ constant int64_t & ne10[[buffer(9)]],
1136
+ constant int64_t & ne12[[buffer(11)]],
1137
+ constant int64_t & ne0[[buffer(15)]],
1138
+ constant int64_t & ne1[[buffer(16)]],
1139
+ constant uint & gqa[[buffer(17)]],
1140
+ uint3 tgpig[[threadgroup_position_in_grid]],
1548
1141
  uint tiisg[[thread_index_in_simdgroup]],
1549
1142
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1550
1143
 
@@ -1560,10 +1153,12 @@ kernel void kernel_mul_mat_q4_K_f32(
1560
1153
  const int nb = ne00/QK_K;
1561
1154
  const int r0 = tgpig.x;
1562
1155
  const int r1 = tgpig.y;
1156
+ const int r2 = tgpig.z;
1563
1157
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1564
1158
  const int ib_row = first_row * nb;
1565
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1566
- device const float * y = (device const float *) src1 + r1*ne10;
1159
+ const uint offset0 = r2/gqa*(nb*ne0);
1160
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1161
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1567
1162
  float yl[16];
1568
1163
  float yh[16];
1569
1164
  float sumf[N_DST]={0.f}, all_sum;
@@ -1630,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1630
1225
  for (int row = 0; row < N_DST; ++row) {
1631
1226
  all_sum = simd_sum(sumf[row]);
1632
1227
  if (tiisg == 0) {
1633
- dst[r1*ne0 + first_row + row] = all_sum;
1228
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1634
1229
  }
1635
1230
  }
1636
1231
  }
@@ -1640,10 +1235,14 @@ kernel void kernel_mul_mat_q4_K_f32(
1640
1235
  device const float * src1,
1641
1236
  device float * dst,
1642
1237
  constant int64_t & ne00,
1643
- constant int64_t & ne10,
1644
- constant int64_t & ne0,
1645
1238
  constant int64_t & ne01[[buffer(4)]],
1646
- uint2 tgpig[[threadgroup_position_in_grid]],
1239
+ constant int64_t & ne02[[buffer(5)]],
1240
+ constant int64_t & ne10[[buffer(9)]],
1241
+ constant int64_t & ne12[[buffer(11)]],
1242
+ constant int64_t & ne0[[buffer(15)]],
1243
+ constant int64_t & ne1[[buffer(16)]],
1244
+ constant uint & gqa[[buffer(17)]],
1245
+ uint3 tgpig[[threadgroup_position_in_grid]],
1647
1246
  uint tiisg[[thread_index_in_simdgroup]],
1648
1247
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1649
1248
 
@@ -1653,10 +1252,12 @@ kernel void kernel_mul_mat_q4_K_f32(
1653
1252
  const int nb = ne00/QK_K;
1654
1253
  const int r0 = tgpig.x;
1655
1254
  const int r1 = tgpig.y;
1255
+ const int r2 = tgpig.z;
1656
1256
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1657
1257
  const int ib_row = first_row * nb;
1658
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1659
- device const float * y = (device const float *) src1 + r1*ne10;
1258
+ const uint offset0 = r2/gqa*(nb*ne0);
1259
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1260
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1660
1261
  float yl[8];
1661
1262
  float yh[8];
1662
1263
  float sumf[N_DST]={0.f}, all_sum;
@@ -1712,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1712
1313
  for (int row = 0; row < N_DST; ++row) {
1713
1314
  all_sum = simd_sum(sumf[row]);
1714
1315
  if (tiisg == 0) {
1715
- dst[r1*ne0 + first_row + row] = all_sum;
1316
+ dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
1716
1317
  }
1717
1318
  }
1718
1319
  }
@@ -1723,9 +1324,14 @@ kernel void kernel_mul_mat_q5_K_f32(
1723
1324
  device const float * src1,
1724
1325
  device float * dst,
1725
1326
  constant int64_t & ne00,
1726
- constant int64_t & ne10,
1727
- constant int64_t & ne0,
1728
- uint2 tgpig[[threadgroup_position_in_grid]],
1327
+ constant int64_t & ne01[[buffer(4)]],
1328
+ constant int64_t & ne02[[buffer(5)]],
1329
+ constant int64_t & ne10[[buffer(9)]],
1330
+ constant int64_t & ne12[[buffer(11)]],
1331
+ constant int64_t & ne0[[buffer(15)]],
1332
+ constant int64_t & ne1[[buffer(16)]],
1333
+ constant uint & gqa[[buffer(17)]],
1334
+ uint3 tgpig[[threadgroup_position_in_grid]],
1729
1335
  uint tiisg[[thread_index_in_simdgroup]],
1730
1336
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1731
1337
 
@@ -1733,11 +1339,12 @@ kernel void kernel_mul_mat_q5_K_f32(
1733
1339
 
1734
1340
  const int64_t r0 = tgpig.x;
1735
1341
  const int64_t r1 = tgpig.y;
1342
+ const int r2 = tgpig.z;
1736
1343
 
1737
1344
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1738
-
1739
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1740
- device const float * yy = (device const float *) src1 + r1*ne10;
1345
+ const uint offset0 = r2/gqa*(nb*ne0);
1346
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
1347
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1741
1348
 
1742
1349
  float sumf[2]={0.f};
1743
1350
 
@@ -1871,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1871
1478
  for (int row = 0; row < 2; ++row) {
1872
1479
  const float tot = simd_sum(sumf[row]);
1873
1480
  if (tiisg == 0) {
1874
- dst[r1*ne0 + first_row + row] = tot;
1481
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1875
1482
  }
1876
1483
  }
1877
1484
 
@@ -1882,9 +1489,14 @@ kernel void kernel_mul_mat_q6_K_f32(
1882
1489
  device const float * src1,
1883
1490
  device float * dst,
1884
1491
  constant int64_t & ne00,
1885
- constant int64_t & ne10,
1886
- constant int64_t & ne0,
1887
- uint2 tgpig[[threadgroup_position_in_grid]],
1492
+ constant int64_t & ne01[[buffer(4)]],
1493
+ constant int64_t & ne02[[buffer(5)]],
1494
+ constant int64_t & ne10[[buffer(9)]],
1495
+ constant int64_t & ne12[[buffer(11)]],
1496
+ constant int64_t & ne0[[buffer(15)]],
1497
+ constant int64_t & ne1[[buffer(16)]],
1498
+ constant uint & gqa[[buffer(17)]],
1499
+ uint3 tgpig[[threadgroup_position_in_grid]],
1888
1500
  uint tiisg[[thread_index_in_simdgroup]],
1889
1501
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
1890
1502
 
@@ -1897,11 +1509,12 @@ kernel void kernel_mul_mat_q6_K_f32(
1897
1509
 
1898
1510
  const int64_t r0 = tgpig.x;
1899
1511
  const int64_t r1 = tgpig.y;
1512
+ const int r2 = tgpig.z;
1900
1513
 
1901
1514
  const int row = 2 * r0 + sgitg;
1902
-
1903
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1904
- device const float * yy = (device const float *) src1 + r1*ne10;
1515
+ const uint offset0 = r2/gqa*(nb*ne0);
1516
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
1517
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1905
1518
 
1906
1519
  float sumf = 0;
1907
1520
 
@@ -1967,6 +1580,366 @@ kernel void kernel_mul_mat_q6_K_f32(
1967
1580
 
1968
1581
  const float tot = simd_sum(sumf);
1969
1582
  if (tiisg == 0) {
1970
- dst[r1*ne0 + row] = tot;
1583
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1584
+ }
1585
+ }
1586
+
1587
+ //============================= templates and their specializations =============================
1588
+
1589
+ template <typename type4x4>
1590
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1591
+ half4x4 temp = *(((device half4x4 *)src));
1592
+ for (int i = 0; i < 16; i++){
1593
+ reg[i/4][i%4] = temp[i/4][i%4];
1594
+ }
1595
+ }
1596
+
1597
+ template <typename type4x4>
1598
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1599
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1600
+ const half d = il ? (xb->d / 16.h) : xb->d;
1601
+ const half m = il ? (-8.h * 16.h) : -8.h;
1602
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
1603
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
1604
+
1605
+ for (int i=0;i<8;i++) {
1606
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
1607
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1608
+ }
1609
+ }
1610
+
1611
+ template <typename type4x4>
1612
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1613
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1614
+ const half d = il ? (xb->d / 16.h) : xb->d;
1615
+ const half m = xb->m;
1616
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
1617
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
1618
+
1619
+ for (int i=0;i<8;i++) {
1620
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
1621
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1622
+ }
1623
+ }
1624
+
1625
+ template <typename type4x4>
1626
+ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
1627
+ const half d = xb->d;
1628
+ const half min = xb->dmin;
1629
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
1630
+ half dl, ml;
1631
+ uint8_t sc = xb->scales[il];
1632
+
1633
+ #if QK_K == 256
1634
+ q = q + 32*(il/8) + 16*(il&1);
1635
+ il = (il/2)%4;
1636
+ #endif
1637
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1638
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1639
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
1640
+ for (int i = 0; i < 16; ++i) {
1641
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1642
+ }
1643
+ }
1644
+
1645
+ template <typename type4x4>
1646
+ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1647
+ const float d_all = (float)(xb->d);
1648
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
1649
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
1650
+ device const int8_t * scales = (device const int8_t *)xb->scales;
1651
+
1652
+ #if QK_K == 256
1653
+ q = q + 32 * (il/8) + 16 * (il&1);
1654
+ h = h + 16 * (il&1);
1655
+ uint8_t m = 1 << (il/2);
1656
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
1657
+ ((il/4)>0 ? 12 : 3);
1658
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1659
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1660
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1661
+ (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1662
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
1663
+
1664
+ il = (il/2)%4;
1665
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1666
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1667
+
1668
+ for (int i = 0; i < 16; ++i) {
1669
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1670
+ }
1671
+ #else
1672
+ float kcoef = il&1 ? 1.f/16.f : 1.f;
1673
+ uint16_t kmask = il&1 ? 0xF0 : 0x0F;
1674
+ float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
1675
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1676
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1677
+ uint8_t m = 1<<(il*2);
1678
+ for (int i = 0; i < 16; ++i) {
1679
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
1680
+ }
1681
+ #endif
1682
+ }
1683
+
1684
+ template <typename type4x4>
1685
+ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1686
+ device const uint8_t * q = xb->qs;
1687
+
1688
+ #if QK_K == 256
1689
+ const float d = (float)(xb->d);
1690
+ const float min = (float)(xb->dmin);
1691
+ short is = (il/4) * 2;
1692
+ q = q + (il/4) * 32 + 16 * (il&1);
1693
+ il = il%4;
1694
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
1695
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1696
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
1697
+ #else
1698
+ q = q + 16 * (il&1);
1699
+ device const uint8_t * s = xb->scales;
1700
+ device const half2 * dh = (device const half2 *)xb->d;
1701
+ const float2 d = (float2)dh[0];
1702
+ const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1703
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
1704
+ #endif
1705
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1706
+ for (int i = 0; i < 16; ++i) {
1707
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1708
+ }
1709
+ }
1710
+
1711
+ template <typename type4x4>
1712
+ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
1713
+ device const uint8_t * q = xb->qs;
1714
+ device const uint8_t * qh = xb->qh;
1715
+
1716
+ #if QK_K == 256
1717
+ const float d = (float)(xb->d);
1718
+ const float min = (float)(xb->dmin);
1719
+ short is = (il/4) * 2;
1720
+ q = q + 32 * (il/4) + 16 * (il&1);
1721
+ qh = qh + 16 * (il&1);
1722
+ uint8_t ul = 1 << (il/2);
1723
+ il = il%4;
1724
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
1725
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1726
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
1727
+
1728
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1729
+ const float qh_val = il<2 ? 16.f : 256.f;
1730
+ for (int i = 0; i < 16; ++i) {
1731
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1732
+ }
1733
+ #else
1734
+ q = q + 16 * (il&1);
1735
+ device const int8_t * s = xb->scales;
1736
+ const float dl = xb->d * s[il];
1737
+ uint8_t m = 1<<(il*2);
1738
+ const float coef = il<2 ? 1.f : 1.f/16.f;
1739
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1740
+ for (int i = 0; i < 16; ++i) {
1741
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
1742
+ }
1743
+ #endif
1744
+ }
1745
+
1746
+ template <typename type4x4>
1747
+ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1748
+ const float d_all = (float)(xb->d);
1749
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
1750
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
1751
+ device const int8_t * scales = (device const int8_t *)xb->scales;
1752
+
1753
+ #if QK_K == 256
1754
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1755
+ qh = qh + 32*(il/8) + 16*(il&1);
1756
+ float sc = scales[(il%2) + 2 * ((il/2))];
1757
+ il = (il/2)%4;
1758
+ #else
1759
+ ql = ql + 16 * (il&1);
1760
+ float sc = scales[il];
1761
+ #endif
1762
+ for (int i = 0; i < 16; ++i) {
1763
+ uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1764
+ uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1765
+ const float coef = il>1 ? 1.f/16.f : 1.f;
1766
+ float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1767
+ ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1768
+ reg[i/4][i%4] = d_all * sc * q * coef;
1769
+ }
1770
+ }
1771
+
1772
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
1773
+ kernel void kernel_get_rows(
1774
+ device const void * src0,
1775
+ device const int * src1,
1776
+ device float * dst,
1777
+ constant int64_t & ne00,
1778
+ constant uint64_t & nb01,
1779
+ constant uint64_t & nb1,
1780
+ uint tgpig[[threadgroup_position_in_grid]],
1781
+ uint tiitg[[thread_index_in_threadgroup]],
1782
+ uint tptg[[threads_per_threadgroup]]) {
1783
+ const int i = tgpig;
1784
+ const int r = ((device int32_t *) src1)[i];
1785
+
1786
+ for (int ind = tiitg; ind < ne00/16; ind += tptg) {
1787
+ float4x4 temp;
1788
+ dequantize_func(
1789
+ ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
1790
+ *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
1791
+ }
1792
+ }
1793
+
1794
+ #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
1795
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
1796
+ #define BLOCK_SIZE_K 32
1797
+ #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
1798
+ #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
1799
+ #define THREAD_PER_BLOCK 128
1800
+ #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
1801
+ #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
1802
+ #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
1803
+ #define SG_MAT_ROW 8
1804
+
1805
+ // each block_q contains 16*nl weights
1806
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
1807
+ kernel void kernel_mul_mm(device const uchar * src0,
1808
+ device const float * src1,
1809
+ device float * dst,
1810
+ constant int64_t & ne00,
1811
+ constant int64_t & ne02,
1812
+ constant int64_t & nb01,
1813
+ constant int64_t & nb02,
1814
+ constant int64_t & ne12,
1815
+ constant int64_t & ne0,
1816
+ constant int64_t & ne1,
1817
+ constant uint & gqa,
1818
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
1819
+ uint3 tgpig[[threadgroup_position_in_grid]],
1820
+ uint tiitg[[thread_index_in_threadgroup]],
1821
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1822
+
1823
+ threadgroup half * sa = ((threadgroup half *)shared_memory);
1824
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
1825
+
1826
+ const uint r0 = tgpig.y;
1827
+ const uint r1 = tgpig.x;
1828
+ const uint im = tgpig.z;
1829
+ // if this block is of 64x32 shape or smaller
1830
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
1831
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
1832
+ // a thread shouldn't load data outside of the matrix
1833
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
1834
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
1835
+
1836
+ simdgroup_half8x8 ma[4];
1837
+ simdgroup_float8x8 mb[2];
1838
+ simdgroup_float8x8 c_res[8];
1839
+ for (int i = 0; i < 8; i++){
1840
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
1841
+ }
1842
+
1843
+ short il = (tiitg % THREAD_PER_ROW);
1844
+ uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
1845
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
1846
+ device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
1847
+ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
1848
+
1849
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
1850
+ //load data and store to threadgroup memory
1851
+ half4x4 temp_a;
1852
+ dequantize_func(x, il, temp_a);
1853
+ #pragma unroll(16)
1854
+ for (int i = 0; i < 16; i++) {
1855
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
1856
+ + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
1857
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
1858
+ }
1859
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
1860
+ = *((device float2x4 *)y);
1861
+ il = (il + 2 < nl) ? il + 2 : il % 2;
1862
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
1863
+ y += BLOCK_SIZE_K;
1864
+
1865
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1866
+ //load matrices from threadgroup memory and conduct outer products
1867
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
1868
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
1869
+ #pragma unroll(4)
1870
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
1871
+ #pragma unroll(4)
1872
+ for (int i = 0; i < 4; i++) {
1873
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
1874
+ }
1875
+ simdgroup_barrier(mem_flags::mem_none);
1876
+ #pragma unroll(2)
1877
+ for (int i = 0; i < 2; i++) {
1878
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
1879
+ }
1880
+
1881
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
1882
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
1883
+ #pragma unroll(8)
1884
+ for (int i = 0; i < 8; i++){
1885
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
1886
+ }
1887
+ }
1888
+ }
1889
+
1890
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
1891
+ device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
1892
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
1893
+ for (int i = 0; i < 8; i++) {
1894
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
1895
+ }
1896
+ } else {
1897
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
1898
+ threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
1899
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
1900
+ for (int i = 0; i < 8; i++) {
1901
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
1902
+ }
1903
+
1904
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1905
+ device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
1906
+ if (sgitg==0) {
1907
+ for (int i = 0; i < n_rows; i++) {
1908
+ for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
1909
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
1910
+ }
1911
+ }
1912
+ }
1971
1913
  }
1972
1914
  }
1915
+
1916
+ #if QK_K == 256
1917
+ #define QK_NL 16
1918
+ #else
1919
+ #define QK_NL 4
1920
+ #endif
1921
+
1922
+ typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
1923
+ constant uint64_t &, constant uint64_t &, uint, uint, uint);
1924
+
1925
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
1926
+ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
1927
+ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
1928
+ template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
1929
+ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
1930
+ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
1931
+ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
1932
+ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
1933
+
1934
+ typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
1935
+ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1936
+ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
1937
+
1938
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
1939
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
1940
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
1941
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
1942
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
1943
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
1944
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
1945
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;