llama_cpp 0.5.1 → 0.5.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
38
38
  device const float4 * src0,
39
39
  device const float4 * src1,
40
40
  device float4 * dst,
41
- constant int64_t & nb,
41
+ constant int64_t & nb,
42
42
  uint tpig[[thread_position_in_grid]]) {
43
43
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
44
  }
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
63
63
  }
64
64
 
65
65
  kernel void kernel_scale(
66
- device const float * src0,
67
- device float * dst,
66
+ device const float4 * src0,
67
+ device float4 * dst,
68
68
  constant float & scale,
69
69
  uint tpig[[thread_position_in_grid]]) {
70
70
  dst[tpig] = src0[tpig] * scale;
71
71
  }
72
72
 
73
73
  kernel void kernel_silu(
74
- device const float * src0,
75
- device float * dst,
74
+ device const float4 * src0,
75
+ device float4 * dst,
76
76
  uint tpig[[thread_position_in_grid]]) {
77
- float x = src0[tpig];
77
+ device const float4 & x = src0[tpig];
78
78
  dst[tpig] = x / (1.0f + exp(-x));
79
79
  }
80
80
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
89
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
90
 
91
91
  kernel void kernel_gelu(
92
- device const float * src0,
93
- device float * dst,
92
+ device const float4 * src0,
93
+ device float4 * dst,
94
94
  uint tpig[[thread_position_in_grid]]) {
95
- float x = src0[tpig];
95
+ device const float4 & x = src0[tpig];
96
96
 
97
97
  // BEWARE !!!
98
98
  // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
107
107
  constant int64_t & ne00,
108
108
  constant int64_t & ne01,
109
109
  constant int64_t & ne02,
110
- threadgroup float * buf [[threadgroup(0)]],
111
110
  uint3 tgpig[[threadgroup_position_in_grid]],
112
111
  uint3 tpitg[[thread_position_in_threadgroup]],
113
112
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
119
118
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
120
119
 
121
120
  // parallel max
122
- buf[tpitg[0]] = -INFINITY;
123
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
124
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
121
+ float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
122
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
+ lmax = MAX(lmax, psrc0[i00]);
125
124
  }
126
-
127
- // reduce
128
- threadgroup_barrier(mem_flags::mem_threadgroup);
129
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
130
- if (tpitg[0] < i) {
131
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
132
- }
133
- threadgroup_barrier(mem_flags::mem_threadgroup);
134
- }
135
-
136
- //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
137
- // the loop, and when that is done, buf[0] has the correct (synchronized) value
138
- //if (tpitg[0] == 0) {
139
- // buf[0] = buf[0];
140
- //}
141
-
142
- //threadgroup_barrier(mem_flags::mem_threadgroup);
143
-
144
- const float max = buf[0];
125
+ const float max = simd_max(lmax);
145
126
 
146
127
  // parallel sum
147
- buf[tpitg[0]] = 0.0f;
128
+ float lsum = 0.0f;
148
129
  for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
149
130
  const float exp_psrc0 = exp(psrc0[i00] - max);
150
- buf[tpitg[0]] += exp_psrc0;
131
+ lsum += exp_psrc0;
151
132
  // Remember the result of exp here. exp is expensive, so we really do not
152
133
  // whish to compute it twice.
153
134
  pdst[i00] = exp_psrc0;
154
135
  }
155
136
 
156
- // reduce
157
- threadgroup_barrier(mem_flags::mem_threadgroup);
158
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
159
- if (tpitg[0] < i) {
160
- buf[tpitg[0]] += buf[tpitg[0] + i];
161
- }
162
- threadgroup_barrier(mem_flags::mem_threadgroup);
137
+ const float sum = simd_sum(lsum);
138
+
139
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
+ pdst[i00] /= sum;
141
+ }
142
+ }
143
+
144
+ kernel void kernel_soft_max_4(
145
+ device const float * src0,
146
+ device float * dst,
147
+ constant int64_t & ne00,
148
+ constant int64_t & ne01,
149
+ constant int64_t & ne02,
150
+ uint3 tgpig[[threadgroup_position_in_grid]],
151
+ uint3 tpitg[[thread_position_in_threadgroup]],
152
+ uint3 ntg[[threads_per_threadgroup]]) {
153
+ const int64_t i03 = tgpig[2];
154
+ const int64_t i02 = tgpig[1];
155
+ const int64_t i01 = tgpig[0];
156
+
157
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
+
160
+ // parallel max
161
+ float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
162
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
+ lmax4 = fmax(lmax4, psrc4[i00]);
163
164
  }
165
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
164
166
 
165
- // broadcast - not needed, see above
166
- //// broadcast
167
- //if (tpitg[0] == 0) {
168
- // buf[0] = buf[0];
169
- //}
167
+ const float max = simd_max(lmax);
170
168
 
171
- //threadgroup_barrier(mem_flags::mem_threadgroup);
169
+ // parallel sum
170
+ float4 lsum4 = 0.0f;
171
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
+ lsum4 += exp_psrc4;
174
+ pdst4[i00] = exp_psrc4;
175
+ }
176
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
172
177
 
173
- const float sum = buf[0];
178
+ const float sum = simd_sum(lsum);
174
179
 
175
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
176
- pdst[i00] /= sum;
180
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
+ pdst4[i00] /= sum;
177
182
  }
178
183
  }
179
184
 
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
192
197
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
193
198
  } else {
194
199
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
+ }
201
+ }
202
+
203
+ kernel void kernel_diag_mask_inf_8(
204
+ device const float4 * src0,
205
+ device float4 * dst,
206
+ constant int64_t & ne00,
207
+ constant int64_t & ne01,
208
+ constant int & n_past,
209
+ uint3 tpig[[thread_position_in_grid]]) {
210
+
211
+ const int64_t i = 2*tpig[0];
212
+
213
+ dst[i+0] = src0[i+0];
214
+ dst[i+1] = src0[i+1];
215
+ int64_t i4 = 4*i;
216
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
217
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
218
+ const int64_t i00 = i4;
219
+ for (int k = 3; k >= 0; --k) {
220
+ if (i00 + 4 + k <= n_past + i01) {
221
+ break;
222
+ }
223
+ dst[i+1][k] = -INFINITY;
224
+ if (i00 + k > n_past + i01) {
225
+ dst[i][k] = -INFINITY;
226
+ }
195
227
  }
196
228
  }
197
229
 
@@ -491,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
491
523
  }
492
524
  }
493
525
 
526
+ #define N_F32_F32 4
527
+
528
+ kernel void kernel_mul_mat_f32_f32(
529
+ device const char * src0,
530
+ device const char * src1,
531
+ device float * dst,
532
+ constant int64_t & ne00,
533
+ constant int64_t & ne01,
534
+ constant int64_t & ne02,
535
+ constant uint64_t & nb00,
536
+ constant uint64_t & nb01,
537
+ constant uint64_t & nb02,
538
+ constant int64_t & ne10,
539
+ constant int64_t & ne11,
540
+ constant int64_t & ne12,
541
+ constant uint64_t & nb10,
542
+ constant uint64_t & nb11,
543
+ constant uint64_t & nb12,
544
+ constant int64_t & ne0,
545
+ constant int64_t & ne1,
546
+ uint3 tgpig[[threadgroup_position_in_grid]],
547
+ uint tiisg[[thread_index_in_simdgroup]]) {
548
+
549
+ const int64_t r0 = tgpig.x;
550
+ const int64_t rb = tgpig.y*N_F32_F32;
551
+ const int64_t im = tgpig.z;
552
+
553
+ device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
554
+
555
+ if (ne00 < 128) {
556
+ for (int row = 0; row < N_F32_F32; ++row) {
557
+ int r1 = rb + row;
558
+ if (r1 >= ne11) {
559
+ break;
560
+ }
561
+
562
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
563
+
564
+ float sumf = 0;
565
+ for (int i = tiisg; i < ne00; i += 32) {
566
+ sumf += (float) x[i] * (float) y[i];
567
+ }
568
+
569
+ float all_sum = simd_sum(sumf);
570
+ if (tiisg == 0) {
571
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
572
+ }
573
+ }
574
+ } else {
575
+ device const float4 * x4 = (device const float4 *)x;
576
+ for (int row = 0; row < N_F32_F32; ++row) {
577
+ int r1 = rb + row;
578
+ if (r1 >= ne11) {
579
+ break;
580
+ }
581
+
582
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
583
+ device const float4 * y4 = (device const float4 *) y;
584
+
585
+ float sumf = 0;
586
+ for (int i = tiisg; i < ne00/4; i += 32) {
587
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
588
+ }
589
+
590
+ float all_sum = simd_sum(sumf);
591
+ if (tiisg == 0) {
592
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
593
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
594
+ }
595
+ }
596
+ }
597
+ }
598
+
494
599
  kernel void kernel_mul_mat_f16_f32_1row(
495
600
  device const char * src0,
496
601
  device const char * src1,
@@ -616,6 +721,49 @@ kernel void kernel_mul_mat_f16_f32(
616
721
  }
617
722
  }
618
723
 
724
+ // Assumes row size (ne00) is a multiple of 4
725
+ kernel void kernel_mul_mat_f16_f32_l4(
726
+ device const char * src0,
727
+ device const char * src1,
728
+ device float * dst,
729
+ constant int64_t & ne00,
730
+ constant int64_t & ne01,
731
+ constant int64_t & ne02,
732
+ constant uint64_t & nb00,
733
+ constant uint64_t & nb01,
734
+ constant uint64_t & nb02,
735
+ constant int64_t & ne10,
736
+ constant int64_t & ne11,
737
+ constant int64_t & ne12,
738
+ constant uint64_t & nb10,
739
+ constant uint64_t & nb11,
740
+ constant uint64_t & nb12,
741
+ constant int64_t & ne0,
742
+ constant int64_t & ne1,
743
+ uint3 tgpig[[threadgroup_position_in_grid]],
744
+ uint tiisg[[thread_index_in_simdgroup]]) {
745
+
746
+ const int nrows = ne11;
747
+ const int64_t r0 = tgpig.x;
748
+ const int64_t im = tgpig.z;
749
+
750
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
751
+
752
+ for (int r1 = 0; r1 < nrows; ++r1) {
753
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
754
+
755
+ float sumf = 0;
756
+ for (int i = tiisg; i < ne00/4; i += 32) {
757
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
758
+ }
759
+
760
+ float all_sum = simd_sum(sumf);
761
+ if (tiisg == 0) {
762
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
763
+ }
764
+ }
765
+ }
766
+
619
767
  kernel void kernel_alibi_f32(
620
768
  device const float * src0,
621
769
  device float * dst,
@@ -1123,31 +1271,40 @@ kernel void kernel_mul_mat_q3_K_f32(
1123
1271
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1124
1272
  device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1125
1273
 
1126
- float yl[16];
1274
+ float yl[32];
1127
1275
 
1128
- const uint16_t kmask1 = 0x0303;
1276
+ const uint16_t kmask1 = 0x3030;
1129
1277
  const uint16_t kmask2 = 0x0f0f;
1130
1278
 
1131
- const int tid = tiisg/2;
1132
- const int ix = tiisg%2;
1133
- const int ip = tid/8; // 0 or 1
1134
- const int il = tid/2 - 4*ip; // 0...3
1279
+ const int tid = tiisg/4;
1280
+ const int ix = tiisg%4;
1281
+ const int ip = tid/4; // 0 or 1
1282
+ const int il = 2*((tid%4)/2); // 0 or 2
1135
1283
  const int ir = tid%2;
1136
1284
  const int n = 8;
1137
1285
  const int l0 = n*ir;
1138
1286
 
1139
- const uint16_t m1 = 1 << (4*ip + il);
1140
- const uint16_t m2 = m1 << 8;
1287
+ // One would think that the Metal compiler would figure out that ip and il can only have
1288
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1289
+ // with these two tales.
1290
+ //
1291
+ // Possible masks for the high bit
1292
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1293
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1294
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1295
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1296
+
1297
+ // Possible masks for the low 2 bits
1298
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1299
+
1300
+ const ushort4 hm = mm[2*ip + il/2];
1141
1301
 
1142
1302
  const int shift = 2*il;
1143
- const uint16_t qm1 = 0x0003 << shift;
1144
- const uint16_t qm2 = 0x0300 << shift;
1145
- const int32_t v1 = 4 << shift;
1146
- const int32_t v2 = 1024 << shift;
1303
+ const float v1 = il == 0 ? 4.f : 64.f;
1304
+ const float v2 = 4.f * v1;
1147
1305
 
1148
1306
  const uint16_t s_shift1 = 4*ip;
1149
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1150
- const int ik = 4 + (il%2);
1307
+ const uint16_t s_shift2 = s_shift1 + il;
1151
1308
 
1152
1309
  const int q_offset = 32*ip + l0;
1153
1310
  const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1313,19 @@ kernel void kernel_mul_mat_q3_K_f32(
1156
1313
 
1157
1314
  device const float * y1 = yy + ix*QK_K + y_offset;
1158
1315
 
1159
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1160
- for (int i = ix; i < nb; i += 2) {
1316
+ uint32_t scales32, aux32;
1317
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1318
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
1319
+
1320
+ float sumf1[2] = {0.f};
1321
+ float sumf2[2] = {0.f};
1322
+ for (int i = ix; i < nb; i += 4) {
1161
1323
 
1162
1324
  for (int l = 0; l < 8; ++l) {
1163
- yl[l+0] = y1[l+ 0];
1164
- yl[l+8] = y1[l+16];
1325
+ yl[l+ 0] = y1[l+ 0];
1326
+ yl[l+ 8] = y1[l+16];
1327
+ yl[l+16] = y1[l+32];
1328
+ yl[l+24] = y1[l+48];
1165
1329
  }
1166
1330
 
1167
1331
  device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1336,43 @@ kernel void kernel_mul_mat_q3_K_f32(
1172
1336
  for (int row = 0; row < 2; ++row) {
1173
1337
 
1174
1338
  const float d_all = (float)dh[0];
1175
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1176
1339
 
1177
- float s1 = 0, s2 = 0;
1340
+ scales16[0] = a[4];
1341
+ scales16[1] = a[5];
1342
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1343
+ scales16[0] = a[il+0];
1344
+ scales16[1] = a[il+1];
1345
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1346
+
1347
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
1178
1348
  for (int l = 0; l < n; l += 2) {
1179
- const uint16_t qs = q[l/2];
1180
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1181
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1349
+ const int32_t qs = q[l/2];
1350
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
1351
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
1352
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1353
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
1354
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
1355
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
1182
1356
  }
1183
- float d = d_all * (s1 + 1.f/256.f * s2);
1184
- sumf1[row] += d * scales[0];
1185
- sumf2[row] += d;
1357
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1358
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1359
+ sumf1[row] += d1 * (scales[0] - 32);
1360
+ sumf2[row] += d2 * (scales[2] - 32);
1186
1361
 
1187
- s1 = s2 = 0;
1362
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
1188
1363
  for (int l = 0; l < n; l += 2) {
1189
- const uint16_t qs = q[l/2+8];
1190
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1191
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1364
+ const int32_t qs = q[l/2+8];
1365
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
1366
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
1367
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1368
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
1369
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
1370
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
1192
1371
  }
1193
- d = d_all * (s1 + 1.f/256.f * s2);
1194
- sumf1[row] += d * scales[1];
1195
- sumf2[row] += d;
1372
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1373
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1374
+ sumf1[row] += d1 * (scales[1] - 32);
1375
+ sumf2[row] += d2 * (scales[3] - 32);
1196
1376
 
1197
1377
  q += step;
1198
1378
  h += step;
@@ -1201,15 +1381,17 @@ kernel void kernel_mul_mat_q3_K_f32(
1201
1381
 
1202
1382
  }
1203
1383
 
1204
- y1 += 2 * QK_K;
1384
+ y1 += 4 * QK_K;
1205
1385
 
1206
1386
  }
1207
1387
 
1208
1388
  for (int row = 0; row < 2; ++row) {
1209
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1210
- const float tot = simd_sum(sumf);
1211
- if (tiisg == 0) {
1212
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1389
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1390
+ sumf1[row] = simd_sum(sumf);
1391
+ }
1392
+ if (tiisg == 0) {
1393
+ for (int row = 0; row < 2; ++row) {
1394
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1213
1395
  }
1214
1396
  }
1215
1397
  }
@@ -1290,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
1290
1472
  device const float * src1,
1291
1473
  device float * dst,
1292
1474
  constant int64_t & ne00,
1293
- constant int64_t & ne01[[buffer(4)]],
1294
- constant int64_t & ne02[[buffer(5)]],
1295
- constant int64_t & ne10[[buffer(9)]],
1296
- constant int64_t & ne12[[buffer(11)]],
1297
- constant int64_t & ne0[[buffer(15)]],
1298
- constant int64_t & ne1[[buffer(16)]],
1299
- constant uint & gqa[[buffer(17)]],
1475
+ constant int64_t & ne01 [[buffer(4)]],
1476
+ constant int64_t & ne02 [[buffer(5)]],
1477
+ constant int64_t & ne10 [[buffer(9)]],
1478
+ constant int64_t & ne12 [[buffer(11)]],
1479
+ constant int64_t & ne0 [[buffer(15)]],
1480
+ constant int64_t & ne1 [[buffer(16)]],
1481
+ constant uint & gqa [[buffer(17)]],
1300
1482
  uint3 tgpig[[threadgroup_position_in_grid]],
1301
1483
  uint tiisg[[thread_index_in_simdgroup]],
1302
1484
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1564,17 +1746,25 @@ kernel void kernel_mul_mat_q5_K_f32(
1564
1746
  sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1565
1747
  sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1566
1748
 
1567
- float4 acc = {0.f, 0.f, 0.f, 0.f};
1749
+ float4 acc1 = {0.f};
1750
+ float4 acc2 = {0.f};
1568
1751
  for (int l = 0; l < n; ++l) {
1569
1752
  uint8_t h = qh[l];
1570
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1571
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1572
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1573
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1753
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
1754
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
1755
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
1756
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
1757
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
1758
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
1759
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
1760
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
1574
1761
  }
1575
1762
  const float dall = dh[0];
1576
1763
  const float dmin = dh[1];
1577
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1764
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
1765
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
1766
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
1767
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
1578
1768
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1579
1769
 
1580
1770
  q1 += step;
@@ -1747,6 +1937,15 @@ kernel void kernel_mul_mat_q6_K_f32(
1747
1937
 
1748
1938
  //============================= templates and their specializations =============================
1749
1939
 
1940
+ // NOTE: this is not dequantizing - we are simply fitting the template
1941
+ template <typename type4x4>
1942
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
1943
+ float4x4 temp = *(((device float4x4 *)src));
1944
+ for (int i = 0; i < 16; i++){
1945
+ reg[i/4][i%4] = temp[i/4][i%4];
1946
+ }
1947
+ }
1948
+
1750
1949
  template <typename type4x4>
1751
1950
  void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1752
1951
  half4x4 temp = *(((device half4x4 *)src));
@@ -1758,28 +1957,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1758
1957
  template <typename type4x4>
1759
1958
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1760
1959
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1761
- const half d = il ? (xb->d / 16.h) : xb->d;
1762
- const half m = il ? ( -8.h * 16.h) : -8.h;
1960
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1961
+ const float d2 = d1 / 256.f;
1962
+ const float md = -8.h * xb->d;
1763
1963
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1764
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1964
+ const ushort mask1 = mask0 << 8;
1765
1965
 
1766
1966
  for (int i=0;i<8;i++) {
1767
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1768
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1967
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1968
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1769
1969
  }
1770
1970
  }
1771
1971
 
1772
1972
  template <typename type4x4>
1773
1973
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1774
1974
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1775
- const half d = il ? (xb->d / 16.h) : xb->d;
1776
- const half m = xb->m;
1975
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1976
+ const float d2 = d1 / 256.f;
1977
+ const float m = xb->m;
1777
1978
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1778
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1979
+ const ushort mask1 = mask0 << 8;
1779
1980
 
1780
1981
  for (int i=0;i<8;i++) {
1781
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1782
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1982
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
1983
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
1783
1984
  }
1784
1985
  }
1785
1986
 
@@ -1815,7 +2016,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
1815
2016
 
1816
2017
  template <typename type4x4>
1817
2018
  void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1818
- const float d_all = (float)(xb->d);
2019
+ const half d_all = xb->d;
1819
2020
  device const uint8_t * q = (device const uint8_t *)xb->qs;
1820
2021
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
1821
2022
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1828,16 +2029,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1828
2029
  ((il/4)>0 ? 12 : 3);
1829
2030
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1830
2031
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1831
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1832
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1833
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
2032
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
2033
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
2034
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
2035
+ const half ml = 4.h * dl;
1834
2036
 
1835
- il = (il/2)%4;
1836
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1837
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2037
+ il = (il/2) & 3;
2038
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
2039
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2040
+ dl *= coef;
1838
2041
 
1839
2042
  for (int i = 0; i < 16; ++i) {
1840
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
2043
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1841
2044
  }
1842
2045
  #else
1843
2046
  float kcoef = il&1 ? 1.f/16.f : 1.f;
@@ -1852,26 +2055,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1852
2055
  #endif
1853
2056
  }
1854
2057
 
2058
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
2059
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
2060
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
2061
+ }
2062
+
1855
2063
  template <typename type4x4>
1856
2064
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1857
- device const uint8_t * q = xb->qs;
2065
+ device const uchar * q = xb->qs;
1858
2066
 
1859
2067
  #if QK_K == 256
1860
- const float d = (float)(xb->d);
1861
- const float min = (float)(xb->dmin);
1862
2068
  short is = (il/4) * 2;
1863
2069
  q = q + (il/4) * 32 + 16 * (il&1);
1864
- il = il%4;
1865
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1866
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1867
- const float ml = il<2 ? min * sc[1] : min * sc[3];
2070
+ il = il & 3;
2071
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2072
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2073
+ const half min = xb->dmin;
2074
+ const half dl = d * sc[0];
2075
+ const half ml = min * sc[1];
1868
2076
  #else
1869
2077
  q = q + 16 * (il&1);
1870
2078
  device const uint8_t * s = xb->scales;
1871
2079
  device const half2 * dh = (device const half2 *)xb->d;
1872
2080
  const float2 d = (float2)dh[0];
1873
2081
  const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1874
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
2082
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
1875
2083
  #endif
1876
2084
  const ushort mask = il<2 ? 0x0F : 0xF0;
1877
2085
  for (int i = 0; i < 16; ++i) {
@@ -1885,19 +2093,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1885
2093
  device const uint8_t * qh = xb->qh;
1886
2094
 
1887
2095
  #if QK_K == 256
1888
- const float d = (float)(xb->d);
1889
- const float min = (float)(xb->dmin);
1890
2096
  short is = (il/4) * 2;
1891
2097
  q = q + 32 * (il/4) + 16 * (il&1);
1892
2098
  qh = qh + 16 * (il&1);
1893
2099
  uint8_t ul = 1 << (il/2);
1894
- il = il%4;
1895
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1896
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1897
- const float ml = il<2 ? min * sc[1] : min * sc[3];
2100
+ il = il & 3;
2101
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2102
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2103
+ const half min = xb->dmin;
2104
+ const half dl = d * sc[0];
2105
+ const half ml = min * sc[1];
1898
2106
 
1899
- const ushort mask = il<2 ? 0x0F : 0xF0;
1900
- const float qh_val = il<2 ? 16.f : 256.f;
2107
+ const ushort mask = il<2 ? 0x0F : 0xF0;
2108
+ const half qh_val = il<2 ? 16.h : 256.h;
1901
2109
  for (int i = 0; i < 16; ++i) {
1902
2110
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1903
2111
  }
@@ -1916,7 +2124,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1916
2124
 
1917
2125
  template <typename type4x4>
1918
2126
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1919
- const float d_all = (float)(xb->d);
2127
+ const half d_all = xb->d;
1920
2128
  device const uint8_t * ql = (device const uint8_t *)xb->ql;
1921
2129
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
1922
2130
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1924,19 +2132,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
1924
2132
  #if QK_K == 256
1925
2133
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1926
2134
  qh = qh + 32*(il/8) + 16*(il&1);
1927
- float sc = scales[(il%2) + 2 * ((il/2))];
1928
- il = (il/2)%4;
2135
+ half sc = scales[(il%2) + 2 * ((il/2))];
2136
+ il = (il/2) & 3;
1929
2137
  #else
1930
2138
  ql = ql + 16 * (il&1);
1931
- float sc = scales[il];
2139
+ half sc = scales[il];
1932
2140
  #endif
2141
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2142
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
2143
+ const half coef = il>1 ? 1.f/16.h : 1.h;
2144
+ const half ml = d_all * sc * 32.h;
2145
+ const half dl = d_all * sc * coef;
1933
2146
  for (int i = 0; i < 16; ++i) {
1934
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1935
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1936
- const float coef = il>1 ? 1.f/16.f : 1.f;
1937
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1938
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1939
- reg[i/4][i%4] = d_all * sc * q * coef;
2147
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
2148
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
2149
+ reg[i/4][i%4] = dl * q - ml;
1940
2150
  }
1941
2151
  }
1942
2152
 
@@ -1976,22 +2186,25 @@ kernel void kernel_get_rows(
1976
2186
  // each block_q contains 16*nl weights
1977
2187
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
1978
2188
  kernel void kernel_mul_mm(device const uchar * src0,
1979
- device const float * src1,
1980
- device float * dst,
1981
- constant int64_t & ne00,
1982
- constant int64_t & ne02,
1983
- constant int64_t & nb01,
1984
- constant int64_t & nb02,
1985
- constant int64_t & ne12,
1986
- constant int64_t & ne0,
1987
- constant int64_t & ne1,
1988
- constant uint & gqa,
1989
- threadgroup uchar * shared_memory [[threadgroup(0)]],
1990
- uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiitg[[thread_index_in_threadgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
-
1994
- threadgroup half * sa = ((threadgroup half *)shared_memory);
2189
+ device const uchar * src1,
2190
+ device float * dst,
2191
+ constant int64_t & ne00,
2192
+ constant int64_t & ne02,
2193
+ constant int64_t & nb01,
2194
+ constant int64_t & nb02,
2195
+ constant int64_t & ne12,
2196
+ constant int64_t & nb10,
2197
+ constant int64_t & nb11,
2198
+ constant int64_t & nb12,
2199
+ constant int64_t & ne0,
2200
+ constant int64_t & ne1,
2201
+ constant uint & gqa,
2202
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
2203
+ uint3 tgpig[[threadgroup_position_in_grid]],
2204
+ uint tiitg[[thread_index_in_threadgroup]],
2205
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2206
+
2207
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
1995
2208
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
1996
2209
 
1997
2210
  const uint r0 = tgpig.y;
@@ -2004,7 +2217,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2004
2217
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2005
2218
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
2006
2219
 
2007
- simdgroup_half8x8 ma[4];
2220
+ simdgroup_half8x8 ma[4];
2008
2221
  simdgroup_float8x8 mb[2];
2009
2222
  simdgroup_float8x8 c_res[8];
2010
2223
  for (int i = 0; i < 8; i++){
@@ -2012,10 +2225,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
2012
2225
  }
2013
2226
 
2014
2227
  short il = (tiitg % THREAD_PER_ROW);
2015
- uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
2016
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2017
- device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
2018
- + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
2228
+
2229
+ uint offset0 = im/gqa*nb02;
2230
+ ushort offset1 = il/nl;
2231
+
2232
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2233
+ device const float * y = (device const float *)(src1
2234
+ + nb12 * im
2235
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
2236
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2019
2237
 
2020
2238
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2021
2239
  //load data and store to threadgroup memory
@@ -2095,6 +2313,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2095
2313
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2096
2314
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
2097
2315
 
2316
+ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2098
2317
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2099
2318
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2100
2319
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@@ -2105,14 +2324,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
2105
2324
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2106
2325
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2107
2326
 
2108
- typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
2109
- constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
2110
- constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
2111
-
2112
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2113
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2114
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2115
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2327
+ typedef void (mat_mm_t)(
2328
+ device const uchar * src0,
2329
+ device const uchar * src1,
2330
+ device float * dst,
2331
+ constant int64_t & ne00,
2332
+ constant int64_t & ne02,
2333
+ constant int64_t & nb01,
2334
+ constant int64_t & nb02,
2335
+ constant int64_t & ne12,
2336
+ constant int64_t & nb10,
2337
+ constant int64_t & nb11,
2338
+ constant int64_t & nb12,
2339
+ constant int64_t & ne0,
2340
+ constant int64_t & ne1,
2341
+ constant uint & gqa,
2342
+ threadgroup uchar *, uint3, uint, uint);
2343
+
2344
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2345
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2346
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2347
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2348
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2116
2349
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2117
2350
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2118
2351
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;