llama_cpp 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
38
38
  device const float4 * src0,
39
39
  device const float4 * src1,
40
40
  device float4 * dst,
41
- constant int64_t & nb,
41
+ constant int64_t & nb,
42
42
  uint tpig[[thread_position_in_grid]]) {
43
43
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
44
  }
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
63
63
  }
64
64
 
65
65
  kernel void kernel_scale(
66
- device const float * src0,
67
- device float * dst,
66
+ device const float4 * src0,
67
+ device float4 * dst,
68
68
  constant float & scale,
69
69
  uint tpig[[thread_position_in_grid]]) {
70
70
  dst[tpig] = src0[tpig] * scale;
71
71
  }
72
72
 
73
73
  kernel void kernel_silu(
74
- device const float * src0,
75
- device float * dst,
74
+ device const float4 * src0,
75
+ device float4 * dst,
76
76
  uint tpig[[thread_position_in_grid]]) {
77
- float x = src0[tpig];
77
+ device const float4 & x = src0[tpig];
78
78
  dst[tpig] = x / (1.0f + exp(-x));
79
79
  }
80
80
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
89
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
90
 
91
91
  kernel void kernel_gelu(
92
- device const float * src0,
93
- device float * dst,
92
+ device const float4 * src0,
93
+ device float4 * dst,
94
94
  uint tpig[[thread_position_in_grid]]) {
95
- float x = src0[tpig];
95
+ device const float4 & x = src0[tpig];
96
96
 
97
97
  // BEWARE !!!
98
98
  // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
107
107
  constant int64_t & ne00,
108
108
  constant int64_t & ne01,
109
109
  constant int64_t & ne02,
110
- threadgroup float * buf [[threadgroup(0)]],
111
110
  uint3 tgpig[[threadgroup_position_in_grid]],
112
111
  uint3 tpitg[[thread_position_in_threadgroup]],
113
112
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
119
118
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
120
119
 
121
120
  // parallel max
122
- buf[tpitg[0]] = -INFINITY;
123
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
124
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
121
+ float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
122
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
+ lmax = MAX(lmax, psrc0[i00]);
125
124
  }
126
-
127
- // reduce
128
- threadgroup_barrier(mem_flags::mem_threadgroup);
129
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
130
- if (tpitg[0] < i) {
131
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
132
- }
133
- threadgroup_barrier(mem_flags::mem_threadgroup);
134
- }
135
-
136
- //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
137
- // the loop, and when that is done, buf[0] has the correct (synchronized) value
138
- //if (tpitg[0] == 0) {
139
- // buf[0] = buf[0];
140
- //}
141
-
142
- //threadgroup_barrier(mem_flags::mem_threadgroup);
143
-
144
- const float max = buf[0];
125
+ const float max = simd_max(lmax);
145
126
 
146
127
  // parallel sum
147
- buf[tpitg[0]] = 0.0f;
128
+ float lsum = 0.0f;
148
129
  for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
149
130
  const float exp_psrc0 = exp(psrc0[i00] - max);
150
- buf[tpitg[0]] += exp_psrc0;
131
+ lsum += exp_psrc0;
151
132
  // Remember the result of exp here. exp is expensive, so we really do not
152
133
  // whish to compute it twice.
153
134
  pdst[i00] = exp_psrc0;
154
135
  }
155
136
 
156
- // reduce
157
- threadgroup_barrier(mem_flags::mem_threadgroup);
158
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
159
- if (tpitg[0] < i) {
160
- buf[tpitg[0]] += buf[tpitg[0] + i];
161
- }
162
- threadgroup_barrier(mem_flags::mem_threadgroup);
137
+ const float sum = simd_sum(lsum);
138
+
139
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
+ pdst[i00] /= sum;
141
+ }
142
+ }
143
+
144
+ kernel void kernel_soft_max_4(
145
+ device const float * src0,
146
+ device float * dst,
147
+ constant int64_t & ne00,
148
+ constant int64_t & ne01,
149
+ constant int64_t & ne02,
150
+ uint3 tgpig[[threadgroup_position_in_grid]],
151
+ uint3 tpitg[[thread_position_in_threadgroup]],
152
+ uint3 ntg[[threads_per_threadgroup]]) {
153
+ const int64_t i03 = tgpig[2];
154
+ const int64_t i02 = tgpig[1];
155
+ const int64_t i01 = tgpig[0];
156
+
157
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
+
160
+ // parallel max
161
+ float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
162
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
+ lmax4 = fmax(lmax4, psrc4[i00]);
163
164
  }
165
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
164
166
 
165
- // broadcast - not needed, see above
166
- //// broadcast
167
- //if (tpitg[0] == 0) {
168
- // buf[0] = buf[0];
169
- //}
167
+ const float max = simd_max(lmax);
170
168
 
171
- //threadgroup_barrier(mem_flags::mem_threadgroup);
169
+ // parallel sum
170
+ float4 lsum4 = 0.0f;
171
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
+ lsum4 += exp_psrc4;
174
+ pdst4[i00] = exp_psrc4;
175
+ }
176
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
172
177
 
173
- const float sum = buf[0];
178
+ const float sum = simd_sum(lsum);
174
179
 
175
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
176
- pdst[i00] /= sum;
180
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
+ pdst4[i00] /= sum;
177
182
  }
178
183
  }
179
184
 
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
192
197
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
193
198
  } else {
194
199
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
+ }
201
+ }
202
+
203
+ kernel void kernel_diag_mask_inf_8(
204
+ device const float4 * src0,
205
+ device float4 * dst,
206
+ constant int64_t & ne00,
207
+ constant int64_t & ne01,
208
+ constant int & n_past,
209
+ uint3 tpig[[thread_position_in_grid]]) {
210
+
211
+ const int64_t i = 2*tpig[0];
212
+
213
+ dst[i+0] = src0[i+0];
214
+ dst[i+1] = src0[i+1];
215
+ int64_t i4 = 4*i;
216
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
217
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
218
+ const int64_t i00 = i4;
219
+ for (int k = 3; k >= 0; --k) {
220
+ if (i00 + 4 + k <= n_past + i01) {
221
+ break;
222
+ }
223
+ dst[i+1][k] = -INFINITY;
224
+ if (i00 + k > n_past + i01) {
225
+ dst[i][k] = -INFINITY;
226
+ }
195
227
  }
196
228
  }
197
229
 
@@ -491,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
491
523
  }
492
524
  }
493
525
 
526
+ #define N_F32_F32 4
527
+
528
+ kernel void kernel_mul_mat_f32_f32(
529
+ device const char * src0,
530
+ device const char * src1,
531
+ device float * dst,
532
+ constant int64_t & ne00,
533
+ constant int64_t & ne01,
534
+ constant int64_t & ne02,
535
+ constant uint64_t & nb00,
536
+ constant uint64_t & nb01,
537
+ constant uint64_t & nb02,
538
+ constant int64_t & ne10,
539
+ constant int64_t & ne11,
540
+ constant int64_t & ne12,
541
+ constant uint64_t & nb10,
542
+ constant uint64_t & nb11,
543
+ constant uint64_t & nb12,
544
+ constant int64_t & ne0,
545
+ constant int64_t & ne1,
546
+ uint3 tgpig[[threadgroup_position_in_grid]],
547
+ uint tiisg[[thread_index_in_simdgroup]]) {
548
+
549
+ const int64_t r0 = tgpig.x;
550
+ const int64_t rb = tgpig.y*N_F32_F32;
551
+ const int64_t im = tgpig.z;
552
+
553
+ device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
554
+
555
+ if (ne00 < 128) {
556
+ for (int row = 0; row < N_F32_F32; ++row) {
557
+ int r1 = rb + row;
558
+ if (r1 >= ne11) {
559
+ break;
560
+ }
561
+
562
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
563
+
564
+ float sumf = 0;
565
+ for (int i = tiisg; i < ne00; i += 32) {
566
+ sumf += (float) x[i] * (float) y[i];
567
+ }
568
+
569
+ float all_sum = simd_sum(sumf);
570
+ if (tiisg == 0) {
571
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
572
+ }
573
+ }
574
+ } else {
575
+ device const float4 * x4 = (device const float4 *)x;
576
+ for (int row = 0; row < N_F32_F32; ++row) {
577
+ int r1 = rb + row;
578
+ if (r1 >= ne11) {
579
+ break;
580
+ }
581
+
582
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
583
+ device const float4 * y4 = (device const float4 *) y;
584
+
585
+ float sumf = 0;
586
+ for (int i = tiisg; i < ne00/4; i += 32) {
587
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
588
+ }
589
+
590
+ float all_sum = simd_sum(sumf);
591
+ if (tiisg == 0) {
592
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
593
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
594
+ }
595
+ }
596
+ }
597
+ }
598
+
494
599
  kernel void kernel_mul_mat_f16_f32_1row(
495
600
  device const char * src0,
496
601
  device const char * src1,
@@ -616,6 +721,49 @@ kernel void kernel_mul_mat_f16_f32(
616
721
  }
617
722
  }
618
723
 
724
+ // Assumes row size (ne00) is a multiple of 4
725
+ kernel void kernel_mul_mat_f16_f32_l4(
726
+ device const char * src0,
727
+ device const char * src1,
728
+ device float * dst,
729
+ constant int64_t & ne00,
730
+ constant int64_t & ne01,
731
+ constant int64_t & ne02,
732
+ constant uint64_t & nb00,
733
+ constant uint64_t & nb01,
734
+ constant uint64_t & nb02,
735
+ constant int64_t & ne10,
736
+ constant int64_t & ne11,
737
+ constant int64_t & ne12,
738
+ constant uint64_t & nb10,
739
+ constant uint64_t & nb11,
740
+ constant uint64_t & nb12,
741
+ constant int64_t & ne0,
742
+ constant int64_t & ne1,
743
+ uint3 tgpig[[threadgroup_position_in_grid]],
744
+ uint tiisg[[thread_index_in_simdgroup]]) {
745
+
746
+ const int nrows = ne11;
747
+ const int64_t r0 = tgpig.x;
748
+ const int64_t im = tgpig.z;
749
+
750
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
751
+
752
+ for (int r1 = 0; r1 < nrows; ++r1) {
753
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
754
+
755
+ float sumf = 0;
756
+ for (int i = tiisg; i < ne00/4; i += 32) {
757
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
758
+ }
759
+
760
+ float all_sum = simd_sum(sumf);
761
+ if (tiisg == 0) {
762
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
763
+ }
764
+ }
765
+ }
766
+
619
767
  kernel void kernel_alibi_f32(
620
768
  device const float * src0,
621
769
  device float * dst,
@@ -1123,31 +1271,40 @@ kernel void kernel_mul_mat_q3_K_f32(
1123
1271
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1124
1272
  device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1125
1273
 
1126
- float yl[16];
1274
+ float yl[32];
1127
1275
 
1128
- const uint16_t kmask1 = 0x0303;
1276
+ const uint16_t kmask1 = 0x3030;
1129
1277
  const uint16_t kmask2 = 0x0f0f;
1130
1278
 
1131
- const int tid = tiisg/2;
1132
- const int ix = tiisg%2;
1133
- const int ip = tid/8; // 0 or 1
1134
- const int il = tid/2 - 4*ip; // 0...3
1279
+ const int tid = tiisg/4;
1280
+ const int ix = tiisg%4;
1281
+ const int ip = tid/4; // 0 or 1
1282
+ const int il = 2*((tid%4)/2); // 0 or 2
1135
1283
  const int ir = tid%2;
1136
1284
  const int n = 8;
1137
1285
  const int l0 = n*ir;
1138
1286
 
1139
- const uint16_t m1 = 1 << (4*ip + il);
1140
- const uint16_t m2 = m1 << 8;
1287
+ // One would think that the Metal compiler would figure out that ip and il can only have
1288
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1289
+ // with these two tales.
1290
+ //
1291
+ // Possible masks for the high bit
1292
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1293
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1294
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1295
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1296
+
1297
+ // Possible masks for the low 2 bits
1298
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1299
+
1300
+ const ushort4 hm = mm[2*ip + il/2];
1141
1301
 
1142
1302
  const int shift = 2*il;
1143
- const uint16_t qm1 = 0x0003 << shift;
1144
- const uint16_t qm2 = 0x0300 << shift;
1145
- const int32_t v1 = 4 << shift;
1146
- const int32_t v2 = 1024 << shift;
1303
+ const float v1 = il == 0 ? 4.f : 64.f;
1304
+ const float v2 = 4.f * v1;
1147
1305
 
1148
1306
  const uint16_t s_shift1 = 4*ip;
1149
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1150
- const int ik = 4 + (il%2);
1307
+ const uint16_t s_shift2 = s_shift1 + il;
1151
1308
 
1152
1309
  const int q_offset = 32*ip + l0;
1153
1310
  const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1313,19 @@ kernel void kernel_mul_mat_q3_K_f32(
1156
1313
 
1157
1314
  device const float * y1 = yy + ix*QK_K + y_offset;
1158
1315
 
1159
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1160
- for (int i = ix; i < nb; i += 2) {
1316
+ uint32_t scales32, aux32;
1317
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1318
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
1319
+
1320
+ float sumf1[2] = {0.f};
1321
+ float sumf2[2] = {0.f};
1322
+ for (int i = ix; i < nb; i += 4) {
1161
1323
 
1162
1324
  for (int l = 0; l < 8; ++l) {
1163
- yl[l+0] = y1[l+ 0];
1164
- yl[l+8] = y1[l+16];
1325
+ yl[l+ 0] = y1[l+ 0];
1326
+ yl[l+ 8] = y1[l+16];
1327
+ yl[l+16] = y1[l+32];
1328
+ yl[l+24] = y1[l+48];
1165
1329
  }
1166
1330
 
1167
1331
  device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1336,43 @@ kernel void kernel_mul_mat_q3_K_f32(
1172
1336
  for (int row = 0; row < 2; ++row) {
1173
1337
 
1174
1338
  const float d_all = (float)dh[0];
1175
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1176
1339
 
1177
- float s1 = 0, s2 = 0;
1340
+ scales16[0] = a[4];
1341
+ scales16[1] = a[5];
1342
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1343
+ scales16[0] = a[il+0];
1344
+ scales16[1] = a[il+1];
1345
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1346
+
1347
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
1178
1348
  for (int l = 0; l < n; l += 2) {
1179
- const uint16_t qs = q[l/2];
1180
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1181
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1349
+ const int32_t qs = q[l/2];
1350
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
1351
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
1352
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1353
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
1354
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
1355
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
1182
1356
  }
1183
- float d = d_all * (s1 + 1.f/256.f * s2);
1184
- sumf1[row] += d * scales[0];
1185
- sumf2[row] += d;
1357
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1358
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1359
+ sumf1[row] += d1 * (scales[0] - 32);
1360
+ sumf2[row] += d2 * (scales[2] - 32);
1186
1361
 
1187
- s1 = s2 = 0;
1362
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
1188
1363
  for (int l = 0; l < n; l += 2) {
1189
- const uint16_t qs = q[l/2+8];
1190
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1191
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1364
+ const int32_t qs = q[l/2+8];
1365
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
1366
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
1367
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1368
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
1369
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
1370
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
1192
1371
  }
1193
- d = d_all * (s1 + 1.f/256.f * s2);
1194
- sumf1[row] += d * scales[1];
1195
- sumf2[row] += d;
1372
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1373
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1374
+ sumf1[row] += d1 * (scales[1] - 32);
1375
+ sumf2[row] += d2 * (scales[3] - 32);
1196
1376
 
1197
1377
  q += step;
1198
1378
  h += step;
@@ -1201,15 +1381,17 @@ kernel void kernel_mul_mat_q3_K_f32(
1201
1381
 
1202
1382
  }
1203
1383
 
1204
- y1 += 2 * QK_K;
1384
+ y1 += 4 * QK_K;
1205
1385
 
1206
1386
  }
1207
1387
 
1208
1388
  for (int row = 0; row < 2; ++row) {
1209
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1210
- const float tot = simd_sum(sumf);
1211
- if (tiisg == 0) {
1212
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1389
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1390
+ sumf1[row] = simd_sum(sumf);
1391
+ }
1392
+ if (tiisg == 0) {
1393
+ for (int row = 0; row < 2; ++row) {
1394
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1213
1395
  }
1214
1396
  }
1215
1397
  }
@@ -1290,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
1290
1472
  device const float * src1,
1291
1473
  device float * dst,
1292
1474
  constant int64_t & ne00,
1293
- constant int64_t & ne01[[buffer(4)]],
1294
- constant int64_t & ne02[[buffer(5)]],
1295
- constant int64_t & ne10[[buffer(9)]],
1296
- constant int64_t & ne12[[buffer(11)]],
1297
- constant int64_t & ne0[[buffer(15)]],
1298
- constant int64_t & ne1[[buffer(16)]],
1299
- constant uint & gqa[[buffer(17)]],
1475
+ constant int64_t & ne01 [[buffer(4)]],
1476
+ constant int64_t & ne02 [[buffer(5)]],
1477
+ constant int64_t & ne10 [[buffer(9)]],
1478
+ constant int64_t & ne12 [[buffer(11)]],
1479
+ constant int64_t & ne0 [[buffer(15)]],
1480
+ constant int64_t & ne1 [[buffer(16)]],
1481
+ constant uint & gqa [[buffer(17)]],
1300
1482
  uint3 tgpig[[threadgroup_position_in_grid]],
1301
1483
  uint tiisg[[thread_index_in_simdgroup]],
1302
1484
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1564,17 +1746,25 @@ kernel void kernel_mul_mat_q5_K_f32(
1564
1746
  sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1565
1747
  sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1566
1748
 
1567
- float4 acc = {0.f, 0.f, 0.f, 0.f};
1749
+ float4 acc1 = {0.f};
1750
+ float4 acc2 = {0.f};
1568
1751
  for (int l = 0; l < n; ++l) {
1569
1752
  uint8_t h = qh[l];
1570
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1571
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1572
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1573
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1753
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
1754
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
1755
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
1756
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
1757
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
1758
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
1759
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
1760
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
1574
1761
  }
1575
1762
  const float dall = dh[0];
1576
1763
  const float dmin = dh[1];
1577
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1764
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
1765
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
1766
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
1767
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
1578
1768
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1579
1769
 
1580
1770
  q1 += step;
@@ -1747,6 +1937,15 @@ kernel void kernel_mul_mat_q6_K_f32(
1747
1937
 
1748
1938
  //============================= templates and their specializations =============================
1749
1939
 
1940
+ // NOTE: this is not dequantizing - we are simply fitting the template
1941
+ template <typename type4x4>
1942
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
1943
+ float4x4 temp = *(((device float4x4 *)src));
1944
+ for (int i = 0; i < 16; i++){
1945
+ reg[i/4][i%4] = temp[i/4][i%4];
1946
+ }
1947
+ }
1948
+
1750
1949
  template <typename type4x4>
1751
1950
  void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1752
1951
  half4x4 temp = *(((device half4x4 *)src));
@@ -1758,28 +1957,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1758
1957
  template <typename type4x4>
1759
1958
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1760
1959
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1761
- const half d = il ? (xb->d / 16.h) : xb->d;
1762
- const half m = il ? ( -8.h * 16.h) : -8.h;
1960
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1961
+ const float d2 = d1 / 256.f;
1962
+ const float md = -8.h * xb->d;
1763
1963
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1764
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1964
+ const ushort mask1 = mask0 << 8;
1765
1965
 
1766
1966
  for (int i=0;i<8;i++) {
1767
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1768
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1967
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1968
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1769
1969
  }
1770
1970
  }
1771
1971
 
1772
1972
  template <typename type4x4>
1773
1973
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1774
1974
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1775
- const half d = il ? (xb->d / 16.h) : xb->d;
1776
- const half m = xb->m;
1975
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1976
+ const float d2 = d1 / 256.f;
1977
+ const float m = xb->m;
1777
1978
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1778
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1979
+ const ushort mask1 = mask0 << 8;
1779
1980
 
1780
1981
  for (int i=0;i<8;i++) {
1781
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1782
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1982
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
1983
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
1783
1984
  }
1784
1985
  }
1785
1986
 
@@ -1815,7 +2016,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
1815
2016
 
1816
2017
  template <typename type4x4>
1817
2018
  void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1818
- const float d_all = (float)(xb->d);
2019
+ const half d_all = xb->d;
1819
2020
  device const uint8_t * q = (device const uint8_t *)xb->qs;
1820
2021
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
1821
2022
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1828,16 +2029,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1828
2029
  ((il/4)>0 ? 12 : 3);
1829
2030
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1830
2031
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1831
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1832
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1833
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
2032
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
2033
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
2034
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
2035
+ const half ml = 4.h * dl;
1834
2036
 
1835
- il = (il/2)%4;
1836
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1837
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2037
+ il = (il/2) & 3;
2038
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
2039
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2040
+ dl *= coef;
1838
2041
 
1839
2042
  for (int i = 0; i < 16; ++i) {
1840
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
2043
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1841
2044
  }
1842
2045
  #else
1843
2046
  float kcoef = il&1 ? 1.f/16.f : 1.f;
@@ -1852,26 +2055,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1852
2055
  #endif
1853
2056
  }
1854
2057
 
2058
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
2059
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
2060
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
2061
+ }
2062
+
1855
2063
  template <typename type4x4>
1856
2064
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1857
- device const uint8_t * q = xb->qs;
2065
+ device const uchar * q = xb->qs;
1858
2066
 
1859
2067
  #if QK_K == 256
1860
- const float d = (float)(xb->d);
1861
- const float min = (float)(xb->dmin);
1862
2068
  short is = (il/4) * 2;
1863
2069
  q = q + (il/4) * 32 + 16 * (il&1);
1864
- il = il%4;
1865
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1866
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1867
- const float ml = il<2 ? min * sc[1] : min * sc[3];
2070
+ il = il & 3;
2071
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2072
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2073
+ const half min = xb->dmin;
2074
+ const half dl = d * sc[0];
2075
+ const half ml = min * sc[1];
1868
2076
  #else
1869
2077
  q = q + 16 * (il&1);
1870
2078
  device const uint8_t * s = xb->scales;
1871
2079
  device const half2 * dh = (device const half2 *)xb->d;
1872
2080
  const float2 d = (float2)dh[0];
1873
2081
  const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1874
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
2082
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
1875
2083
  #endif
1876
2084
  const ushort mask = il<2 ? 0x0F : 0xF0;
1877
2085
  for (int i = 0; i < 16; ++i) {
@@ -1885,19 +2093,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1885
2093
  device const uint8_t * qh = xb->qh;
1886
2094
 
1887
2095
  #if QK_K == 256
1888
- const float d = (float)(xb->d);
1889
- const float min = (float)(xb->dmin);
1890
2096
  short is = (il/4) * 2;
1891
2097
  q = q + 32 * (il/4) + 16 * (il&1);
1892
2098
  qh = qh + 16 * (il&1);
1893
2099
  uint8_t ul = 1 << (il/2);
1894
- il = il%4;
1895
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1896
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1897
- const float ml = il<2 ? min * sc[1] : min * sc[3];
2100
+ il = il & 3;
2101
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2102
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2103
+ const half min = xb->dmin;
2104
+ const half dl = d * sc[0];
2105
+ const half ml = min * sc[1];
1898
2106
 
1899
- const ushort mask = il<2 ? 0x0F : 0xF0;
1900
- const float qh_val = il<2 ? 16.f : 256.f;
2107
+ const ushort mask = il<2 ? 0x0F : 0xF0;
2108
+ const half qh_val = il<2 ? 16.h : 256.h;
1901
2109
  for (int i = 0; i < 16; ++i) {
1902
2110
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1903
2111
  }
@@ -1916,7 +2124,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1916
2124
 
1917
2125
  template <typename type4x4>
1918
2126
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1919
- const float d_all = (float)(xb->d);
2127
+ const half d_all = xb->d;
1920
2128
  device const uint8_t * ql = (device const uint8_t *)xb->ql;
1921
2129
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
1922
2130
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1924,19 +2132,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
1924
2132
  #if QK_K == 256
1925
2133
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1926
2134
  qh = qh + 32*(il/8) + 16*(il&1);
1927
- float sc = scales[(il%2) + 2 * ((il/2))];
1928
- il = (il/2)%4;
2135
+ half sc = scales[(il%2) + 2 * ((il/2))];
2136
+ il = (il/2) & 3;
1929
2137
  #else
1930
2138
  ql = ql + 16 * (il&1);
1931
- float sc = scales[il];
2139
+ half sc = scales[il];
1932
2140
  #endif
2141
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2142
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
2143
+ const half coef = il>1 ? 1.f/16.h : 1.h;
2144
+ const half ml = d_all * sc * 32.h;
2145
+ const half dl = d_all * sc * coef;
1933
2146
  for (int i = 0; i < 16; ++i) {
1934
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1935
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1936
- const float coef = il>1 ? 1.f/16.f : 1.f;
1937
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1938
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1939
- reg[i/4][i%4] = d_all * sc * q * coef;
2147
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
2148
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
2149
+ reg[i/4][i%4] = dl * q - ml;
1940
2150
  }
1941
2151
  }
1942
2152
 
@@ -1976,22 +2186,25 @@ kernel void kernel_get_rows(
1976
2186
  // each block_q contains 16*nl weights
1977
2187
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
1978
2188
  kernel void kernel_mul_mm(device const uchar * src0,
1979
- device const float * src1,
1980
- device float * dst,
1981
- constant int64_t & ne00,
1982
- constant int64_t & ne02,
1983
- constant int64_t & nb01,
1984
- constant int64_t & nb02,
1985
- constant int64_t & ne12,
1986
- constant int64_t & ne0,
1987
- constant int64_t & ne1,
1988
- constant uint & gqa,
1989
- threadgroup uchar * shared_memory [[threadgroup(0)]],
1990
- uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiitg[[thread_index_in_threadgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
-
1994
- threadgroup half * sa = ((threadgroup half *)shared_memory);
2189
+ device const uchar * src1,
2190
+ device float * dst,
2191
+ constant int64_t & ne00,
2192
+ constant int64_t & ne02,
2193
+ constant int64_t & nb01,
2194
+ constant int64_t & nb02,
2195
+ constant int64_t & ne12,
2196
+ constant int64_t & nb10,
2197
+ constant int64_t & nb11,
2198
+ constant int64_t & nb12,
2199
+ constant int64_t & ne0,
2200
+ constant int64_t & ne1,
2201
+ constant uint & gqa,
2202
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
2203
+ uint3 tgpig[[threadgroup_position_in_grid]],
2204
+ uint tiitg[[thread_index_in_threadgroup]],
2205
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2206
+
2207
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
1995
2208
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
1996
2209
 
1997
2210
  const uint r0 = tgpig.y;
@@ -2004,7 +2217,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2004
2217
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2005
2218
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
2006
2219
 
2007
- simdgroup_half8x8 ma[4];
2220
+ simdgroup_half8x8 ma[4];
2008
2221
  simdgroup_float8x8 mb[2];
2009
2222
  simdgroup_float8x8 c_res[8];
2010
2223
  for (int i = 0; i < 8; i++){
@@ -2012,10 +2225,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
2012
2225
  }
2013
2226
 
2014
2227
  short il = (tiitg % THREAD_PER_ROW);
2015
- uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
2016
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2017
- device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
2018
- + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
2228
+
2229
+ uint offset0 = im/gqa*nb02;
2230
+ ushort offset1 = il/nl;
2231
+
2232
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2233
+ device const float * y = (device const float *)(src1
2234
+ + nb12 * im
2235
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
2236
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2019
2237
 
2020
2238
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2021
2239
  //load data and store to threadgroup memory
@@ -2095,6 +2313,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2095
2313
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2096
2314
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
2097
2315
 
2316
+ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2098
2317
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2099
2318
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2100
2319
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@@ -2105,14 +2324,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
2105
2324
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2106
2325
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2107
2326
 
2108
- typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
2109
- constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
2110
- constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
2111
-
2112
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2113
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2114
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2115
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2327
+ typedef void (mat_mm_t)(
2328
+ device const uchar * src0,
2329
+ device const uchar * src1,
2330
+ device float * dst,
2331
+ constant int64_t & ne00,
2332
+ constant int64_t & ne02,
2333
+ constant int64_t & nb01,
2334
+ constant int64_t & nb02,
2335
+ constant int64_t & ne12,
2336
+ constant int64_t & nb10,
2337
+ constant int64_t & nb11,
2338
+ constant int64_t & nb12,
2339
+ constant int64_t & ne0,
2340
+ constant int64_t & ne1,
2341
+ constant uint & gqa,
2342
+ threadgroup uchar *, uint3, uint, uint);
2343
+
2344
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2345
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2346
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2347
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2348
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2116
2349
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2117
2350
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2118
2351
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;