llama_cpp 0.3.2 → 0.3.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
331
331
  threadgroup float * sum [[threadgroup(0)]],
332
332
  uint tgpig[[threadgroup_position_in_grid]],
333
333
  uint tpitg[[thread_position_in_threadgroup]],
334
+ uint sgitg[[simdgroup_index_in_threadgroup]],
335
+ uint tiisg[[thread_index_in_simdgroup]],
334
336
  uint ntg[[threads_per_threadgroup]]) {
335
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
337
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
338
+ device const float * x_scalar = (device const float *) x;
339
+ float4 sumf=0;
340
+ float all_sum=0;
336
341
 
337
342
  // parallel sum
338
- sum[tpitg] = 0.0f;
339
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340
- sum[tpitg] += x[i00] * x[i00];
343
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
344
+ sumf += x[i00] * x[i00];
345
+ }
346
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
347
+ all_sum = simd_sum(all_sum);
348
+ if (tiisg == 0) {
349
+ sum[sgitg] = all_sum;
341
350
  }
342
351
 
343
- // reduce
344
352
  threadgroup_barrier(mem_flags::mem_threadgroup);
345
- for (uint i = ntg/2; i > 0; i /= 2) {
346
- if (tpitg < i) {
347
- sum[tpitg] += sum[tpitg + i];
348
- }
349
- threadgroup_barrier(mem_flags::mem_threadgroup);
353
+ // broadcast, simd group number is ntg / 32
354
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
355
+ if (tpitg < i) {
356
+ sum[tpitg] += sum[tpitg + i];
357
+ }
350
358
  }
351
-
352
- // broadcast
353
359
  if (tpitg == 0) {
360
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
354
361
  sum[0] /= ne00;
355
362
  }
356
363
 
@@ -359,147 +366,127 @@ kernel void kernel_rms_norm(
359
366
  const float mean = sum[0];
360
367
  const float scale = 1.0f/sqrt(mean + eps);
361
368
 
362
- device float * y = dst + tgpig*ne00;
363
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
369
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
370
+ device float * y_scalar = (device float *) y;
371
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
364
372
  y[i00] = x[i00] * scale;
365
373
  }
374
+ if (tpitg == 0) {
375
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
376
+ }
366
377
  }
367
378
 
368
- kernel void kernel_mul_mat_q4_0_f32(
369
- device const void * src0,
370
- device const float * src1,
371
- device float * dst,
372
- constant int64_t & ne00,
373
- constant int64_t & ne10,
374
- constant int64_t & ne0,
375
- threadgroup float * sum [[threadgroup(0)]],
376
- uint2 tgpig[[threadgroup_position_in_grid]],
377
- uint2 tpitg[[thread_position_in_threadgroup]],
378
- uint2 tptg[[threads_per_threadgroup]]) {
379
- const int nb = ne00/QK4_0;
379
+ // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380
+ float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
381
+ float d = qb_curr->d;
382
+ float4 acc = 0.f;
383
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
384
+ for (int i = 0; i < 16; i+=2) {
385
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
386
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
387
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
388
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
389
+ }
390
+ return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
391
+ }
380
392
 
381
- const int64_t r0 = tgpig.x;
382
- const int64_t r1 = tgpig.y;
393
+ // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394
+ float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
395
+ float d = qb_curr->d;
396
+ float m = qb_curr->m;
397
+ float4 acc = 0.f;
398
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
399
+ for (int i = 0; i < 16; i+=2) {
400
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
401
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
402
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
403
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
404
+ }
405
+ return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
406
+ }
383
407
 
384
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
408
+ // putting them in the kernel cause a significant performance penalty
409
+ #define N_DST 4 // each SIMD group works on 4 rows
410
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
411
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
412
+ template<typename block_q_type>
413
+ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
414
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415
+ uint2 tgpig, uint tiisg, uint sgitg) {
416
+ const int nb = ne00/QK4_0;
417
+ const int r0 = tgpig.x;
418
+ const int r1 = tgpig.y;
419
+ device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
385
420
  device const float * y = (device const float *) src1 + r1*ne10;
386
-
387
- const int nth = tptg.x*tptg.y;
388
- const int ith = tptg.y*tpitg.x + tpitg.y;
389
-
390
- const int ix = tpitg.y/4; // 0 or 1
391
- const int iy = tpitg.y - 4*ix; // 0...3
392
-
393
- const int first = 4 * iy;
394
-
395
- float sumf = 0;
396
-
397
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
398
-
399
- const float d = (float)x[i].d;
400
-
401
- device const uint8_t * xl = x[i].qs + first;
402
- device const float * yl = y + i * QK4_0 + first;
403
-
404
- float2 acc = {0.0f, 0.0f};
405
-
406
- for (int j = 0; j < 4; ++j) {
407
-
408
- acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
409
- acc[1] += yl[j] + yl[j+16];
410
-
421
+ float4 y_curr[8]; // src1 vector cache
422
+ float sumf[N_DST]={0.f}, all_sum;
423
+ thread float * yl=(thread float *)y_curr;
424
+
425
+ // each thread in a SIMD group deals with 1 block.
426
+ for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
427
+ float sumy = 0;
428
+ for (int i = 0; i < QK4_0 / 4; i++) {
429
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
430
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
411
431
  }
412
432
 
413
- sumf += d * (acc[0] - 8.f*acc[1]);
433
+ for (int row = 0; row < N_DST; row++) {
434
+ sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
435
+ }
414
436
  }
415
437
 
416
- sum[ith] = sumf;
438
+ // from now loads two rows every time and 16 blocks per row
439
+ int ir = tiisg / (N_SIMDWIDTH / 2);
440
+ int ib = tiisg % (N_SIMDWIDTH / 2);
441
+ for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
442
+ int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
443
+ float sumy = 0;
444
+ for (int i = 0; i < QK4_0 / 4; i++) {
445
+ y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
446
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
447
+ }
417
448
 
418
- //
419
- // Accumulate the sum from all threads in the threadgroup
420
- //
421
- threadgroup_barrier(mem_flags::mem_threadgroup);
422
- if (ith%4 == 0) {
423
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
424
- }
425
- threadgroup_barrier(mem_flags::mem_threadgroup);
426
- if (ith%16 == 0) {
427
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
449
+ for (int row = 0; row < N_DST; row+=2) {
450
+ if (nb_start + ib < nb) {
451
+ sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
452
+ }
453
+ }
428
454
  }
429
- threadgroup_barrier(mem_flags::mem_threadgroup);
430
- if (ith == 0) {
431
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
432
- dst[r1*ne0 + r0] = sum[0];
455
+
456
+ for (int row = 0; row < N_DST; ++row) {
457
+ all_sum = simd_sum(sumf[row]);
458
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
459
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
460
+ }
433
461
  }
434
462
  }
435
463
 
436
- kernel void kernel_mul_mat_q4_1_f32(
464
+ kernel void kernel_mul_mat_q4_0_f32(
437
465
  device const void * src0,
438
466
  device const float * src1,
439
467
  device float * dst,
440
468
  constant int64_t & ne00,
441
469
  constant int64_t & ne10,
442
470
  constant int64_t & ne0,
443
- threadgroup float * sum [[threadgroup(0)]],
471
+ constant int64_t & ne01[[buffer(4)]],
444
472
  uint2 tgpig[[threadgroup_position_in_grid]],
445
- uint2 tpitg[[thread_position_in_threadgroup]],
446
- uint2 tptg[[threads_per_threadgroup]]) {
447
- const int nb = ne00/QK4_1;
448
-
449
- const int64_t r0 = tgpig.x;
450
- const int64_t r1 = tgpig.y;
451
-
452
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
453
- device const float * y = (device const float *) src1 + r1*ne10;
454
-
455
- const uint nth = tptg.x*tptg.y;
456
- const uint ith = tptg.y*tpitg.x + tpitg.y;
457
-
458
- const int ix = tpitg.y/4; // 0 or 1
459
- const int iy = tpitg.y - 4*ix; // 0...3
460
-
461
- const int first = 4 * iy;
462
-
463
- float sumf = 0;
464
-
465
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
466
-
467
- const float d = (float)x[i].d;
468
- const float m = (float)x[i].m;
469
-
470
- device const uint8_t * xl = x[i].qs + first;
471
- device const float * yl = y + i * QK4_1 + first;
472
-
473
- float2 acc = {0.0f, 0.0f};
474
-
475
- for (int j = 0; j < 4; ++j) {
476
-
477
- acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
478
- acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
479
-
480
- }
481
-
482
- sumf += acc[0] + acc[1];
483
- }
484
-
485
- sum[ith] = sumf;
473
+ uint tiisg[[thread_index_in_simdgroup]],
474
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
475
+ mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476
+ }
486
477
 
487
- //
488
- // Accumulate the sum from all threads in the threadgroup
489
- //
490
- threadgroup_barrier(mem_flags::mem_threadgroup);
491
- if (ith%4 == 0) {
492
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
493
- }
494
- threadgroup_barrier(mem_flags::mem_threadgroup);
495
- if (ith%16 == 0) {
496
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
497
- }
498
- threadgroup_barrier(mem_flags::mem_threadgroup);
499
- if (ith == 0) {
500
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
501
- dst[r1*ne0 + r0] = sum[0];
502
- }
478
+ kernel void kernel_mul_mat_q4_1_f32(
479
+ device const void * src0,
480
+ device const float * src1,
481
+ device float * dst,
482
+ constant int64_t & ne00,
483
+ constant int64_t & ne10,
484
+ constant int64_t & ne0,
485
+ constant int64_t & ne01[[buffer(4)]],
486
+ uint2 tgpig[[threadgroup_position_in_grid]],
487
+ uint tiisg[[thread_index_in_simdgroup]],
488
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
+ mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
503
490
  }
504
491
 
505
492
  kernel void kernel_mul_mat_f16_f32(
@@ -615,17 +602,19 @@ kernel void kernel_rope(
615
602
  constant int & n_past,
616
603
  constant int & n_dims,
617
604
  constant int & mode,
605
+ constant float & freq_base,
606
+ constant float & freq_scale,
618
607
  uint3 tpig[[thread_position_in_grid]]) {
619
608
  const int64_t i3 = tpig[2];
620
609
  const int64_t i2 = tpig[1];
621
610
  const int64_t i1 = tpig[0];
622
611
 
623
612
  const bool is_neox = mode & 2;
624
- const float theta_scale = pow(10000.0, -2.0f/n_dims);
613
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
625
614
 
626
615
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
627
616
 
628
- float theta = (float)p;
617
+ float theta = freq_scale * (float)p;
629
618
 
630
619
  if (!is_neox) {
631
620
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -1220,111 +1209,137 @@ kernel void kernel_mul_mat_q2_K_f32(
1220
1209
  constant int64_t & ne00,
1221
1210
  constant int64_t & ne10,
1222
1211
  constant int64_t & ne0,
1223
- threadgroup float * sum [[threadgroup(0)]],
1212
+ constant int64_t & ne01[[buffer(4)]],
1224
1213
  uint2 tgpig[[threadgroup_position_in_grid]],
1225
- uint2 tpitg[[thread_position_in_threadgroup]],
1226
- uint2 tptg[[threads_per_threadgroup]]) {
1214
+ uint tiisg[[thread_index_in_simdgroup]],
1215
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1227
1216
 
1228
1217
  const int nb = ne00/QK_K;
1218
+ const int r0 = tgpig.x;
1219
+ const int r1 = tgpig.y;
1229
1220
 
1230
- const int64_t r0 = tgpig.x;
1231
- const int64_t r1 = tgpig.y;
1232
-
1233
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1234
- device const float * yy = (device const float *) src1 + r1*ne10;
1221
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1222
+ const int ib_row = first_row * nb;
1223
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1224
+ device const float * y = (device const float *) src1 + r1*ne10;
1225
+ float yl[32];
1226
+ float sumf[N_DST]={0.f}, all_sum;
1235
1227
 
1236
- const int nth = tptg.x*tptg.y;
1237
- const int ith = tptg.y*tpitg.x + tpitg.y;
1238
-
1239
- float sumf = 0;
1228
+ const int step = sizeof(block_q2_K) * nb;
1240
1229
 
1241
1230
  #if QK_K == 256
1242
- const int tid = tpitg.y; // 0...16
1243
- const int il = tid/4; // 0...3
1244
- const int ir = tid%4; // 0...3
1245
- const int ip = il/2; // 0 or 1
1246
- const int shift1 = 4*(il%2);// 0 or 4
1247
- const int shift2 = shift1+2;// 2 or 6
1248
- const int n = 8;
1249
- const int is = 4*il + (n*ir)/16;
1250
-
1251
- const int y_offset = 64*il + n*ir;
1252
- const int q_offset = 32*ip + n*ir;
1253
-
1254
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1255
-
1256
- device const uint8_t * q = x[i].qs + q_offset;
1257
- device const uint8_t * scales = x[i].scales + is;
1258
-
1259
- uint8_t d1 = scales[0] & 0xF;
1260
- uint8_t d2 = scales[2] & 0xF;
1261
- uint8_t m1 = scales[0] >> 4;
1262
- uint8_t m2 = scales[2] >> 4;
1263
-
1264
- device const float * y = yy + i*QK_K + y_offset;
1265
-
1266
- float2 s = {0.f, 0.f};
1267
- float smin = 0;
1268
- for (int l = 0; l < n; ++l) {
1269
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1270
- s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1271
- smin += y[l+ 0] * m1 + y[l+32] * m2;
1231
+ const int ix = tiisg/8; // 0...3
1232
+ const int it = tiisg%8; // 0...7
1233
+ const int im = it/4; // 0 or 1
1234
+ const int ir = it%4; // 0...3
1235
+ const int is = (8*ir)/16;// 0 or 1
1236
+
1237
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1238
+
1239
+ for (int ib = ix; ib < nb; ib += 4) {
1240
+
1241
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1242
+ for (int i = 0; i < 8; ++i) {
1243
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1244
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
1245
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
1246
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1272
1247
  }
1273
1248
 
1274
- const float dall = (float)x[i].d;
1275
- const float dmin = (float)x[i].dmin;
1276
-
1277
- sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1249
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1250
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1251
+ device const half * dh = &x[ib].d;
1252
+
1253
+ for (int row = 0; row < N_DST; row++) {
1254
+
1255
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1256
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1257
+ for (int i = 0; i < 8; i += 2) {
1258
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1259
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1260
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1261
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1262
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1263
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1264
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1265
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1266
+ }
1267
+ float dall = dh[0];
1268
+ float dmin = dh[1] * 1.f/16.f;
1269
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1270
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
1271
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
1272
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
1273
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
1274
+
1275
+ qs += step/2;
1276
+ sc += step;
1277
+ dh += step/2;
1278
+ }
1278
1279
 
1280
+ y4 += 4 * QK_K;
1279
1281
  }
1280
1282
  #else
1281
- const int il = 4 * tpitg.x;
1282
-
1283
- uint32_t aux[2];
1284
- thread const uint8_t * d = (thread const uint8_t *)aux;
1285
- thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1283
+ const int ix = tiisg/2; // 0...15
1284
+ const int it = tiisg%2; // 0...1
1286
1285
 
1287
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1286
+ device const float * y4 = y + ix * QK_K + 8 * it;
1288
1287
 
1289
- device const uint8_t * q = x[i].qs + il;
1290
- device const float * y = yy + i*QK_K + il;
1288
+ for (int ib = ix; ib < nb; ib += 16) {
1291
1289
 
1292
- const float dall = (float)x[i].d;
1293
- const float dmin = (float)x[i].dmin;
1290
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1291
+ for (int i = 0; i < 8; ++i) {
1292
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1293
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
1294
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
1295
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
1296
+ }
1294
1297
 
1295
- device const uint32_t * a = (device const uint32_t *)x[i].scales;
1296
- aux[0] = a[0] & 0x0f0f0f0f;
1297
- aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1298
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
1299
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1300
+ device const half * dh = &x[ib].d;
1301
+
1302
+ for (int row = 0; row < N_DST; row++) {
1303
+
1304
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1305
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1306
+ for (int i = 0; i < 8; i += 2) {
1307
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1308
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1309
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1310
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1311
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1312
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1313
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1314
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1315
+ }
1298
1316
 
1299
- for (int l = 0; l < 4; ++l) {
1300
- sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1301
- + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1302
- + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1303
- + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1317
+ float dall = dh[0];
1318
+ float dmin = dh[1];
1319
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1320
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1321
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1322
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1323
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1324
+
1325
+ qs += step/2;
1326
+ sc += step;
1327
+ dh += step/2;
1304
1328
  }
1329
+
1330
+ y4 += 16 * QK_K;
1305
1331
  }
1306
1332
  #endif
1307
1333
 
1308
- sum[ith] = sumf;
1309
-
1310
- //
1311
- // Accumulate the sum from all threads in the threadgroup
1312
- //
1313
- threadgroup_barrier(mem_flags::mem_threadgroup);
1314
- if (ith%4 == 0) {
1315
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1316
- }
1317
- threadgroup_barrier(mem_flags::mem_threadgroup);
1318
- if (ith%16 == 0) {
1319
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1320
- }
1321
- threadgroup_barrier(mem_flags::mem_threadgroup);
1322
- if (ith == 0) {
1323
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1324
- dst[r1*ne0 + r0] = sum[0];
1334
+ for (int row = 0; row < N_DST; ++row) {
1335
+ all_sum = simd_sum(sumf[row]);
1336
+ if (tiisg == 0) {
1337
+ dst[r1*ne0 + first_row + row] = all_sum;
1338
+ }
1325
1339
  }
1326
1340
  }
1327
1341
 
1342
+ #if QK_K == 256
1328
1343
  kernel void kernel_mul_mat_q3_K_f32(
1329
1344
  device const void * src0,
1330
1345
  device const float * src1,
@@ -1333,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
1333
1348
  constant int64_t & ne10,
1334
1349
  constant int64_t & ne0,
1335
1350
  constant int64_t & ne1,
1336
- threadgroup float * sum [[threadgroup(0)]],
1337
1351
  uint2 tgpig[[threadgroup_position_in_grid]],
1338
- uint2 tpitg[[thread_position_in_threadgroup]],
1339
- uint2 tptg[[threads_per_threadgroup]]) {
1352
+ uint tiisg[[thread_index_in_simdgroup]],
1353
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1340
1354
 
1341
1355
  const int nb = ne00/QK_K;
1342
1356
 
1343
1357
  const int64_t r0 = tgpig.x;
1344
1358
  const int64_t r1 = tgpig.y;
1345
1359
 
1346
- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1347
- device const float * yy = (device const float *) src1 + r1*ne10;
1348
-
1349
- const int nth = tptg.x*tptg.y;
1350
- const int ith = tptg.y*tpitg.x + tpitg.y;
1360
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1351
1361
 
1352
- #if QK_K == 256
1362
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1363
+ device const float * yy = (device const float *) src1 + r1*ne10;
1353
1364
 
1354
- const uint8_t m3 = 3;
1355
- const int8_t m4 = 4;
1365
+ float yl[16];
1356
1366
 
1357
1367
  const uint16_t kmask1 = 0x0303;
1358
1368
  const uint16_t kmask2 = 0x0f0f;
1359
1369
 
1360
- const int tid = tpitg.y; // expecting 16
1370
+ const int tid = tiisg/2;
1371
+ const int ix = tiisg%2;
1361
1372
  const int ip = tid/8; // 0 or 1
1362
1373
  const int il = tid/2 - 4*ip; // 0...3
1363
1374
  const int ir = tid%2;
1364
1375
  const int n = 8;
1365
1376
  const int l0 = n*ir;
1366
1377
 
1367
- const uint8_t m = 1 << (4*ip + il);
1378
+ const uint16_t m1 = 1 << (4*ip + il);
1379
+ const uint16_t m2 = m1 << 8;
1368
1380
 
1369
1381
  const int shift = 2*il;
1382
+ const uint16_t qm1 = 0x0003 << shift;
1383
+ const uint16_t qm2 = 0x0300 << shift;
1384
+ const int32_t v1 = 4 << shift;
1385
+ const int32_t v2 = 1024 << shift;
1370
1386
 
1371
1387
  const uint16_t s_shift1 = 4*ip;
1372
1388
  const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1375,226 +1391,315 @@ kernel void kernel_mul_mat_q3_K_f32(
1375
1391
  const int q_offset = 32*ip + l0;
1376
1392
  const int y_offset = 128*ip + 32*il + l0;
1377
1393
 
1378
- //float sumf = 0;
1379
- float sumf1 = 0, sumf2 = 0;
1380
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1394
+ const int step = sizeof(block_q3_K) * nb / 2;
1381
1395
 
1382
- const float d_all = (float)(x[i].d);
1396
+ device const float * y1 = yy + ix*QK_K + y_offset;
1383
1397
 
1384
- device const uint8_t * q = x[i].qs + q_offset;
1385
- device const uint8_t * h = x[i].hmask + l0;
1386
- device const float * y = yy + i * QK_K + y_offset;
1398
+ float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1399
+ for (int i = ix; i < nb; i += 2) {
1387
1400
 
1388
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1389
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1390
-
1391
- float s = 0;
1392
- for (int l = 0; l < n; ++l) {
1393
- s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1394
- }
1395
- float d = d_all * s;
1396
- sumf1 += d * scales[0];
1397
- sumf2 += d;
1398
- //sumf += d_all * s * (scales[0] - 32);
1399
-
1400
- s = 0;
1401
- for (int l = 0; l < n; ++l) {
1402
- s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1401
+ for (int l = 0; l < 8; ++l) {
1402
+ yl[l+0] = y1[l+ 0];
1403
+ yl[l+8] = y1[l+16];
1403
1404
  }
1404
- d = d_all * s;
1405
- sumf1 += d * scales[1];
1406
- sumf2 += d;
1407
- //sumf += d_all * s * (scales[1] - 32);
1408
-
1409
- }
1410
1405
 
1411
- //sum[ith] = sumf;
1412
- sum[ith] = sumf1 - 32.f*sumf2;
1413
- #else
1414
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1415
- const int im = il/8; // 0, 0, 1, 1
1416
- const int in = il%8; // 0, 4, 0, 4
1406
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1407
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1408
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1409
+ device const half * dh = &x[i].d;
1417
1410
 
1418
- float sumf = 0;
1411
+ for (int row = 0; row < 2; ++row) {
1419
1412
 
1420
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1413
+ const float d_all = (float)dh[0];
1414
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1421
1415
 
1422
- const float d_all = (float)(x[i].d);
1423
-
1424
- device const uint8_t * q = x[i].qs + il;
1425
- device const uint8_t * h = x[i].hmask + in;
1426
- device const float * y = yy + i * QK_K + il;
1416
+ float s1 = 0, s2 = 0;
1417
+ for (int l = 0; l < n; l += 2) {
1418
+ const uint16_t qs = q[l/2];
1419
+ s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1420
+ s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1421
+ }
1422
+ float d = d_all * (s1 + 1.f/256.f * s2);
1423
+ sumf1[row] += d * scales[0];
1424
+ sumf2[row] += d;
1425
+
1426
+ s1 = s2 = 0;
1427
+ for (int l = 0; l < n; l += 2) {
1428
+ const uint16_t qs = q[l/2+8];
1429
+ s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1430
+ s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1431
+ }
1432
+ d = d_all * (s1 + 1.f/256.f * s2);
1433
+ sumf1[row] += d * scales[1];
1434
+ sumf2[row] += d;
1427
1435
 
1428
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1429
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1430
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1431
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1436
+ q += step;
1437
+ h += step;
1438
+ a += step;
1439
+ dh += step;
1432
1440
 
1433
- for (int l = 0; l < 4; ++l) {
1434
- const uint8_t hm = h[l] >> im;
1435
- sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1436
- + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1437
- + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1438
- + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1439
1441
  }
1440
1442
 
1441
- }
1442
-
1443
- sum[ith] = sumf;
1444
-
1445
- #endif
1443
+ y1 += 2 * QK_K;
1446
1444
 
1447
- //
1448
- // Accumulate the sum from all threads in the threadgroup
1449
- //
1450
- threadgroup_barrier(mem_flags::mem_threadgroup);
1451
- if (ith%4 == 0) {
1452
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1453
- }
1454
- threadgroup_barrier(mem_flags::mem_threadgroup);
1455
- if (ith%16 == 0) {
1456
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1457
- }
1458
- threadgroup_barrier(mem_flags::mem_threadgroup);
1459
- if (ith == 0) {
1460
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1461
- dst[r1*ne0 + r0] = sum[0];
1462
1445
  }
1463
1446
 
1447
+ for (int row = 0; row < 2; ++row) {
1448
+ const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1449
+ const float tot = simd_sum(sumf);
1450
+ if (tiisg == 0) {
1451
+ dst[r1*ne0 + first_row + row] = tot;
1452
+ }
1453
+ }
1464
1454
  }
1465
-
1466
- kernel void kernel_mul_mat_q4_K_f32(
1455
+ #else
1456
+ kernel void kernel_mul_mat_q3_K_f32(
1467
1457
  device const void * src0,
1468
1458
  device const float * src1,
1469
1459
  device float * dst,
1470
1460
  constant int64_t & ne00,
1471
1461
  constant int64_t & ne10,
1472
1462
  constant int64_t & ne0,
1473
- threadgroup float * sum [[threadgroup(0)]],
1463
+ constant int64_t & ne1,
1474
1464
  uint2 tgpig[[threadgroup_position_in_grid]],
1475
- uint2 tpitg[[thread_position_in_threadgroup]],
1476
- uint2 tptg[[threads_per_threadgroup]]) {
1465
+ uint tiisg[[thread_index_in_simdgroup]],
1466
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1477
1467
 
1478
1468
  const int nb = ne00/QK_K;
1479
1469
 
1480
1470
  const int64_t r0 = tgpig.x;
1481
1471
  const int64_t r1 = tgpig.y;
1482
1472
 
1483
- const int nth = tptg.x*tptg.y;
1484
- const int ith = tptg.y*tpitg.x + tpitg.y;
1473
+ const int row = 2 * r0 + sgitg;
1485
1474
 
1486
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1475
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1487
1476
  device const float * yy = (device const float *) src1 + r1*ne10;
1477
+ const int ix = tiisg/4;
1478
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1479
+ const int im = il/8; // 0, 0, 1, 1
1480
+ const int in = il%8; // 0, 4, 0, 4
1488
1481
 
1489
- float sumf = 0;
1482
+ float2 sum = {0.f, 0.f};
1483
+
1484
+ for (int i = ix; i < nb; i += 8) {
1485
+
1486
+ const float d_all = (float)(x[i].d);
1487
+
1488
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1489
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1490
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1491
+ device const float * y = yy + i * QK_K + il;
1492
+
1493
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1494
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1495
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1496
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1497
+
1498
+ for (int l = 0; l < 4; l += 2) {
1499
+ const uint16_t hm = h[l/2] >> im;
1500
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1501
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1502
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1503
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1504
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1505
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1506
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1507
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
1508
+ }
1509
+
1510
+ }
1511
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
1512
+
1513
+ const float tot = simd_sum(sumf);
1514
+ if (tiisg == 0) {
1515
+ dst[r1*ne0 + row] = tot;
1516
+ }
1517
+
1518
+ }
1519
+ #endif
1490
1520
 
1491
1521
  #if QK_K == 256
1522
+ kernel void kernel_mul_mat_q4_K_f32(
1523
+ device const void * src0,
1524
+ device const float * src1,
1525
+ device float * dst,
1526
+ constant int64_t & ne00,
1527
+ constant int64_t & ne10,
1528
+ constant int64_t & ne0,
1529
+ constant int64_t & ne01[[buffer(4)]],
1530
+ uint2 tgpig[[threadgroup_position_in_grid]],
1531
+ uint tiisg[[thread_index_in_simdgroup]],
1532
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1492
1533
 
1493
1534
  const uint16_t kmask1 = 0x3f3f;
1494
1535
  const uint16_t kmask2 = 0x0f0f;
1495
1536
  const uint16_t kmask3 = 0xc0c0;
1496
1537
 
1497
- const int tid = tpitg.y; // 0...16
1498
- const int il = tid/4; // 0...3
1499
- const int ir = tid - 4*il;// 0...3
1500
- const int n = 4;
1538
+ const int ix = tiisg/8; // 0...3
1539
+ const int it = tiisg%8; // 0...7
1540
+ const int im = it/4; // 0 or 1
1541
+ const int ir = it%4; // 0...3
1501
1542
 
1502
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1503
- const int in = il%2;
1543
+ const int nb = ne00/QK_K;
1544
+ const int r0 = tgpig.x;
1545
+ const int r1 = tgpig.y;
1546
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1547
+ const int ib_row = first_row * nb;
1548
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1549
+ device const float * y = (device const float *) src1 + r1*ne10;
1550
+ float yl[16];
1551
+ float yh[16];
1552
+ float sumf[N_DST]={0.f}, all_sum;
1504
1553
 
1505
- const int l0 = n*(2*ir + in);
1506
- const int q_offset = 32*im + l0;
1507
- const int y_offset = 64*im + l0;
1554
+ const int step = sizeof(block_q4_K) * nb / 2;
1508
1555
 
1509
- uchar2 sc1, sc2, sc3, sc4;
1556
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
1510
1557
 
1511
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1558
+ uint16_t sc16[4];
1559
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1512
1560
 
1513
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1514
- device const uint8_t * q2 = q1 + 64;
1515
- device const float * y1 = yy + i*QK_K + y_offset;
1516
- device const float * y2 = y1 + 128;
1561
+ for (int ib = ix; ib < nb; ib += 4) {
1517
1562
 
1518
- const float dall = (float)((x + i)->d);
1519
- const float dmin = (float)((x + i)->dmin);
1563
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1564
+ for (int i = 0; i < 8; ++i) {
1565
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
1566
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
1567
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
1568
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1569
+ }
1520
1570
 
1521
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1522
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1523
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1524
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1525
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1571
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1572
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1573
+ device const half * dh = &x[ib].d;
1574
+
1575
+ for (int row = 0; row < N_DST; row++) {
1576
+
1577
+ sc16[0] = sc[0] & kmask1;
1578
+ sc16[1] = sc[2] & kmask1;
1579
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
1580
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
1581
+
1582
+ device const uint16_t * q2 = q1 + 32;
1583
+
1584
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1585
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1586
+ for (int i = 0; i < 8; i += 2) {
1587
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
1588
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
1589
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
1590
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
1591
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
1592
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
1593
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
1594
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
1595
+ }
1526
1596
 
1527
- float4 s = {0.f, 0.f, 0.f, 0.f};
1528
- float smin = 0;
1529
- for (int l = 0; l < n; ++l) {
1597
+ float dall = dh[0];
1598
+ float dmin = dh[1];
1599
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
1600
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
1601
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
1602
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
1603
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1604
+
1605
+ q1 += step;
1606
+ sc += step;
1607
+ dh += step;
1608
+ }
1530
1609
 
1531
- s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1532
- s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1533
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1610
+ y4 += 4 * QK_K;
1611
+ }
1534
1612
 
1613
+ for (int row = 0; row < N_DST; ++row) {
1614
+ all_sum = simd_sum(sumf[row]);
1615
+ if (tiisg == 0) {
1616
+ dst[r1*ne0 + first_row + row] = all_sum;
1535
1617
  }
1536
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1537
-
1538
1618
  }
1619
+ }
1539
1620
  #else
1540
- uint16_t aux16[2];
1541
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
1621
+ kernel void kernel_mul_mat_q4_K_f32(
1622
+ device const void * src0,
1623
+ device const float * src1,
1624
+ device float * dst,
1625
+ constant int64_t & ne00,
1626
+ constant int64_t & ne10,
1627
+ constant int64_t & ne0,
1628
+ constant int64_t & ne01[[buffer(4)]],
1629
+ uint2 tgpig[[threadgroup_position_in_grid]],
1630
+ uint tiisg[[thread_index_in_simdgroup]],
1631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1542
1632
 
1543
- const int il = 4*tpitg.x;
1633
+ const int ix = tiisg/4; // 0...7
1634
+ const int it = tiisg%4; // 0...3
1544
1635
 
1545
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1636
+ const int nb = ne00/QK_K;
1637
+ const int r0 = tgpig.x;
1638
+ const int r1 = tgpig.y;
1639
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1640
+ const int ib_row = first_row * nb;
1641
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1642
+ device const float * y = (device const float *) src1 + r1*ne10;
1643
+ float yl[8];
1644
+ float yh[8];
1645
+ float sumf[N_DST]={0.f}, all_sum;
1546
1646
 
1547
- device const uint8_t * q = x[i].qs + il;
1548
- device const float * y = yy + i * QK_K + il;
1647
+ const int step = sizeof(block_q4_K) * nb / 2;
1549
1648
 
1550
- const float d = (float)x[i].d[0];
1551
- const float m = (float)x[i].d[1];
1649
+ device const float * y4 = y + ix * QK_K + 8 * it;
1552
1650
 
1553
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1554
- aux16[0] = a[0] & 0x0f0f;
1555
- aux16[1] = (a[0] >> 4) & 0x0f0f;
1651
+ uint16_t sc16[4];
1556
1652
 
1557
- for (int l = 0; l < 4; ++l) {
1558
- sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1559
- + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1653
+ for (int ib = ix; ib < nb; ib += 8) {
1654
+
1655
+ float2 sumy = {0.f, 0.f};
1656
+ for (int i = 0; i < 8; ++i) {
1657
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
1658
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
1560
1659
  }
1561
- }
1562
- #endif
1563
1660
 
1564
- sum[ith] = sumf;
1661
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
1662
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1663
+ device const half * dh = x[ib].d;
1565
1664
 
1566
- //
1567
- // Accumulate the sum from all threads in the threadgroup
1568
- // This version is slightly faster than the commented out one below,
1569
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1570
- //
1571
- threadgroup_barrier(mem_flags::mem_threadgroup);
1572
- if (ith%4 == 0) {
1573
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1574
- }
1575
- threadgroup_barrier(mem_flags::mem_threadgroup);
1576
- if (ith%16 == 0) {
1577
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1578
- }
1579
- threadgroup_barrier(mem_flags::mem_threadgroup);
1580
- if (ith == 0) {
1581
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1582
- dst[r1*ne0 + r0] = sum[0];
1665
+ for (int row = 0; row < N_DST; row++) {
1666
+
1667
+ sc16[0] = sc[0] & 0x000f;
1668
+ sc16[1] = sc[0] & 0x0f00;
1669
+ sc16[2] = sc[0] & 0x00f0;
1670
+ sc16[3] = sc[0] & 0xf000;
1671
+
1672
+ float2 acc1 = {0.f, 0.f};
1673
+ float2 acc2 = {0.f, 0.f};
1674
+ for (int i = 0; i < 8; i += 2) {
1675
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
1676
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
1677
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
1678
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
1679
+ }
1680
+
1681
+ float dall = dh[0];
1682
+ float dmin = dh[1];
1683
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
1684
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
1685
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
1686
+
1687
+ qs += step;
1688
+ sc += step;
1689
+ dh += step;
1690
+ }
1691
+
1692
+ y4 += 8 * QK_K;
1583
1693
  }
1584
1694
 
1585
- //// accumulate the sum from all threads in the threadgroup
1586
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1587
- //for (uint i = nth/2; i > 0; i /= 2) {
1588
- // if (ith < i) {
1589
- // sum[ith] += sum[ith + i];
1590
- // }
1591
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1592
- //}
1593
-
1594
- //if (ith == 0) {
1595
- // dst[r1*ne0 + r0] = sum[0];
1596
- //}
1695
+ for (int row = 0; row < N_DST; ++row) {
1696
+ all_sum = simd_sum(sumf[row]);
1697
+ if (tiisg == 0) {
1698
+ dst[r1*ne0 + first_row + row] = all_sum;
1699
+ }
1700
+ }
1597
1701
  }
1702
+ #endif
1598
1703
 
1599
1704
  kernel void kernel_mul_mat_q5_K_f32(
1600
1705
  device const void * src0,
@@ -1603,39 +1708,39 @@ kernel void kernel_mul_mat_q5_K_f32(
1603
1708
  constant int64_t & ne00,
1604
1709
  constant int64_t & ne10,
1605
1710
  constant int64_t & ne0,
1606
- threadgroup float * sum [[threadgroup(0)]],
1607
1711
  uint2 tgpig[[threadgroup_position_in_grid]],
1608
- uint2 tpitg[[thread_position_in_threadgroup]],
1609
- uint2 tptg[[threads_per_threadgroup]]) {
1712
+ uint tiisg[[thread_index_in_simdgroup]],
1713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1610
1714
 
1611
1715
  const int nb = ne00/QK_K;
1612
1716
 
1613
1717
  const int64_t r0 = tgpig.x;
1614
1718
  const int64_t r1 = tgpig.y;
1615
1719
 
1616
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1721
+
1722
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1617
1723
  device const float * yy = (device const float *) src1 + r1*ne10;
1618
1724
 
1619
- const int nth = tptg.x*tptg.y;
1620
- const int ith = tptg.y*tpitg.x + tpitg.y;
1725
+ float sumf[2]={0.f};
1621
1726
 
1622
- float sumf = 0;
1727
+ const int step = sizeof(block_q5_K) * nb;
1623
1728
 
1624
1729
  #if QK_K == 256
1730
+ #
1731
+ float yl[16], yh[16];
1625
1732
 
1626
1733
  const uint16_t kmask1 = 0x3f3f;
1627
1734
  const uint16_t kmask2 = 0x0f0f;
1628
1735
  const uint16_t kmask3 = 0xc0c0;
1629
1736
 
1630
- const int tid = tpitg.y; // 0...16
1631
- const int il = tid/4; // 0...3
1632
- const int ir = tid - 4*il;// 0...3
1633
- const int n = 4;
1634
-
1635
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1636
- const int in = il%2;
1737
+ const int tid = tiisg/4;
1738
+ const int ix = tiisg%4;
1739
+ const int im = tid/4;
1740
+ const int ir = tid%4;
1741
+ const int n = 8;
1637
1742
 
1638
- const int l0 = n*(2*ir + in);
1743
+ const int l0 = n*ir;
1639
1744
  const int q_offset = 32*im + l0;
1640
1745
  const int y_offset = 64*im + l0;
1641
1746
 
@@ -1644,78 +1749,113 @@ kernel void kernel_mul_mat_q5_K_f32(
1644
1749
  const uint8_t hm3 = hm1 << 4;
1645
1750
  const uint8_t hm4 = hm2 << 4;
1646
1751
 
1647
- uchar2 sc1, sc2, sc3, sc4;
1752
+ uint16_t sc16[4];
1753
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1648
1754
 
1649
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1755
+ device const float * y1 = yy + ix*QK_K + y_offset;
1650
1756
 
1651
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1652
- device const uint8_t * q2 = q1 + 64;
1653
- device const uint8_t * qh = (x + i)->qh + l0;
1654
- device const float * y1 = yy + i*QK_K + y_offset;
1655
- device const float * y2 = y1 + 128;
1757
+ for (int i = ix; i < nb; i += 4) {
1656
1758
 
1657
- const float dall = (float)((x + i)->d);
1658
- const float dmin = (float)((x + i)->dmin);
1759
+ device const uint8_t * q1 = x[i].qs + q_offset;
1760
+ device const uint8_t * qh = x[i].qh + l0;
1761
+ device const half * dh = &x[i].d;
1762
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1659
1763
 
1660
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1661
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1662
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1663
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1664
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1764
+ device const float * y2 = y1 + 128;
1765
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1766
+ for (int l = 0; l < 8; ++l) {
1767
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1768
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1769
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1770
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1771
+ }
1665
1772
 
1666
- float4 s = {0.f, 0.f, 0.f, 0.f};
1667
- float smin = 0;
1668
- for (int l = 0; l < n; ++l) {
1773
+ for (int row = 0; row < 2; ++row) {
1774
+
1775
+ device const uint8_t * q2 = q1 + 64;
1669
1776
 
1670
- s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1671
- s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1672
- s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1673
- s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1674
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1777
+ sc16[0] = a[0] & kmask1;
1778
+ sc16[1] = a[2] & kmask1;
1779
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1780
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1781
+
1782
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
1783
+ for (int l = 0; l < n; ++l) {
1784
+ uint8_t h = qh[l];
1785
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1786
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1787
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1788
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1789
+ }
1790
+ const float dall = dh[0];
1791
+ const float dmin = dh[1];
1792
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1793
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1794
+
1795
+ q1 += step;
1796
+ qh += step;
1797
+ dh += step/2;
1798
+ a += step/2;
1675
1799
 
1676
1800
  }
1677
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1801
+
1802
+ y1 += 4 * QK_K;
1678
1803
 
1679
1804
  }
1680
1805
  #else
1681
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1682
- const int im = il/8; // 0, 0, 1, 1
1683
- const int in = il%8; // 0, 4, 0, 4
1806
+ float yl[8], yh[8];
1684
1807
 
1685
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1808
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1809
+ const int ix = tiisg%8;
1810
+ const int im = il/8; // 0, 0, 1, 1
1811
+ const int in = il%8; // 0, 4, 0, 4
1686
1812
 
1687
- const float d = (float)x[i].d;
1813
+ device const float * y = yy + ix*QK_K + il;
1814
+
1815
+ for (int i = ix; i < nb; i += 8) {
1816
+
1817
+ for (int l = 0; l < 4; ++l) {
1818
+ yl[l+0] = y[l+ 0];
1819
+ yl[l+4] = y[l+16];
1820
+ yh[l+0] = y[l+32];
1821
+ yh[l+4] = y[l+48];
1822
+ }
1823
+
1824
+ device const half * dh = &x[i].d;
1688
1825
  device const uint8_t * q = x[i].qs + il;
1689
1826
  device const uint8_t * h = x[i].qh + in;
1690
1827
  device const int8_t * s = x[i].scales;
1691
- device const float * y = yy + i*QK_K + il;
1692
1828
 
1693
- for (int l = 0; l < 4; ++l) {
1694
- const uint8_t hl = h[l] >> im;
1695
- sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1696
- + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1697
- + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1698
- + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1829
+ for (int row = 0; row < 2; ++row) {
1830
+
1831
+ const float d = dh[0];
1832
+
1833
+ float2 acc = {0.f, 0.f};
1834
+ for (int l = 0; l < 4; ++l) {
1835
+ const uint8_t hl = h[l] >> im;
1836
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1837
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1838
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1839
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1840
+ }
1841
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1842
+
1843
+ q += step;
1844
+ h += step;
1845
+ s += step;
1846
+ dh += step/2;
1847
+
1699
1848
  }
1849
+
1850
+ y += 8 * QK_K;
1700
1851
  }
1701
1852
  #endif
1702
- sum[ith] = sumf;
1703
1853
 
1704
- //
1705
- // Accumulate the sum from all threads in the threadgroup
1706
- //
1707
- threadgroup_barrier(mem_flags::mem_threadgroup);
1708
- if (ith%4 == 0) {
1709
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1710
- }
1711
- threadgroup_barrier(mem_flags::mem_threadgroup);
1712
- if (ith%16 == 0) {
1713
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1714
- }
1715
- threadgroup_barrier(mem_flags::mem_threadgroup);
1716
- if (ith == 0) {
1717
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1718
- dst[r1*ne0 + r0] = sum[0];
1854
+ for (int row = 0; row < 2; ++row) {
1855
+ const float tot = simd_sum(sumf[row]);
1856
+ if (tiisg == 0) {
1857
+ dst[r1*ne0 + first_row + row] = tot;
1858
+ }
1719
1859
  }
1720
1860
 
1721
1861
  }
@@ -1727,10 +1867,9 @@ kernel void kernel_mul_mat_q6_K_f32(
1727
1867
  constant int64_t & ne00,
1728
1868
  constant int64_t & ne10,
1729
1869
  constant int64_t & ne0,
1730
- threadgroup float * sum [[threadgroup(0)]],
1731
1870
  uint2 tgpig[[threadgroup_position_in_grid]],
1732
- uint2 tpitg[[thread_position_in_threadgroup]],
1733
- uint2 tptg[[threads_per_threadgroup]]) {
1871
+ uint tiisg[[thread_index_in_simdgroup]],
1872
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1734
1873
 
1735
1874
  const uint8_t kmask1 = 0x03;
1736
1875
  const uint8_t kmask2 = 0x0C;
@@ -1742,19 +1881,18 @@ kernel void kernel_mul_mat_q6_K_f32(
1742
1881
  const int64_t r0 = tgpig.x;
1743
1882
  const int64_t r1 = tgpig.y;
1744
1883
 
1745
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1746
- device const float * yy = (device const float *) src1 + r1*ne10;
1884
+ const int row = 2 * r0 + sgitg;
1747
1885
 
1748
- const int nth = tptg.x*tptg.y;
1749
- const int ith = tptg.y*tpitg.x + tpitg.y;
1886
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1887
+ device const float * yy = (device const float *) src1 + r1*ne10;
1750
1888
 
1751
1889
  float sumf = 0;
1752
1890
 
1753
1891
  #if QK_K == 256
1754
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1755
- const int iqs = 16 * tpitg.y;
1756
- const int ip = iqs / 128; // 0 or 1
1757
- const int il = (iqs - 128*ip)/16; // 0...7
1892
+ const int tid = tiisg/2;
1893
+ const int ix = tiisg%2;
1894
+ const int ip = tid/8; // 0 or 1
1895
+ const int il = tid%8;
1758
1896
  const int n = 4;
1759
1897
  const int l0 = n*il;
1760
1898
  const int is = 8*ip + l0/16;
@@ -1763,9 +1901,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1763
1901
  const int q_offset_l = 64*ip + l0;
1764
1902
  const int q_offset_h = 32*ip + l0;
1765
1903
 
1766
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1904
+ for (int i = ix; i < nb; i += 2) {
1767
1905
 
1768
- device const uint8_t * ql = x[i].ql + q_offset_l;
1906
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1907
+ device const uint8_t * q2 = q1 + 32;
1769
1908
  device const uint8_t * qh = x[i].qh + q_offset_h;
1770
1909
  device const int8_t * sc = x[i].scales + is;
1771
1910
 
@@ -1775,19 +1914,21 @@ kernel void kernel_mul_mat_q6_K_f32(
1775
1914
 
1776
1915
  float4 sums = {0.f, 0.f, 0.f, 0.f};
1777
1916
  for (int l = 0; l < n; ++l) {
1778
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1779
- sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1780
- sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1781
- sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1917
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1918
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1919
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1920
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1782
1921
  }
1783
1922
 
1784
1923
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1785
1924
 
1786
1925
  }
1926
+
1787
1927
  #else
1788
- const int il = 4*tpitg.x; // 0, 4, 8, 12
1928
+ const int ix = tiisg/4;
1929
+ const int il = 4*(tiisg%4);
1789
1930
 
1790
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1931
+ for (int i = ix; i < nb; i += 8) {
1791
1932
  device const float * y = yy + i * QK_K + il;
1792
1933
  device const uint8_t * ql = x[i].ql + il;
1793
1934
  device const uint8_t * qh = x[i].qh + il;
@@ -1807,23 +1948,8 @@ kernel void kernel_mul_mat_q6_K_f32(
1807
1948
 
1808
1949
  #endif
1809
1950
 
1810
- sum[ith] = sumf;
1811
-
1812
- //
1813
- // Accumulate the sum from all threads in the threadgroup
1814
- //
1815
- threadgroup_barrier(mem_flags::mem_threadgroup);
1816
- if (ith%4 == 0) {
1817
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1818
- }
1819
- threadgroup_barrier(mem_flags::mem_threadgroup);
1820
- if (ith%16 == 0) {
1821
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1822
- }
1823
- threadgroup_barrier(mem_flags::mem_threadgroup);
1824
- if (ith == 0) {
1825
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1826
- dst[r1*ne0 + r0] = sum[0];
1951
+ const float tot = simd_sum(sumf);
1952
+ if (tiisg == 0) {
1953
+ dst[r1*ne0 + row] = tot;
1827
1954
  }
1828
-
1829
1955
  }