llama_cpp 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
331
331
  threadgroup float * sum [[threadgroup(0)]],
332
332
  uint tgpig[[threadgroup_position_in_grid]],
333
333
  uint tpitg[[thread_position_in_threadgroup]],
334
+ uint sgitg[[simdgroup_index_in_threadgroup]],
335
+ uint tiisg[[thread_index_in_simdgroup]],
334
336
  uint ntg[[threads_per_threadgroup]]) {
335
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
337
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
338
+ device const float * x_scalar = (device const float *) x;
339
+ float4 sumf=0;
340
+ float all_sum=0;
336
341
 
337
342
  // parallel sum
338
- sum[tpitg] = 0.0f;
339
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340
- sum[tpitg] += x[i00] * x[i00];
343
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
344
+ sumf += x[i00] * x[i00];
345
+ }
346
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
347
+ all_sum = simd_sum(all_sum);
348
+ if (tiisg == 0) {
349
+ sum[sgitg] = all_sum;
341
350
  }
342
351
 
343
- // reduce
344
352
  threadgroup_barrier(mem_flags::mem_threadgroup);
345
- for (uint i = ntg/2; i > 0; i /= 2) {
346
- if (tpitg < i) {
347
- sum[tpitg] += sum[tpitg + i];
348
- }
349
- threadgroup_barrier(mem_flags::mem_threadgroup);
353
+ // broadcast, simd group number is ntg / 32
354
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
355
+ if (tpitg < i) {
356
+ sum[tpitg] += sum[tpitg + i];
357
+ }
350
358
  }
351
-
352
- // broadcast
353
359
  if (tpitg == 0) {
360
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
354
361
  sum[0] /= ne00;
355
362
  }
356
363
 
@@ -359,147 +366,127 @@ kernel void kernel_rms_norm(
359
366
  const float mean = sum[0];
360
367
  const float scale = 1.0f/sqrt(mean + eps);
361
368
 
362
- device float * y = dst + tgpig*ne00;
363
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
369
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
370
+ device float * y_scalar = (device float *) y;
371
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
364
372
  y[i00] = x[i00] * scale;
365
373
  }
374
+ if (tpitg == 0) {
375
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
376
+ }
366
377
  }
367
378
 
368
- kernel void kernel_mul_mat_q4_0_f32(
369
- device const void * src0,
370
- device const float * src1,
371
- device float * dst,
372
- constant int64_t & ne00,
373
- constant int64_t & ne10,
374
- constant int64_t & ne0,
375
- threadgroup float * sum [[threadgroup(0)]],
376
- uint2 tgpig[[threadgroup_position_in_grid]],
377
- uint2 tpitg[[thread_position_in_threadgroup]],
378
- uint2 tptg[[threads_per_threadgroup]]) {
379
- const int nb = ne00/QK4_0;
379
+ // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380
+ float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
381
+ float d = qb_curr->d;
382
+ float4 acc = 0.f;
383
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
384
+ for (int i = 0; i < 16; i+=2) {
385
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
386
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
387
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
388
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
389
+ }
390
+ return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
391
+ }
380
392
 
381
- const int64_t r0 = tgpig.x;
382
- const int64_t r1 = tgpig.y;
393
+ // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394
+ float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
395
+ float d = qb_curr->d;
396
+ float m = qb_curr->m;
397
+ float4 acc = 0.f;
398
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
399
+ for (int i = 0; i < 16; i+=2) {
400
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
401
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
402
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
403
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
404
+ }
405
+ return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
406
+ }
383
407
 
384
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
408
+ // putting them in the kernel cause a significant performance penalty
409
+ #define N_DST 4 // each SIMD group works on 4 rows
410
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
411
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
412
+ template<typename block_q_type>
413
+ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
414
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415
+ uint2 tgpig, uint tiisg, uint sgitg) {
416
+ const int nb = ne00/QK4_0;
417
+ const int r0 = tgpig.x;
418
+ const int r1 = tgpig.y;
419
+ device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
385
420
  device const float * y = (device const float *) src1 + r1*ne10;
386
-
387
- const int nth = tptg.x*tptg.y;
388
- const int ith = tptg.y*tpitg.x + tpitg.y;
389
-
390
- const int ix = tpitg.y/4; // 0 or 1
391
- const int iy = tpitg.y - 4*ix; // 0...3
392
-
393
- const int first = 4 * iy;
394
-
395
- float sumf = 0;
396
-
397
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
398
-
399
- const float d = (float)x[i].d;
400
-
401
- device const uint8_t * xl = x[i].qs + first;
402
- device const float * yl = y + i * QK4_0 + first;
403
-
404
- float2 acc = {0.0f, 0.0f};
405
-
406
- for (int j = 0; j < 4; ++j) {
407
-
408
- acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
409
- acc[1] += yl[j] + yl[j+16];
410
-
421
+ float4 y_curr[8]; // src1 vector cache
422
+ float sumf[N_DST]={0.f}, all_sum;
423
+ thread float * yl=(thread float *)y_curr;
424
+
425
+ // each thread in a SIMD group deals with 1 block.
426
+ for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
427
+ float sumy = 0;
428
+ for (int i = 0; i < QK4_0 / 4; i++) {
429
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
430
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
411
431
  }
412
432
 
413
- sumf += d * (acc[0] - 8.f*acc[1]);
433
+ for (int row = 0; row < N_DST; row++) {
434
+ sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
435
+ }
414
436
  }
415
437
 
416
- sum[ith] = sumf;
438
+ // from now loads two rows every time and 16 blocks per row
439
+ int ir = tiisg / (N_SIMDWIDTH / 2);
440
+ int ib = tiisg % (N_SIMDWIDTH / 2);
441
+ for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
442
+ int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
443
+ float sumy = 0;
444
+ for (int i = 0; i < QK4_0 / 4; i++) {
445
+ y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
446
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
447
+ }
417
448
 
418
- //
419
- // Accumulate the sum from all threads in the threadgroup
420
- //
421
- threadgroup_barrier(mem_flags::mem_threadgroup);
422
- if (ith%4 == 0) {
423
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
424
- }
425
- threadgroup_barrier(mem_flags::mem_threadgroup);
426
- if (ith%16 == 0) {
427
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
449
+ for (int row = 0; row < N_DST; row+=2) {
450
+ if (nb_start + ib < nb) {
451
+ sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
452
+ }
453
+ }
428
454
  }
429
- threadgroup_barrier(mem_flags::mem_threadgroup);
430
- if (ith == 0) {
431
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
432
- dst[r1*ne0 + r0] = sum[0];
455
+
456
+ for (int row = 0; row < N_DST; ++row) {
457
+ all_sum = simd_sum(sumf[row]);
458
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
459
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
460
+ }
433
461
  }
434
462
  }
435
463
 
436
- kernel void kernel_mul_mat_q4_1_f32(
464
+ kernel void kernel_mul_mat_q4_0_f32(
437
465
  device const void * src0,
438
466
  device const float * src1,
439
467
  device float * dst,
440
468
  constant int64_t & ne00,
441
469
  constant int64_t & ne10,
442
470
  constant int64_t & ne0,
443
- threadgroup float * sum [[threadgroup(0)]],
471
+ constant int64_t & ne01[[buffer(4)]],
444
472
  uint2 tgpig[[threadgroup_position_in_grid]],
445
- uint2 tpitg[[thread_position_in_threadgroup]],
446
- uint2 tptg[[threads_per_threadgroup]]) {
447
- const int nb = ne00/QK4_1;
448
-
449
- const int64_t r0 = tgpig.x;
450
- const int64_t r1 = tgpig.y;
451
-
452
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
453
- device const float * y = (device const float *) src1 + r1*ne10;
454
-
455
- const uint nth = tptg.x*tptg.y;
456
- const uint ith = tptg.y*tpitg.x + tpitg.y;
457
-
458
- const int ix = tpitg.y/4; // 0 or 1
459
- const int iy = tpitg.y - 4*ix; // 0...3
460
-
461
- const int first = 4 * iy;
462
-
463
- float sumf = 0;
464
-
465
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
466
-
467
- const float d = (float)x[i].d;
468
- const float m = (float)x[i].m;
469
-
470
- device const uint8_t * xl = x[i].qs + first;
471
- device const float * yl = y + i * QK4_1 + first;
472
-
473
- float2 acc = {0.0f, 0.0f};
474
-
475
- for (int j = 0; j < 4; ++j) {
476
-
477
- acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
478
- acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
479
-
480
- }
481
-
482
- sumf += acc[0] + acc[1];
483
- }
484
-
485
- sum[ith] = sumf;
473
+ uint tiisg[[thread_index_in_simdgroup]],
474
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
475
+ mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476
+ }
486
477
 
487
- //
488
- // Accumulate the sum from all threads in the threadgroup
489
- //
490
- threadgroup_barrier(mem_flags::mem_threadgroup);
491
- if (ith%4 == 0) {
492
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
493
- }
494
- threadgroup_barrier(mem_flags::mem_threadgroup);
495
- if (ith%16 == 0) {
496
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
497
- }
498
- threadgroup_barrier(mem_flags::mem_threadgroup);
499
- if (ith == 0) {
500
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
501
- dst[r1*ne0 + r0] = sum[0];
502
- }
478
+ kernel void kernel_mul_mat_q4_1_f32(
479
+ device const void * src0,
480
+ device const float * src1,
481
+ device float * dst,
482
+ constant int64_t & ne00,
483
+ constant int64_t & ne10,
484
+ constant int64_t & ne0,
485
+ constant int64_t & ne01[[buffer(4)]],
486
+ uint2 tgpig[[threadgroup_position_in_grid]],
487
+ uint tiisg[[thread_index_in_simdgroup]],
488
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
+ mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
503
490
  }
504
491
 
505
492
  kernel void kernel_mul_mat_f16_f32(
@@ -615,17 +602,19 @@ kernel void kernel_rope(
615
602
  constant int & n_past,
616
603
  constant int & n_dims,
617
604
  constant int & mode,
605
+ constant float & freq_base,
606
+ constant float & freq_scale,
618
607
  uint3 tpig[[thread_position_in_grid]]) {
619
608
  const int64_t i3 = tpig[2];
620
609
  const int64_t i2 = tpig[1];
621
610
  const int64_t i1 = tpig[0];
622
611
 
623
612
  const bool is_neox = mode & 2;
624
- const float theta_scale = pow(10000.0, -2.0f/n_dims);
613
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
625
614
 
626
615
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
627
616
 
628
- float theta = (float)p;
617
+ float theta = freq_scale * (float)p;
629
618
 
630
619
  if (!is_neox) {
631
620
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -1220,111 +1209,137 @@ kernel void kernel_mul_mat_q2_K_f32(
1220
1209
  constant int64_t & ne00,
1221
1210
  constant int64_t & ne10,
1222
1211
  constant int64_t & ne0,
1223
- threadgroup float * sum [[threadgroup(0)]],
1212
+ constant int64_t & ne01[[buffer(4)]],
1224
1213
  uint2 tgpig[[threadgroup_position_in_grid]],
1225
- uint2 tpitg[[thread_position_in_threadgroup]],
1226
- uint2 tptg[[threads_per_threadgroup]]) {
1214
+ uint tiisg[[thread_index_in_simdgroup]],
1215
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1227
1216
 
1228
1217
  const int nb = ne00/QK_K;
1218
+ const int r0 = tgpig.x;
1219
+ const int r1 = tgpig.y;
1229
1220
 
1230
- const int64_t r0 = tgpig.x;
1231
- const int64_t r1 = tgpig.y;
1232
-
1233
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1234
- device const float * yy = (device const float *) src1 + r1*ne10;
1221
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1222
+ const int ib_row = first_row * nb;
1223
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1224
+ device const float * y = (device const float *) src1 + r1*ne10;
1225
+ float yl[32];
1226
+ float sumf[N_DST]={0.f}, all_sum;
1235
1227
 
1236
- const int nth = tptg.x*tptg.y;
1237
- const int ith = tptg.y*tpitg.x + tpitg.y;
1238
-
1239
- float sumf = 0;
1228
+ const int step = sizeof(block_q2_K) * nb;
1240
1229
 
1241
1230
  #if QK_K == 256
1242
- const int tid = tpitg.y; // 0...16
1243
- const int il = tid/4; // 0...3
1244
- const int ir = tid%4; // 0...3
1245
- const int ip = il/2; // 0 or 1
1246
- const int shift1 = 4*(il%2);// 0 or 4
1247
- const int shift2 = shift1+2;// 2 or 6
1248
- const int n = 8;
1249
- const int is = 4*il + (n*ir)/16;
1250
-
1251
- const int y_offset = 64*il + n*ir;
1252
- const int q_offset = 32*ip + n*ir;
1253
-
1254
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1255
-
1256
- device const uint8_t * q = x[i].qs + q_offset;
1257
- device const uint8_t * scales = x[i].scales + is;
1258
-
1259
- uint8_t d1 = scales[0] & 0xF;
1260
- uint8_t d2 = scales[2] & 0xF;
1261
- uint8_t m1 = scales[0] >> 4;
1262
- uint8_t m2 = scales[2] >> 4;
1263
-
1264
- device const float * y = yy + i*QK_K + y_offset;
1265
-
1266
- float2 s = {0.f, 0.f};
1267
- float smin = 0;
1268
- for (int l = 0; l < n; ++l) {
1269
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1270
- s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1271
- smin += y[l+ 0] * m1 + y[l+32] * m2;
1231
+ const int ix = tiisg/8; // 0...3
1232
+ const int it = tiisg%8; // 0...7
1233
+ const int im = it/4; // 0 or 1
1234
+ const int ir = it%4; // 0...3
1235
+ const int is = (8*ir)/16;// 0 or 1
1236
+
1237
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1238
+
1239
+ for (int ib = ix; ib < nb; ib += 4) {
1240
+
1241
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1242
+ for (int i = 0; i < 8; ++i) {
1243
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1244
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
1245
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
1246
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1272
1247
  }
1273
1248
 
1274
- const float dall = (float)x[i].d;
1275
- const float dmin = (float)x[i].dmin;
1276
-
1277
- sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1249
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1250
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1251
+ device const half * dh = &x[ib].d;
1252
+
1253
+ for (int row = 0; row < N_DST; row++) {
1254
+
1255
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1256
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1257
+ for (int i = 0; i < 8; i += 2) {
1258
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1259
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1260
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1261
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1262
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1263
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1264
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1265
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1266
+ }
1267
+ float dall = dh[0];
1268
+ float dmin = dh[1] * 1.f/16.f;
1269
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1270
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
1271
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
1272
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
1273
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
1274
+
1275
+ qs += step/2;
1276
+ sc += step;
1277
+ dh += step/2;
1278
+ }
1278
1279
 
1280
+ y4 += 4 * QK_K;
1279
1281
  }
1280
1282
  #else
1281
- const int il = 4 * tpitg.x;
1282
-
1283
- uint32_t aux[2];
1284
- thread const uint8_t * d = (thread const uint8_t *)aux;
1285
- thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1283
+ const int ix = tiisg/2; // 0...15
1284
+ const int it = tiisg%2; // 0...1
1286
1285
 
1287
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1286
+ device const float * y4 = y + ix * QK_K + 8 * it;
1288
1287
 
1289
- device const uint8_t * q = x[i].qs + il;
1290
- device const float * y = yy + i*QK_K + il;
1288
+ for (int ib = ix; ib < nb; ib += 16) {
1291
1289
 
1292
- const float dall = (float)x[i].d;
1293
- const float dmin = (float)x[i].dmin;
1290
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1291
+ for (int i = 0; i < 8; ++i) {
1292
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1293
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
1294
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
1295
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
1296
+ }
1294
1297
 
1295
- device const uint32_t * a = (device const uint32_t *)x[i].scales;
1296
- aux[0] = a[0] & 0x0f0f0f0f;
1297
- aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1298
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
1299
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1300
+ device const half * dh = &x[ib].d;
1301
+
1302
+ for (int row = 0; row < N_DST; row++) {
1303
+
1304
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1305
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1306
+ for (int i = 0; i < 8; i += 2) {
1307
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1308
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1309
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1310
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1311
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1312
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1313
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1314
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1315
+ }
1298
1316
 
1299
- for (int l = 0; l < 4; ++l) {
1300
- sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1301
- + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1302
- + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1303
- + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1317
+ float dall = dh[0];
1318
+ float dmin = dh[1];
1319
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1320
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1321
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1322
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1323
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1324
+
1325
+ qs += step/2;
1326
+ sc += step;
1327
+ dh += step/2;
1304
1328
  }
1329
+
1330
+ y4 += 16 * QK_K;
1305
1331
  }
1306
1332
  #endif
1307
1333
 
1308
- sum[ith] = sumf;
1309
-
1310
- //
1311
- // Accumulate the sum from all threads in the threadgroup
1312
- //
1313
- threadgroup_barrier(mem_flags::mem_threadgroup);
1314
- if (ith%4 == 0) {
1315
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1316
- }
1317
- threadgroup_barrier(mem_flags::mem_threadgroup);
1318
- if (ith%16 == 0) {
1319
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1320
- }
1321
- threadgroup_barrier(mem_flags::mem_threadgroup);
1322
- if (ith == 0) {
1323
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1324
- dst[r1*ne0 + r0] = sum[0];
1334
+ for (int row = 0; row < N_DST; ++row) {
1335
+ all_sum = simd_sum(sumf[row]);
1336
+ if (tiisg == 0) {
1337
+ dst[r1*ne0 + first_row + row] = all_sum;
1338
+ }
1325
1339
  }
1326
1340
  }
1327
1341
 
1342
+ #if QK_K == 256
1328
1343
  kernel void kernel_mul_mat_q3_K_f32(
1329
1344
  device const void * src0,
1330
1345
  device const float * src1,
@@ -1333,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
1333
1348
  constant int64_t & ne10,
1334
1349
  constant int64_t & ne0,
1335
1350
  constant int64_t & ne1,
1336
- threadgroup float * sum [[threadgroup(0)]],
1337
1351
  uint2 tgpig[[threadgroup_position_in_grid]],
1338
- uint2 tpitg[[thread_position_in_threadgroup]],
1339
- uint2 tptg[[threads_per_threadgroup]]) {
1352
+ uint tiisg[[thread_index_in_simdgroup]],
1353
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1340
1354
 
1341
1355
  const int nb = ne00/QK_K;
1342
1356
 
1343
1357
  const int64_t r0 = tgpig.x;
1344
1358
  const int64_t r1 = tgpig.y;
1345
1359
 
1346
- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1347
- device const float * yy = (device const float *) src1 + r1*ne10;
1348
-
1349
- const int nth = tptg.x*tptg.y;
1350
- const int ith = tptg.y*tpitg.x + tpitg.y;
1360
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1351
1361
 
1352
- #if QK_K == 256
1362
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1363
+ device const float * yy = (device const float *) src1 + r1*ne10;
1353
1364
 
1354
- const uint8_t m3 = 3;
1355
- const int8_t m4 = 4;
1365
+ float yl[16];
1356
1366
 
1357
1367
  const uint16_t kmask1 = 0x0303;
1358
1368
  const uint16_t kmask2 = 0x0f0f;
1359
1369
 
1360
- const int tid = tpitg.y; // expecting 16
1370
+ const int tid = tiisg/2;
1371
+ const int ix = tiisg%2;
1361
1372
  const int ip = tid/8; // 0 or 1
1362
1373
  const int il = tid/2 - 4*ip; // 0...3
1363
1374
  const int ir = tid%2;
1364
1375
  const int n = 8;
1365
1376
  const int l0 = n*ir;
1366
1377
 
1367
- const uint8_t m = 1 << (4*ip + il);
1378
+ const uint16_t m1 = 1 << (4*ip + il);
1379
+ const uint16_t m2 = m1 << 8;
1368
1380
 
1369
1381
  const int shift = 2*il;
1382
+ const uint16_t qm1 = 0x0003 << shift;
1383
+ const uint16_t qm2 = 0x0300 << shift;
1384
+ const int32_t v1 = 4 << shift;
1385
+ const int32_t v2 = 1024 << shift;
1370
1386
 
1371
1387
  const uint16_t s_shift1 = 4*ip;
1372
1388
  const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1375,226 +1391,315 @@ kernel void kernel_mul_mat_q3_K_f32(
1375
1391
  const int q_offset = 32*ip + l0;
1376
1392
  const int y_offset = 128*ip + 32*il + l0;
1377
1393
 
1378
- //float sumf = 0;
1379
- float sumf1 = 0, sumf2 = 0;
1380
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1394
+ const int step = sizeof(block_q3_K) * nb / 2;
1381
1395
 
1382
- const float d_all = (float)(x[i].d);
1396
+ device const float * y1 = yy + ix*QK_K + y_offset;
1383
1397
 
1384
- device const uint8_t * q = x[i].qs + q_offset;
1385
- device const uint8_t * h = x[i].hmask + l0;
1386
- device const float * y = yy + i * QK_K + y_offset;
1398
+ float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1399
+ for (int i = ix; i < nb; i += 2) {
1387
1400
 
1388
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1389
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1390
-
1391
- float s = 0;
1392
- for (int l = 0; l < n; ++l) {
1393
- s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1394
- }
1395
- float d = d_all * s;
1396
- sumf1 += d * scales[0];
1397
- sumf2 += d;
1398
- //sumf += d_all * s * (scales[0] - 32);
1399
-
1400
- s = 0;
1401
- for (int l = 0; l < n; ++l) {
1402
- s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1401
+ for (int l = 0; l < 8; ++l) {
1402
+ yl[l+0] = y1[l+ 0];
1403
+ yl[l+8] = y1[l+16];
1403
1404
  }
1404
- d = d_all * s;
1405
- sumf1 += d * scales[1];
1406
- sumf2 += d;
1407
- //sumf += d_all * s * (scales[1] - 32);
1408
-
1409
- }
1410
1405
 
1411
- //sum[ith] = sumf;
1412
- sum[ith] = sumf1 - 32.f*sumf2;
1413
- #else
1414
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1415
- const int im = il/8; // 0, 0, 1, 1
1416
- const int in = il%8; // 0, 4, 0, 4
1406
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1407
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1408
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1409
+ device const half * dh = &x[i].d;
1417
1410
 
1418
- float sumf = 0;
1411
+ for (int row = 0; row < 2; ++row) {
1419
1412
 
1420
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1413
+ const float d_all = (float)dh[0];
1414
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1421
1415
 
1422
- const float d_all = (float)(x[i].d);
1423
-
1424
- device const uint8_t * q = x[i].qs + il;
1425
- device const uint8_t * h = x[i].hmask + in;
1426
- device const float * y = yy + i * QK_K + il;
1416
+ float s1 = 0, s2 = 0;
1417
+ for (int l = 0; l < n; l += 2) {
1418
+ const uint16_t qs = q[l/2];
1419
+ s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1420
+ s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1421
+ }
1422
+ float d = d_all * (s1 + 1.f/256.f * s2);
1423
+ sumf1[row] += d * scales[0];
1424
+ sumf2[row] += d;
1425
+
1426
+ s1 = s2 = 0;
1427
+ for (int l = 0; l < n; l += 2) {
1428
+ const uint16_t qs = q[l/2+8];
1429
+ s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1430
+ s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1431
+ }
1432
+ d = d_all * (s1 + 1.f/256.f * s2);
1433
+ sumf1[row] += d * scales[1];
1434
+ sumf2[row] += d;
1427
1435
 
1428
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1429
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1430
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1431
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1436
+ q += step;
1437
+ h += step;
1438
+ a += step;
1439
+ dh += step;
1432
1440
 
1433
- for (int l = 0; l < 4; ++l) {
1434
- const uint8_t hm = h[l] >> im;
1435
- sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1436
- + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1437
- + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1438
- + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1439
1441
  }
1440
1442
 
1441
- }
1442
-
1443
- sum[ith] = sumf;
1444
-
1445
- #endif
1443
+ y1 += 2 * QK_K;
1446
1444
 
1447
- //
1448
- // Accumulate the sum from all threads in the threadgroup
1449
- //
1450
- threadgroup_barrier(mem_flags::mem_threadgroup);
1451
- if (ith%4 == 0) {
1452
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1453
- }
1454
- threadgroup_barrier(mem_flags::mem_threadgroup);
1455
- if (ith%16 == 0) {
1456
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1457
- }
1458
- threadgroup_barrier(mem_flags::mem_threadgroup);
1459
- if (ith == 0) {
1460
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1461
- dst[r1*ne0 + r0] = sum[0];
1462
1445
  }
1463
1446
 
1447
+ for (int row = 0; row < 2; ++row) {
1448
+ const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1449
+ const float tot = simd_sum(sumf);
1450
+ if (tiisg == 0) {
1451
+ dst[r1*ne0 + first_row + row] = tot;
1452
+ }
1453
+ }
1464
1454
  }
1465
-
1466
- kernel void kernel_mul_mat_q4_K_f32(
1455
+ #else
1456
+ kernel void kernel_mul_mat_q3_K_f32(
1467
1457
  device const void * src0,
1468
1458
  device const float * src1,
1469
1459
  device float * dst,
1470
1460
  constant int64_t & ne00,
1471
1461
  constant int64_t & ne10,
1472
1462
  constant int64_t & ne0,
1473
- threadgroup float * sum [[threadgroup(0)]],
1463
+ constant int64_t & ne1,
1474
1464
  uint2 tgpig[[threadgroup_position_in_grid]],
1475
- uint2 tpitg[[thread_position_in_threadgroup]],
1476
- uint2 tptg[[threads_per_threadgroup]]) {
1465
+ uint tiisg[[thread_index_in_simdgroup]],
1466
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1477
1467
 
1478
1468
  const int nb = ne00/QK_K;
1479
1469
 
1480
1470
  const int64_t r0 = tgpig.x;
1481
1471
  const int64_t r1 = tgpig.y;
1482
1472
 
1483
- const int nth = tptg.x*tptg.y;
1484
- const int ith = tptg.y*tpitg.x + tpitg.y;
1473
+ const int row = 2 * r0 + sgitg;
1485
1474
 
1486
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1475
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1487
1476
  device const float * yy = (device const float *) src1 + r1*ne10;
1477
+ const int ix = tiisg/4;
1478
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1479
+ const int im = il/8; // 0, 0, 1, 1
1480
+ const int in = il%8; // 0, 4, 0, 4
1488
1481
 
1489
- float sumf = 0;
1482
+ float2 sum = {0.f, 0.f};
1483
+
1484
+ for (int i = ix; i < nb; i += 8) {
1485
+
1486
+ const float d_all = (float)(x[i].d);
1487
+
1488
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1489
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1490
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1491
+ device const float * y = yy + i * QK_K + il;
1492
+
1493
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1494
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1495
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1496
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1497
+
1498
+ for (int l = 0; l < 4; l += 2) {
1499
+ const uint16_t hm = h[l/2] >> im;
1500
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1501
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1502
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1503
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1504
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1505
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1506
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1507
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
1508
+ }
1509
+
1510
+ }
1511
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
1512
+
1513
+ const float tot = simd_sum(sumf);
1514
+ if (tiisg == 0) {
1515
+ dst[r1*ne0 + row] = tot;
1516
+ }
1517
+
1518
+ }
1519
+ #endif
1490
1520
 
1491
1521
  #if QK_K == 256
1522
+ kernel void kernel_mul_mat_q4_K_f32(
1523
+ device const void * src0,
1524
+ device const float * src1,
1525
+ device float * dst,
1526
+ constant int64_t & ne00,
1527
+ constant int64_t & ne10,
1528
+ constant int64_t & ne0,
1529
+ constant int64_t & ne01[[buffer(4)]],
1530
+ uint2 tgpig[[threadgroup_position_in_grid]],
1531
+ uint tiisg[[thread_index_in_simdgroup]],
1532
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1492
1533
 
1493
1534
  const uint16_t kmask1 = 0x3f3f;
1494
1535
  const uint16_t kmask2 = 0x0f0f;
1495
1536
  const uint16_t kmask3 = 0xc0c0;
1496
1537
 
1497
- const int tid = tpitg.y; // 0...16
1498
- const int il = tid/4; // 0...3
1499
- const int ir = tid - 4*il;// 0...3
1500
- const int n = 4;
1538
+ const int ix = tiisg/8; // 0...3
1539
+ const int it = tiisg%8; // 0...7
1540
+ const int im = it/4; // 0 or 1
1541
+ const int ir = it%4; // 0...3
1501
1542
 
1502
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1503
- const int in = il%2;
1543
+ const int nb = ne00/QK_K;
1544
+ const int r0 = tgpig.x;
1545
+ const int r1 = tgpig.y;
1546
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1547
+ const int ib_row = first_row * nb;
1548
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1549
+ device const float * y = (device const float *) src1 + r1*ne10;
1550
+ float yl[16];
1551
+ float yh[16];
1552
+ float sumf[N_DST]={0.f}, all_sum;
1504
1553
 
1505
- const int l0 = n*(2*ir + in);
1506
- const int q_offset = 32*im + l0;
1507
- const int y_offset = 64*im + l0;
1554
+ const int step = sizeof(block_q4_K) * nb / 2;
1508
1555
 
1509
- uchar2 sc1, sc2, sc3, sc4;
1556
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
1510
1557
 
1511
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1558
+ uint16_t sc16[4];
1559
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1512
1560
 
1513
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1514
- device const uint8_t * q2 = q1 + 64;
1515
- device const float * y1 = yy + i*QK_K + y_offset;
1516
- device const float * y2 = y1 + 128;
1561
+ for (int ib = ix; ib < nb; ib += 4) {
1517
1562
 
1518
- const float dall = (float)((x + i)->d);
1519
- const float dmin = (float)((x + i)->dmin);
1563
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1564
+ for (int i = 0; i < 8; ++i) {
1565
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
1566
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
1567
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
1568
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1569
+ }
1520
1570
 
1521
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1522
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1523
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1524
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1525
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1571
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1572
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1573
+ device const half * dh = &x[ib].d;
1574
+
1575
+ for (int row = 0; row < N_DST; row++) {
1576
+
1577
+ sc16[0] = sc[0] & kmask1;
1578
+ sc16[1] = sc[2] & kmask1;
1579
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
1580
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
1581
+
1582
+ device const uint16_t * q2 = q1 + 32;
1583
+
1584
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1585
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1586
+ for (int i = 0; i < 8; i += 2) {
1587
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
1588
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
1589
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
1590
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
1591
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
1592
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
1593
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
1594
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
1595
+ }
1526
1596
 
1527
- float4 s = {0.f, 0.f, 0.f, 0.f};
1528
- float smin = 0;
1529
- for (int l = 0; l < n; ++l) {
1597
+ float dall = dh[0];
1598
+ float dmin = dh[1];
1599
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
1600
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
1601
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
1602
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
1603
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1604
+
1605
+ q1 += step;
1606
+ sc += step;
1607
+ dh += step;
1608
+ }
1530
1609
 
1531
- s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1532
- s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1533
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1610
+ y4 += 4 * QK_K;
1611
+ }
1534
1612
 
1613
+ for (int row = 0; row < N_DST; ++row) {
1614
+ all_sum = simd_sum(sumf[row]);
1615
+ if (tiisg == 0) {
1616
+ dst[r1*ne0 + first_row + row] = all_sum;
1535
1617
  }
1536
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1537
-
1538
1618
  }
1619
+ }
1539
1620
  #else
1540
- uint16_t aux16[2];
1541
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
1621
+ kernel void kernel_mul_mat_q4_K_f32(
1622
+ device const void * src0,
1623
+ device const float * src1,
1624
+ device float * dst,
1625
+ constant int64_t & ne00,
1626
+ constant int64_t & ne10,
1627
+ constant int64_t & ne0,
1628
+ constant int64_t & ne01[[buffer(4)]],
1629
+ uint2 tgpig[[threadgroup_position_in_grid]],
1630
+ uint tiisg[[thread_index_in_simdgroup]],
1631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1542
1632
 
1543
- const int il = 4*tpitg.x;
1633
+ const int ix = tiisg/4; // 0...7
1634
+ const int it = tiisg%4; // 0...3
1544
1635
 
1545
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1636
+ const int nb = ne00/QK_K;
1637
+ const int r0 = tgpig.x;
1638
+ const int r1 = tgpig.y;
1639
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1640
+ const int ib_row = first_row * nb;
1641
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1642
+ device const float * y = (device const float *) src1 + r1*ne10;
1643
+ float yl[8];
1644
+ float yh[8];
1645
+ float sumf[N_DST]={0.f}, all_sum;
1546
1646
 
1547
- device const uint8_t * q = x[i].qs + il;
1548
- device const float * y = yy + i * QK_K + il;
1647
+ const int step = sizeof(block_q4_K) * nb / 2;
1549
1648
 
1550
- const float d = (float)x[i].d[0];
1551
- const float m = (float)x[i].d[1];
1649
+ device const float * y4 = y + ix * QK_K + 8 * it;
1552
1650
 
1553
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1554
- aux16[0] = a[0] & 0x0f0f;
1555
- aux16[1] = (a[0] >> 4) & 0x0f0f;
1651
+ uint16_t sc16[4];
1556
1652
 
1557
- for (int l = 0; l < 4; ++l) {
1558
- sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1559
- + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1653
+ for (int ib = ix; ib < nb; ib += 8) {
1654
+
1655
+ float2 sumy = {0.f, 0.f};
1656
+ for (int i = 0; i < 8; ++i) {
1657
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
1658
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
1560
1659
  }
1561
- }
1562
- #endif
1563
1660
 
1564
- sum[ith] = sumf;
1661
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
1662
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1663
+ device const half * dh = x[ib].d;
1565
1664
 
1566
- //
1567
- // Accumulate the sum from all threads in the threadgroup
1568
- // This version is slightly faster than the commented out one below,
1569
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1570
- //
1571
- threadgroup_barrier(mem_flags::mem_threadgroup);
1572
- if (ith%4 == 0) {
1573
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1574
- }
1575
- threadgroup_barrier(mem_flags::mem_threadgroup);
1576
- if (ith%16 == 0) {
1577
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1578
- }
1579
- threadgroup_barrier(mem_flags::mem_threadgroup);
1580
- if (ith == 0) {
1581
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1582
- dst[r1*ne0 + r0] = sum[0];
1665
+ for (int row = 0; row < N_DST; row++) {
1666
+
1667
+ sc16[0] = sc[0] & 0x000f;
1668
+ sc16[1] = sc[0] & 0x0f00;
1669
+ sc16[2] = sc[0] & 0x00f0;
1670
+ sc16[3] = sc[0] & 0xf000;
1671
+
1672
+ float2 acc1 = {0.f, 0.f};
1673
+ float2 acc2 = {0.f, 0.f};
1674
+ for (int i = 0; i < 8; i += 2) {
1675
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
1676
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
1677
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
1678
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
1679
+ }
1680
+
1681
+ float dall = dh[0];
1682
+ float dmin = dh[1];
1683
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
1684
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
1685
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
1686
+
1687
+ qs += step;
1688
+ sc += step;
1689
+ dh += step;
1690
+ }
1691
+
1692
+ y4 += 8 * QK_K;
1583
1693
  }
1584
1694
 
1585
- //// accumulate the sum from all threads in the threadgroup
1586
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1587
- //for (uint i = nth/2; i > 0; i /= 2) {
1588
- // if (ith < i) {
1589
- // sum[ith] += sum[ith + i];
1590
- // }
1591
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1592
- //}
1593
-
1594
- //if (ith == 0) {
1595
- // dst[r1*ne0 + r0] = sum[0];
1596
- //}
1695
+ for (int row = 0; row < N_DST; ++row) {
1696
+ all_sum = simd_sum(sumf[row]);
1697
+ if (tiisg == 0) {
1698
+ dst[r1*ne0 + first_row + row] = all_sum;
1699
+ }
1700
+ }
1597
1701
  }
1702
+ #endif
1598
1703
 
1599
1704
  kernel void kernel_mul_mat_q5_K_f32(
1600
1705
  device const void * src0,
@@ -1603,39 +1708,39 @@ kernel void kernel_mul_mat_q5_K_f32(
1603
1708
  constant int64_t & ne00,
1604
1709
  constant int64_t & ne10,
1605
1710
  constant int64_t & ne0,
1606
- threadgroup float * sum [[threadgroup(0)]],
1607
1711
  uint2 tgpig[[threadgroup_position_in_grid]],
1608
- uint2 tpitg[[thread_position_in_threadgroup]],
1609
- uint2 tptg[[threads_per_threadgroup]]) {
1712
+ uint tiisg[[thread_index_in_simdgroup]],
1713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1610
1714
 
1611
1715
  const int nb = ne00/QK_K;
1612
1716
 
1613
1717
  const int64_t r0 = tgpig.x;
1614
1718
  const int64_t r1 = tgpig.y;
1615
1719
 
1616
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1721
+
1722
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1617
1723
  device const float * yy = (device const float *) src1 + r1*ne10;
1618
1724
 
1619
- const int nth = tptg.x*tptg.y;
1620
- const int ith = tptg.y*tpitg.x + tpitg.y;
1725
+ float sumf[2]={0.f};
1621
1726
 
1622
- float sumf = 0;
1727
+ const int step = sizeof(block_q5_K) * nb;
1623
1728
 
1624
1729
  #if QK_K == 256
1730
+ #
1731
+ float yl[16], yh[16];
1625
1732
 
1626
1733
  const uint16_t kmask1 = 0x3f3f;
1627
1734
  const uint16_t kmask2 = 0x0f0f;
1628
1735
  const uint16_t kmask3 = 0xc0c0;
1629
1736
 
1630
- const int tid = tpitg.y; // 0...16
1631
- const int il = tid/4; // 0...3
1632
- const int ir = tid - 4*il;// 0...3
1633
- const int n = 4;
1634
-
1635
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1636
- const int in = il%2;
1737
+ const int tid = tiisg/4;
1738
+ const int ix = tiisg%4;
1739
+ const int im = tid/4;
1740
+ const int ir = tid%4;
1741
+ const int n = 8;
1637
1742
 
1638
- const int l0 = n*(2*ir + in);
1743
+ const int l0 = n*ir;
1639
1744
  const int q_offset = 32*im + l0;
1640
1745
  const int y_offset = 64*im + l0;
1641
1746
 
@@ -1644,78 +1749,113 @@ kernel void kernel_mul_mat_q5_K_f32(
1644
1749
  const uint8_t hm3 = hm1 << 4;
1645
1750
  const uint8_t hm4 = hm2 << 4;
1646
1751
 
1647
- uchar2 sc1, sc2, sc3, sc4;
1752
+ uint16_t sc16[4];
1753
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1648
1754
 
1649
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1755
+ device const float * y1 = yy + ix*QK_K + y_offset;
1650
1756
 
1651
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1652
- device const uint8_t * q2 = q1 + 64;
1653
- device const uint8_t * qh = (x + i)->qh + l0;
1654
- device const float * y1 = yy + i*QK_K + y_offset;
1655
- device const float * y2 = y1 + 128;
1757
+ for (int i = ix; i < nb; i += 4) {
1656
1758
 
1657
- const float dall = (float)((x + i)->d);
1658
- const float dmin = (float)((x + i)->dmin);
1759
+ device const uint8_t * q1 = x[i].qs + q_offset;
1760
+ device const uint8_t * qh = x[i].qh + l0;
1761
+ device const half * dh = &x[i].d;
1762
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1659
1763
 
1660
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1661
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1662
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1663
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1664
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1764
+ device const float * y2 = y1 + 128;
1765
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1766
+ for (int l = 0; l < 8; ++l) {
1767
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1768
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1769
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1770
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1771
+ }
1665
1772
 
1666
- float4 s = {0.f, 0.f, 0.f, 0.f};
1667
- float smin = 0;
1668
- for (int l = 0; l < n; ++l) {
1773
+ for (int row = 0; row < 2; ++row) {
1774
+
1775
+ device const uint8_t * q2 = q1 + 64;
1669
1776
 
1670
- s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1671
- s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1672
- s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1673
- s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1674
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1777
+ sc16[0] = a[0] & kmask1;
1778
+ sc16[1] = a[2] & kmask1;
1779
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1780
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1781
+
1782
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
1783
+ for (int l = 0; l < n; ++l) {
1784
+ uint8_t h = qh[l];
1785
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1786
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1787
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1788
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1789
+ }
1790
+ const float dall = dh[0];
1791
+ const float dmin = dh[1];
1792
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1793
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1794
+
1795
+ q1 += step;
1796
+ qh += step;
1797
+ dh += step/2;
1798
+ a += step/2;
1675
1799
 
1676
1800
  }
1677
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1801
+
1802
+ y1 += 4 * QK_K;
1678
1803
 
1679
1804
  }
1680
1805
  #else
1681
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1682
- const int im = il/8; // 0, 0, 1, 1
1683
- const int in = il%8; // 0, 4, 0, 4
1806
+ float yl[8], yh[8];
1684
1807
 
1685
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1808
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1809
+ const int ix = tiisg%8;
1810
+ const int im = il/8; // 0, 0, 1, 1
1811
+ const int in = il%8; // 0, 4, 0, 4
1686
1812
 
1687
- const float d = (float)x[i].d;
1813
+ device const float * y = yy + ix*QK_K + il;
1814
+
1815
+ for (int i = ix; i < nb; i += 8) {
1816
+
1817
+ for (int l = 0; l < 4; ++l) {
1818
+ yl[l+0] = y[l+ 0];
1819
+ yl[l+4] = y[l+16];
1820
+ yh[l+0] = y[l+32];
1821
+ yh[l+4] = y[l+48];
1822
+ }
1823
+
1824
+ device const half * dh = &x[i].d;
1688
1825
  device const uint8_t * q = x[i].qs + il;
1689
1826
  device const uint8_t * h = x[i].qh + in;
1690
1827
  device const int8_t * s = x[i].scales;
1691
- device const float * y = yy + i*QK_K + il;
1692
1828
 
1693
- for (int l = 0; l < 4; ++l) {
1694
- const uint8_t hl = h[l] >> im;
1695
- sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1696
- + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1697
- + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1698
- + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1829
+ for (int row = 0; row < 2; ++row) {
1830
+
1831
+ const float d = dh[0];
1832
+
1833
+ float2 acc = {0.f, 0.f};
1834
+ for (int l = 0; l < 4; ++l) {
1835
+ const uint8_t hl = h[l] >> im;
1836
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1837
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1838
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1839
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1840
+ }
1841
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1842
+
1843
+ q += step;
1844
+ h += step;
1845
+ s += step;
1846
+ dh += step/2;
1847
+
1699
1848
  }
1849
+
1850
+ y += 8 * QK_K;
1700
1851
  }
1701
1852
  #endif
1702
- sum[ith] = sumf;
1703
1853
 
1704
- //
1705
- // Accumulate the sum from all threads in the threadgroup
1706
- //
1707
- threadgroup_barrier(mem_flags::mem_threadgroup);
1708
- if (ith%4 == 0) {
1709
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1710
- }
1711
- threadgroup_barrier(mem_flags::mem_threadgroup);
1712
- if (ith%16 == 0) {
1713
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1714
- }
1715
- threadgroup_barrier(mem_flags::mem_threadgroup);
1716
- if (ith == 0) {
1717
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1718
- dst[r1*ne0 + r0] = sum[0];
1854
+ for (int row = 0; row < 2; ++row) {
1855
+ const float tot = simd_sum(sumf[row]);
1856
+ if (tiisg == 0) {
1857
+ dst[r1*ne0 + first_row + row] = tot;
1858
+ }
1719
1859
  }
1720
1860
 
1721
1861
  }
@@ -1727,10 +1867,9 @@ kernel void kernel_mul_mat_q6_K_f32(
1727
1867
  constant int64_t & ne00,
1728
1868
  constant int64_t & ne10,
1729
1869
  constant int64_t & ne0,
1730
- threadgroup float * sum [[threadgroup(0)]],
1731
1870
  uint2 tgpig[[threadgroup_position_in_grid]],
1732
- uint2 tpitg[[thread_position_in_threadgroup]],
1733
- uint2 tptg[[threads_per_threadgroup]]) {
1871
+ uint tiisg[[thread_index_in_simdgroup]],
1872
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1734
1873
 
1735
1874
  const uint8_t kmask1 = 0x03;
1736
1875
  const uint8_t kmask2 = 0x0C;
@@ -1742,19 +1881,18 @@ kernel void kernel_mul_mat_q6_K_f32(
1742
1881
  const int64_t r0 = tgpig.x;
1743
1882
  const int64_t r1 = tgpig.y;
1744
1883
 
1745
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1746
- device const float * yy = (device const float *) src1 + r1*ne10;
1884
+ const int row = 2 * r0 + sgitg;
1747
1885
 
1748
- const int nth = tptg.x*tptg.y;
1749
- const int ith = tptg.y*tpitg.x + tpitg.y;
1886
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1887
+ device const float * yy = (device const float *) src1 + r1*ne10;
1750
1888
 
1751
1889
  float sumf = 0;
1752
1890
 
1753
1891
  #if QK_K == 256
1754
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1755
- const int iqs = 16 * tpitg.y;
1756
- const int ip = iqs / 128; // 0 or 1
1757
- const int il = (iqs - 128*ip)/16; // 0...7
1892
+ const int tid = tiisg/2;
1893
+ const int ix = tiisg%2;
1894
+ const int ip = tid/8; // 0 or 1
1895
+ const int il = tid%8;
1758
1896
  const int n = 4;
1759
1897
  const int l0 = n*il;
1760
1898
  const int is = 8*ip + l0/16;
@@ -1763,9 +1901,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1763
1901
  const int q_offset_l = 64*ip + l0;
1764
1902
  const int q_offset_h = 32*ip + l0;
1765
1903
 
1766
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1904
+ for (int i = ix; i < nb; i += 2) {
1767
1905
 
1768
- device const uint8_t * ql = x[i].ql + q_offset_l;
1906
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1907
+ device const uint8_t * q2 = q1 + 32;
1769
1908
  device const uint8_t * qh = x[i].qh + q_offset_h;
1770
1909
  device const int8_t * sc = x[i].scales + is;
1771
1910
 
@@ -1775,19 +1914,21 @@ kernel void kernel_mul_mat_q6_K_f32(
1775
1914
 
1776
1915
  float4 sums = {0.f, 0.f, 0.f, 0.f};
1777
1916
  for (int l = 0; l < n; ++l) {
1778
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1779
- sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1780
- sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1781
- sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1917
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1918
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1919
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1920
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1782
1921
  }
1783
1922
 
1784
1923
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1785
1924
 
1786
1925
  }
1926
+
1787
1927
  #else
1788
- const int il = 4*tpitg.x; // 0, 4, 8, 12
1928
+ const int ix = tiisg/4;
1929
+ const int il = 4*(tiisg%4);
1789
1930
 
1790
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1931
+ for (int i = ix; i < nb; i += 8) {
1791
1932
  device const float * y = yy + i * QK_K + il;
1792
1933
  device const uint8_t * ql = x[i].ql + il;
1793
1934
  device const uint8_t * qh = x[i].qh + il;
@@ -1807,23 +1948,8 @@ kernel void kernel_mul_mat_q6_K_f32(
1807
1948
 
1808
1949
  #endif
1809
1950
 
1810
- sum[ith] = sumf;
1811
-
1812
- //
1813
- // Accumulate the sum from all threads in the threadgroup
1814
- //
1815
- threadgroup_barrier(mem_flags::mem_threadgroup);
1816
- if (ith%4 == 0) {
1817
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1818
- }
1819
- threadgroup_barrier(mem_flags::mem_threadgroup);
1820
- if (ith%16 == 0) {
1821
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1822
- }
1823
- threadgroup_barrier(mem_flags::mem_threadgroup);
1824
- if (ith == 0) {
1825
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1826
- dst[r1*ne0 + r0] = sum[0];
1951
+ const float tot = simd_sum(sumf);
1952
+ if (tiisg == 0) {
1953
+ dst[r1*ne0 + row] = tot;
1827
1954
  }
1828
-
1829
1955
  }