llama_cpp 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
331
331
  threadgroup float * sum [[threadgroup(0)]],
332
332
  uint tgpig[[threadgroup_position_in_grid]],
333
333
  uint tpitg[[thread_position_in_threadgroup]],
334
+ uint sgitg[[simdgroup_index_in_threadgroup]],
335
+ uint tiisg[[thread_index_in_simdgroup]],
334
336
  uint ntg[[threads_per_threadgroup]]) {
335
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
337
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
338
+ device const float * x_scalar = (device const float *) x;
339
+ float4 sumf=0;
340
+ float all_sum=0;
336
341
 
337
342
  // parallel sum
338
- sum[tpitg] = 0.0f;
339
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340
- sum[tpitg] += x[i00] * x[i00];
343
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
344
+ sumf += x[i00] * x[i00];
345
+ }
346
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
347
+ all_sum = simd_sum(all_sum);
348
+ if (tiisg == 0) {
349
+ sum[sgitg] = all_sum;
341
350
  }
342
351
 
343
- // reduce
344
352
  threadgroup_barrier(mem_flags::mem_threadgroup);
345
- for (uint i = ntg/2; i > 0; i /= 2) {
346
- if (tpitg < i) {
347
- sum[tpitg] += sum[tpitg + i];
348
- }
349
- threadgroup_barrier(mem_flags::mem_threadgroup);
353
+ // broadcast, simd group number is ntg / 32
354
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
355
+ if (tpitg < i) {
356
+ sum[tpitg] += sum[tpitg + i];
357
+ }
350
358
  }
351
-
352
- // broadcast
353
359
  if (tpitg == 0) {
360
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
354
361
  sum[0] /= ne00;
355
362
  }
356
363
 
@@ -359,82 +366,94 @@ kernel void kernel_rms_norm(
359
366
  const float mean = sum[0];
360
367
  const float scale = 1.0f/sqrt(mean + eps);
361
368
 
362
- device float * y = dst + tgpig*ne00;
363
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
369
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
370
+ device float * y_scalar = (device float *) y;
371
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
364
372
  y[i00] = x[i00] * scale;
365
373
  }
374
+ if (tpitg == 0) {
375
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
376
+ }
377
+ }
378
+
379
+ // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380
+ float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
381
+ float d = qb_curr->d;
382
+ float4 acc = 0.f;
383
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
384
+ for (int i = 0; i < 16; i+=2) {
385
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
386
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
387
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
388
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
389
+ }
390
+ return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
391
+ }
392
+
393
+ // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394
+ float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
395
+ float d = qb_curr->d;
396
+ float m = qb_curr->m;
397
+ float4 acc = 0.f;
398
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
399
+ for (int i = 0; i < 16; i+=2) {
400
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
401
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
402
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
403
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
404
+ }
405
+ return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
366
406
  }
367
407
 
368
408
  // putting them in the kernel cause a significant performance penalty
369
409
  #define N_DST 4 // each SIMD group works on 4 rows
370
410
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371
411
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
372
- kernel void kernel_mul_mat_q4_0_f32(
373
- device const void * src0,
374
- device const float * src1,
375
- device float * dst,
376
- constant int64_t & ne00,
377
- constant int64_t & ne10,
378
- constant int64_t & ne0,
379
- constant int64_t & ne01[[buffer(4)]],
380
- uint2 tgpig[[threadgroup_position_in_grid]],
381
- uint tiisg[[thread_index_in_simdgroup]],
382
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
412
+ template<typename block_q_type>
413
+ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
414
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415
+ uint2 tgpig, uint tiisg, uint sgitg) {
383
416
  const int nb = ne00/QK4_0;
384
417
  const int r0 = tgpig.x;
385
418
  const int r1 = tgpig.y;
386
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
419
+ device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
387
420
  device const float * y = (device const float *) src1 + r1*ne10;
388
- block_q4_0 qb_curr, qb_next;
389
421
  float4 y_curr[8]; // src1 vector cache
390
422
  float sumf[N_DST]={0.f}, all_sum;
391
423
  thread float * yl=(thread float *)y_curr;
392
424
 
393
- // bootstrap
394
- qb_curr = x[tiisg];
395
425
  // each thread in a SIMD group deals with 1 block.
396
426
  for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
397
-
427
+ float sumy = 0;
398
428
  for (int i = 0; i < QK4_0 / 4; i++) {
399
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
429
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
430
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
400
431
  }
401
432
 
402
433
  for (int row = 0; row < N_DST; row++) {
403
- // prefetch next x block
404
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
405
-
406
- // calculate
407
- float d = qb_curr.d;
408
- float2 acc = {0.0f, 0.0f};
409
- for (int i = 0; i < 16; i++) {
410
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
411
- acc[1] += yl[i] + yl[i+16];
412
- }
413
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
414
- qb_curr = qb_next;
434
+ sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
415
435
  }
416
436
  }
417
437
 
418
- for (int i = 0; i < QK4_0 / 4; i++) {
419
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
420
- }
421
-
422
- for (int row = 0; row < N_DST; row++) {
423
- // prefetch next x block
424
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
425
-
426
- // calculate
427
- float d = qb_curr.d;
428
- float2 acc = {0.0f, 0.0f};
429
- for (int i = 0; i < 16; i++) {
430
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
431
- acc[1] += yl[i] + yl[i+16];
438
+ // from now loads two rows every time and 16 blocks per row
439
+ int ir = tiisg / (N_SIMDWIDTH / 2);
440
+ int ib = tiisg % (N_SIMDWIDTH / 2);
441
+ for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
442
+ int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
443
+ float sumy = 0;
444
+ for (int i = 0; i < QK4_0 / 4; i++) {
445
+ y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
446
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
432
447
  }
433
- if (tiisg < nb % N_SIMDWIDTH) {
434
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
448
+
449
+ for (int row = 0; row < N_DST; row+=2) {
450
+ if (nb_start + ib < nb) {
451
+ sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
452
+ }
435
453
  }
436
- qb_curr = qb_next;
454
+ }
437
455
 
456
+ for (int row = 0; row < N_DST; ++row) {
438
457
  all_sum = simd_sum(sumf[row]);
439
458
  if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
440
459
  dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
@@ -442,73 +461,32 @@ kernel void kernel_mul_mat_q4_0_f32(
442
461
  }
443
462
  }
444
463
 
445
- kernel void kernel_mul_mat_q4_1_f32(
464
+ kernel void kernel_mul_mat_q4_0_f32(
446
465
  device const void * src0,
447
466
  device const float * src1,
448
467
  device float * dst,
449
468
  constant int64_t & ne00,
450
469
  constant int64_t & ne10,
451
470
  constant int64_t & ne0,
452
- threadgroup float * sum [[threadgroup(0)]],
471
+ constant int64_t & ne01[[buffer(4)]],
453
472
  uint2 tgpig[[threadgroup_position_in_grid]],
454
- uint2 tpitg[[thread_position_in_threadgroup]],
455
- uint2 tptg[[threads_per_threadgroup]]) {
456
- const int nb = ne00/QK4_1;
457
-
458
- const int64_t r0 = tgpig.x;
459
- const int64_t r1 = tgpig.y;
460
-
461
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
462
- device const float * y = (device const float *) src1 + r1*ne10;
463
-
464
- const uint nth = tptg.x*tptg.y;
465
- const uint ith = tptg.y*tpitg.x + tpitg.y;
466
-
467
- const int ix = tpitg.y/4; // 0 or 1
468
- const int iy = tpitg.y - 4*ix; // 0...3
469
-
470
- const int first = 4 * iy;
471
-
472
- float sumf = 0;
473
-
474
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
475
-
476
- const float d = (float)x[i].d;
477
- const float m = (float)x[i].m;
478
-
479
- device const uint8_t * xl = x[i].qs + first;
480
- device const float * yl = y + i * QK4_1 + first;
481
-
482
- float2 acc = {0.0f, 0.0f};
483
-
484
- for (int j = 0; j < 4; ++j) {
485
-
486
- acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
487
- acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
488
-
489
- }
490
-
491
- sumf += acc[0] + acc[1];
492
- }
493
-
494
- sum[ith] = sumf;
473
+ uint tiisg[[thread_index_in_simdgroup]],
474
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
475
+ mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476
+ }
495
477
 
496
- //
497
- // Accumulate the sum from all threads in the threadgroup
498
- //
499
- threadgroup_barrier(mem_flags::mem_threadgroup);
500
- if (ith%4 == 0) {
501
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
502
- }
503
- threadgroup_barrier(mem_flags::mem_threadgroup);
504
- if (ith%16 == 0) {
505
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
506
- }
507
- threadgroup_barrier(mem_flags::mem_threadgroup);
508
- if (ith == 0) {
509
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
510
- dst[r1*ne0 + r0] = sum[0];
511
- }
478
+ kernel void kernel_mul_mat_q4_1_f32(
479
+ device const void * src0,
480
+ device const float * src1,
481
+ device float * dst,
482
+ constant int64_t & ne00,
483
+ constant int64_t & ne10,
484
+ constant int64_t & ne0,
485
+ constant int64_t & ne01[[buffer(4)]],
486
+ uint2 tgpig[[threadgroup_position_in_grid]],
487
+ uint tiisg[[thread_index_in_simdgroup]],
488
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
+ mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
512
490
  }
513
491
 
514
492
  kernel void kernel_mul_mat_f16_f32(
@@ -624,17 +602,19 @@ kernel void kernel_rope(
624
602
  constant int & n_past,
625
603
  constant int & n_dims,
626
604
  constant int & mode,
605
+ constant float & freq_base,
606
+ constant float & freq_scale,
627
607
  uint3 tpig[[thread_position_in_grid]]) {
628
608
  const int64_t i3 = tpig[2];
629
609
  const int64_t i2 = tpig[1];
630
610
  const int64_t i1 = tpig[0];
631
611
 
632
612
  const bool is_neox = mode & 2;
633
- const float theta_scale = pow(10000.0, -2.0f/n_dims);
613
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
634
614
 
635
615
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
636
616
 
637
- float theta = (float)p;
617
+ float theta = freq_scale * (float)p;
638
618
 
639
619
  if (!is_neox) {
640
620
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -1229,111 +1209,137 @@ kernel void kernel_mul_mat_q2_K_f32(
1229
1209
  constant int64_t & ne00,
1230
1210
  constant int64_t & ne10,
1231
1211
  constant int64_t & ne0,
1232
- threadgroup float * sum [[threadgroup(0)]],
1212
+ constant int64_t & ne01[[buffer(4)]],
1233
1213
  uint2 tgpig[[threadgroup_position_in_grid]],
1234
- uint2 tpitg[[thread_position_in_threadgroup]],
1235
- uint2 tptg[[threads_per_threadgroup]]) {
1214
+ uint tiisg[[thread_index_in_simdgroup]],
1215
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1236
1216
 
1237
1217
  const int nb = ne00/QK_K;
1218
+ const int r0 = tgpig.x;
1219
+ const int r1 = tgpig.y;
1238
1220
 
1239
- const int64_t r0 = tgpig.x;
1240
- const int64_t r1 = tgpig.y;
1241
-
1242
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1243
- device const float * yy = (device const float *) src1 + r1*ne10;
1244
-
1245
- const int nth = tptg.x*tptg.y;
1246
- const int ith = tptg.y*tpitg.x + tpitg.y;
1221
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1222
+ const int ib_row = first_row * nb;
1223
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1224
+ device const float * y = (device const float *) src1 + r1*ne10;
1225
+ float yl[32];
1226
+ float sumf[N_DST]={0.f}, all_sum;
1247
1227
 
1248
- float sumf = 0;
1228
+ const int step = sizeof(block_q2_K) * nb;
1249
1229
 
1250
1230
  #if QK_K == 256
1251
- const int tid = tpitg.y; // 0...16
1252
- const int il = tid/4; // 0...3
1253
- const int ir = tid%4; // 0...3
1254
- const int ip = il/2; // 0 or 1
1255
- const int shift1 = 4*(il%2);// 0 or 4
1256
- const int shift2 = shift1+2;// 2 or 6
1257
- const int n = 8;
1258
- const int is = 4*il + (n*ir)/16;
1259
-
1260
- const int y_offset = 64*il + n*ir;
1261
- const int q_offset = 32*ip + n*ir;
1262
-
1263
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1264
-
1265
- device const uint8_t * q = x[i].qs + q_offset;
1266
- device const uint8_t * scales = x[i].scales + is;
1231
+ const int ix = tiisg/8; // 0...3
1232
+ const int it = tiisg%8; // 0...7
1233
+ const int im = it/4; // 0 or 1
1234
+ const int ir = it%4; // 0...3
1235
+ const int is = (8*ir)/16;// 0 or 1
1236
+
1237
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1238
+
1239
+ for (int ib = ix; ib < nb; ib += 4) {
1240
+
1241
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1242
+ for (int i = 0; i < 8; ++i) {
1243
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1244
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
1245
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
1246
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1247
+ }
1267
1248
 
1268
- uint8_t d1 = scales[0] & 0xF;
1269
- uint8_t d2 = scales[2] & 0xF;
1270
- uint8_t m1 = scales[0] >> 4;
1271
- uint8_t m2 = scales[2] >> 4;
1249
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1250
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1251
+ device const half * dh = &x[ib].d;
1272
1252
 
1273
- device const float * y = yy + i*QK_K + y_offset;
1253
+ for (int row = 0; row < N_DST; row++) {
1274
1254
 
1275
- float2 s = {0.f, 0.f};
1276
- float smin = 0;
1277
- for (int l = 0; l < n; ++l) {
1278
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1279
- s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1280
- smin += y[l+ 0] * m1 + y[l+32] * m2;
1255
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1256
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1257
+ for (int i = 0; i < 8; i += 2) {
1258
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1259
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1260
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1261
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1262
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1263
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1264
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1265
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1266
+ }
1267
+ float dall = dh[0];
1268
+ float dmin = dh[1] * 1.f/16.f;
1269
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1270
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
1271
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
1272
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
1273
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
1274
+
1275
+ qs += step/2;
1276
+ sc += step;
1277
+ dh += step/2;
1281
1278
  }
1282
1279
 
1283
- const float dall = (float)x[i].d;
1284
- const float dmin = (float)x[i].dmin;
1285
-
1286
- sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1287
-
1280
+ y4 += 4 * QK_K;
1288
1281
  }
1289
1282
  #else
1290
- const int il = 4 * tpitg.x;
1283
+ const int ix = tiisg/2; // 0...15
1284
+ const int it = tiisg%2; // 0...1
1291
1285
 
1292
- uint32_t aux[2];
1293
- thread const uint8_t * d = (thread const uint8_t *)aux;
1294
- thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1286
+ device const float * y4 = y + ix * QK_K + 8 * it;
1295
1287
 
1296
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1288
+ for (int ib = ix; ib < nb; ib += 16) {
1297
1289
 
1298
- device const uint8_t * q = x[i].qs + il;
1299
- device const float * y = yy + i*QK_K + il;
1290
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1291
+ for (int i = 0; i < 8; ++i) {
1292
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1293
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
1294
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
1295
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
1296
+ }
1300
1297
 
1301
- const float dall = (float)x[i].d;
1302
- const float dmin = (float)x[i].dmin;
1298
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
1299
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1300
+ device const half * dh = &x[ib].d;
1303
1301
 
1304
- device const uint32_t * a = (device const uint32_t *)x[i].scales;
1305
- aux[0] = a[0] & 0x0f0f0f0f;
1306
- aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1302
+ for (int row = 0; row < N_DST; row++) {
1307
1303
 
1308
- for (int l = 0; l < 4; ++l) {
1309
- sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1310
- + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1311
- + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1312
- + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1305
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1306
+ for (int i = 0; i < 8; i += 2) {
1307
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1308
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1309
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1310
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1311
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1312
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1313
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1314
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1315
+ }
1316
+
1317
+ float dall = dh[0];
1318
+ float dmin = dh[1];
1319
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1320
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1321
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1322
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1323
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1324
+
1325
+ qs += step/2;
1326
+ sc += step;
1327
+ dh += step/2;
1313
1328
  }
1329
+
1330
+ y4 += 16 * QK_K;
1314
1331
  }
1315
1332
  #endif
1316
1333
 
1317
- sum[ith] = sumf;
1318
-
1319
- //
1320
- // Accumulate the sum from all threads in the threadgroup
1321
- //
1322
- threadgroup_barrier(mem_flags::mem_threadgroup);
1323
- if (ith%4 == 0) {
1324
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1325
- }
1326
- threadgroup_barrier(mem_flags::mem_threadgroup);
1327
- if (ith%16 == 0) {
1328
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1329
- }
1330
- threadgroup_barrier(mem_flags::mem_threadgroup);
1331
- if (ith == 0) {
1332
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1333
- dst[r1*ne0 + r0] = sum[0];
1334
+ for (int row = 0; row < N_DST; ++row) {
1335
+ all_sum = simd_sum(sumf[row]);
1336
+ if (tiisg == 0) {
1337
+ dst[r1*ne0 + first_row + row] = all_sum;
1338
+ }
1334
1339
  }
1335
1340
  }
1336
1341
 
1342
+ #if QK_K == 256
1337
1343
  kernel void kernel_mul_mat_q3_K_f32(
1338
1344
  device const void * src0,
1339
1345
  device const float * src1,
@@ -1342,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
1342
1348
  constant int64_t & ne10,
1343
1349
  constant int64_t & ne0,
1344
1350
  constant int64_t & ne1,
1345
- threadgroup float * sum [[threadgroup(0)]],
1346
1351
  uint2 tgpig[[threadgroup_position_in_grid]],
1347
- uint2 tpitg[[thread_position_in_threadgroup]],
1348
- uint2 tptg[[threads_per_threadgroup]]) {
1352
+ uint tiisg[[thread_index_in_simdgroup]],
1353
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1349
1354
 
1350
1355
  const int nb = ne00/QK_K;
1351
1356
 
1352
1357
  const int64_t r0 = tgpig.x;
1353
1358
  const int64_t r1 = tgpig.y;
1354
1359
 
1355
- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1356
- device const float * yy = (device const float *) src1 + r1*ne10;
1360
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1357
1361
 
1358
- const int nth = tptg.x*tptg.y;
1359
- const int ith = tptg.y*tpitg.x + tpitg.y;
1360
-
1361
- #if QK_K == 256
1362
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1363
+ device const float * yy = (device const float *) src1 + r1*ne10;
1362
1364
 
1363
- const uint8_t m3 = 3;
1364
- const int8_t m4 = 4;
1365
+ float yl[16];
1365
1366
 
1366
1367
  const uint16_t kmask1 = 0x0303;
1367
1368
  const uint16_t kmask2 = 0x0f0f;
1368
1369
 
1369
- const int tid = tpitg.y; // expecting 16
1370
+ const int tid = tiisg/2;
1371
+ const int ix = tiisg%2;
1370
1372
  const int ip = tid/8; // 0 or 1
1371
1373
  const int il = tid/2 - 4*ip; // 0...3
1372
1374
  const int ir = tid%2;
1373
1375
  const int n = 8;
1374
1376
  const int l0 = n*ir;
1375
1377
 
1376
- const uint8_t m = 1 << (4*ip + il);
1378
+ const uint16_t m1 = 1 << (4*ip + il);
1379
+ const uint16_t m2 = m1 << 8;
1377
1380
 
1378
1381
  const int shift = 2*il;
1382
+ const uint16_t qm1 = 0x0003 << shift;
1383
+ const uint16_t qm2 = 0x0300 << shift;
1384
+ const int32_t v1 = 4 << shift;
1385
+ const int32_t v2 = 1024 << shift;
1379
1386
 
1380
1387
  const uint16_t s_shift1 = 4*ip;
1381
1388
  const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1384,226 +1391,315 @@ kernel void kernel_mul_mat_q3_K_f32(
1384
1391
  const int q_offset = 32*ip + l0;
1385
1392
  const int y_offset = 128*ip + 32*il + l0;
1386
1393
 
1387
- //float sumf = 0;
1388
- float sumf1 = 0, sumf2 = 0;
1389
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1390
-
1391
- const float d_all = (float)(x[i].d);
1394
+ const int step = sizeof(block_q3_K) * nb / 2;
1392
1395
 
1393
- device const uint8_t * q = x[i].qs + q_offset;
1394
- device const uint8_t * h = x[i].hmask + l0;
1395
- device const float * y = yy + i * QK_K + y_offset;
1396
+ device const float * y1 = yy + ix*QK_K + y_offset;
1396
1397
 
1397
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1398
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1398
+ float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1399
+ for (int i = ix; i < nb; i += 2) {
1399
1400
 
1400
- float s = 0;
1401
- for (int l = 0; l < n; ++l) {
1402
- s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1401
+ for (int l = 0; l < 8; ++l) {
1402
+ yl[l+0] = y1[l+ 0];
1403
+ yl[l+8] = y1[l+16];
1403
1404
  }
1404
- float d = d_all * s;
1405
- sumf1 += d * scales[0];
1406
- sumf2 += d;
1407
- //sumf += d_all * s * (scales[0] - 32);
1408
1405
 
1409
- s = 0;
1410
- for (int l = 0; l < n; ++l) {
1411
- s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1412
- }
1413
- d = d_all * s;
1414
- sumf1 += d * scales[1];
1415
- sumf2 += d;
1416
- //sumf += d_all * s * (scales[1] - 32);
1406
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1407
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1408
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1409
+ device const half * dh = &x[i].d;
1417
1410
 
1418
- }
1411
+ for (int row = 0; row < 2; ++row) {
1419
1412
 
1420
- //sum[ith] = sumf;
1421
- sum[ith] = sumf1 - 32.f*sumf2;
1422
- #else
1423
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1424
- const int im = il/8; // 0, 0, 1, 1
1425
- const int in = il%8; // 0, 4, 0, 4
1413
+ const float d_all = (float)dh[0];
1414
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1426
1415
 
1427
- float sumf = 0;
1428
-
1429
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1430
-
1431
- const float d_all = (float)(x[i].d);
1416
+ float s1 = 0, s2 = 0;
1417
+ for (int l = 0; l < n; l += 2) {
1418
+ const uint16_t qs = q[l/2];
1419
+ s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1420
+ s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1421
+ }
1422
+ float d = d_all * (s1 + 1.f/256.f * s2);
1423
+ sumf1[row] += d * scales[0];
1424
+ sumf2[row] += d;
1425
+
1426
+ s1 = s2 = 0;
1427
+ for (int l = 0; l < n; l += 2) {
1428
+ const uint16_t qs = q[l/2+8];
1429
+ s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1430
+ s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1431
+ }
1432
+ d = d_all * (s1 + 1.f/256.f * s2);
1433
+ sumf1[row] += d * scales[1];
1434
+ sumf2[row] += d;
1432
1435
 
1433
- device const uint8_t * q = x[i].qs + il;
1434
- device const uint8_t * h = x[i].hmask + in;
1435
- device const float * y = yy + i * QK_K + il;
1436
+ q += step;
1437
+ h += step;
1438
+ a += step;
1439
+ dh += step;
1436
1440
 
1437
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1438
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1439
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1440
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1441
-
1442
- for (int l = 0; l < 4; ++l) {
1443
- const uint8_t hm = h[l] >> im;
1444
- sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1445
- + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1446
- + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1447
- + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1448
1441
  }
1449
1442
 
1450
- }
1451
-
1452
- sum[ith] = sumf;
1453
-
1454
- #endif
1443
+ y1 += 2 * QK_K;
1455
1444
 
1456
- //
1457
- // Accumulate the sum from all threads in the threadgroup
1458
- //
1459
- threadgroup_barrier(mem_flags::mem_threadgroup);
1460
- if (ith%4 == 0) {
1461
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1462
- }
1463
- threadgroup_barrier(mem_flags::mem_threadgroup);
1464
- if (ith%16 == 0) {
1465
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1466
- }
1467
- threadgroup_barrier(mem_flags::mem_threadgroup);
1468
- if (ith == 0) {
1469
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1470
- dst[r1*ne0 + r0] = sum[0];
1471
1445
  }
1472
1446
 
1447
+ for (int row = 0; row < 2; ++row) {
1448
+ const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1449
+ const float tot = simd_sum(sumf);
1450
+ if (tiisg == 0) {
1451
+ dst[r1*ne0 + first_row + row] = tot;
1452
+ }
1453
+ }
1473
1454
  }
1474
-
1475
- kernel void kernel_mul_mat_q4_K_f32(
1455
+ #else
1456
+ kernel void kernel_mul_mat_q3_K_f32(
1476
1457
  device const void * src0,
1477
1458
  device const float * src1,
1478
1459
  device float * dst,
1479
1460
  constant int64_t & ne00,
1480
1461
  constant int64_t & ne10,
1481
1462
  constant int64_t & ne0,
1482
- threadgroup float * sum [[threadgroup(0)]],
1463
+ constant int64_t & ne1,
1483
1464
  uint2 tgpig[[threadgroup_position_in_grid]],
1484
- uint2 tpitg[[thread_position_in_threadgroup]],
1485
- uint2 tptg[[threads_per_threadgroup]]) {
1465
+ uint tiisg[[thread_index_in_simdgroup]],
1466
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1486
1467
 
1487
1468
  const int nb = ne00/QK_K;
1488
1469
 
1489
1470
  const int64_t r0 = tgpig.x;
1490
1471
  const int64_t r1 = tgpig.y;
1491
1472
 
1492
- const int nth = tptg.x*tptg.y;
1493
- const int ith = tptg.y*tpitg.x + tpitg.y;
1473
+ const int row = 2 * r0 + sgitg;
1494
1474
 
1495
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1475
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1496
1476
  device const float * yy = (device const float *) src1 + r1*ne10;
1477
+ const int ix = tiisg/4;
1478
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1479
+ const int im = il/8; // 0, 0, 1, 1
1480
+ const int in = il%8; // 0, 4, 0, 4
1497
1481
 
1498
- float sumf = 0;
1482
+ float2 sum = {0.f, 0.f};
1483
+
1484
+ for (int i = ix; i < nb; i += 8) {
1485
+
1486
+ const float d_all = (float)(x[i].d);
1487
+
1488
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1489
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1490
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1491
+ device const float * y = yy + i * QK_K + il;
1492
+
1493
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1494
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1495
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1496
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1497
+
1498
+ for (int l = 0; l < 4; l += 2) {
1499
+ const uint16_t hm = h[l/2] >> im;
1500
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1501
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1502
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1503
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1504
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1505
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1506
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1507
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
1508
+ }
1509
+
1510
+ }
1511
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
1512
+
1513
+ const float tot = simd_sum(sumf);
1514
+ if (tiisg == 0) {
1515
+ dst[r1*ne0 + row] = tot;
1516
+ }
1517
+
1518
+ }
1519
+ #endif
1499
1520
 
1500
1521
  #if QK_K == 256
1522
+ kernel void kernel_mul_mat_q4_K_f32(
1523
+ device const void * src0,
1524
+ device const float * src1,
1525
+ device float * dst,
1526
+ constant int64_t & ne00,
1527
+ constant int64_t & ne10,
1528
+ constant int64_t & ne0,
1529
+ constant int64_t & ne01[[buffer(4)]],
1530
+ uint2 tgpig[[threadgroup_position_in_grid]],
1531
+ uint tiisg[[thread_index_in_simdgroup]],
1532
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1501
1533
 
1502
1534
  const uint16_t kmask1 = 0x3f3f;
1503
1535
  const uint16_t kmask2 = 0x0f0f;
1504
1536
  const uint16_t kmask3 = 0xc0c0;
1505
1537
 
1506
- const int tid = tpitg.y; // 0...16
1507
- const int il = tid/4; // 0...3
1508
- const int ir = tid - 4*il;// 0...3
1509
- const int n = 4;
1538
+ const int ix = tiisg/8; // 0...3
1539
+ const int it = tiisg%8; // 0...7
1540
+ const int im = it/4; // 0 or 1
1541
+ const int ir = it%4; // 0...3
1510
1542
 
1511
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1512
- const int in = il%2;
1543
+ const int nb = ne00/QK_K;
1544
+ const int r0 = tgpig.x;
1545
+ const int r1 = tgpig.y;
1546
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1547
+ const int ib_row = first_row * nb;
1548
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1549
+ device const float * y = (device const float *) src1 + r1*ne10;
1550
+ float yl[16];
1551
+ float yh[16];
1552
+ float sumf[N_DST]={0.f}, all_sum;
1513
1553
 
1514
- const int l0 = n*(2*ir + in);
1515
- const int q_offset = 32*im + l0;
1516
- const int y_offset = 64*im + l0;
1554
+ const int step = sizeof(block_q4_K) * nb / 2;
1517
1555
 
1518
- uchar2 sc1, sc2, sc3, sc4;
1556
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
1519
1557
 
1520
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1558
+ uint16_t sc16[4];
1559
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1521
1560
 
1522
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1523
- device const uint8_t * q2 = q1 + 64;
1524
- device const float * y1 = yy + i*QK_K + y_offset;
1525
- device const float * y2 = y1 + 128;
1561
+ for (int ib = ix; ib < nb; ib += 4) {
1526
1562
 
1527
- const float dall = (float)((x + i)->d);
1528
- const float dmin = (float)((x + i)->dmin);
1563
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1564
+ for (int i = 0; i < 8; ++i) {
1565
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
1566
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
1567
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
1568
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1569
+ }
1529
1570
 
1530
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1531
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1532
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1533
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1534
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1571
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1572
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1573
+ device const half * dh = &x[ib].d;
1535
1574
 
1536
- float4 s = {0.f, 0.f, 0.f, 0.f};
1537
- float smin = 0;
1538
- for (int l = 0; l < n; ++l) {
1575
+ for (int row = 0; row < N_DST; row++) {
1539
1576
 
1540
- s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1541
- s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1542
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1577
+ sc16[0] = sc[0] & kmask1;
1578
+ sc16[1] = sc[2] & kmask1;
1579
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
1580
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
1581
+
1582
+ device const uint16_t * q2 = q1 + 32;
1583
+
1584
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1585
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1586
+ for (int i = 0; i < 8; i += 2) {
1587
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
1588
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
1589
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
1590
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
1591
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
1592
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
1593
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
1594
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
1595
+ }
1543
1596
 
1597
+ float dall = dh[0];
1598
+ float dmin = dh[1];
1599
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
1600
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
1601
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
1602
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
1603
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1604
+
1605
+ q1 += step;
1606
+ sc += step;
1607
+ dh += step;
1544
1608
  }
1545
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1546
1609
 
1610
+ y4 += 4 * QK_K;
1611
+ }
1612
+
1613
+ for (int row = 0; row < N_DST; ++row) {
1614
+ all_sum = simd_sum(sumf[row]);
1615
+ if (tiisg == 0) {
1616
+ dst[r1*ne0 + first_row + row] = all_sum;
1617
+ }
1547
1618
  }
1619
+ }
1548
1620
  #else
1549
- uint16_t aux16[2];
1550
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
1621
+ kernel void kernel_mul_mat_q4_K_f32(
1622
+ device const void * src0,
1623
+ device const float * src1,
1624
+ device float * dst,
1625
+ constant int64_t & ne00,
1626
+ constant int64_t & ne10,
1627
+ constant int64_t & ne0,
1628
+ constant int64_t & ne01[[buffer(4)]],
1629
+ uint2 tgpig[[threadgroup_position_in_grid]],
1630
+ uint tiisg[[thread_index_in_simdgroup]],
1631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1551
1632
 
1552
- const int il = 4*tpitg.x;
1633
+ const int ix = tiisg/4; // 0...7
1634
+ const int it = tiisg%4; // 0...3
1553
1635
 
1554
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1636
+ const int nb = ne00/QK_K;
1637
+ const int r0 = tgpig.x;
1638
+ const int r1 = tgpig.y;
1639
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1640
+ const int ib_row = first_row * nb;
1641
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1642
+ device const float * y = (device const float *) src1 + r1*ne10;
1643
+ float yl[8];
1644
+ float yh[8];
1645
+ float sumf[N_DST]={0.f}, all_sum;
1555
1646
 
1556
- device const uint8_t * q = x[i].qs + il;
1557
- device const float * y = yy + i * QK_K + il;
1647
+ const int step = sizeof(block_q4_K) * nb / 2;
1558
1648
 
1559
- const float d = (float)x[i].d[0];
1560
- const float m = (float)x[i].d[1];
1649
+ device const float * y4 = y + ix * QK_K + 8 * it;
1561
1650
 
1562
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1563
- aux16[0] = a[0] & 0x0f0f;
1564
- aux16[1] = (a[0] >> 4) & 0x0f0f;
1651
+ uint16_t sc16[4];
1565
1652
 
1566
- for (int l = 0; l < 4; ++l) {
1567
- sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1568
- + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1653
+ for (int ib = ix; ib < nb; ib += 8) {
1654
+
1655
+ float2 sumy = {0.f, 0.f};
1656
+ for (int i = 0; i < 8; ++i) {
1657
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
1658
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
1569
1659
  }
1570
- }
1571
- #endif
1572
1660
 
1573
- sum[ith] = sumf;
1661
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
1662
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1663
+ device const half * dh = x[ib].d;
1574
1664
 
1575
- //
1576
- // Accumulate the sum from all threads in the threadgroup
1577
- // This version is slightly faster than the commented out one below,
1578
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1579
- //
1580
- threadgroup_barrier(mem_flags::mem_threadgroup);
1581
- if (ith%4 == 0) {
1582
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1583
- }
1584
- threadgroup_barrier(mem_flags::mem_threadgroup);
1585
- if (ith%16 == 0) {
1586
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1587
- }
1588
- threadgroup_barrier(mem_flags::mem_threadgroup);
1589
- if (ith == 0) {
1590
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1591
- dst[r1*ne0 + r0] = sum[0];
1665
+ for (int row = 0; row < N_DST; row++) {
1666
+
1667
+ sc16[0] = sc[0] & 0x000f;
1668
+ sc16[1] = sc[0] & 0x0f00;
1669
+ sc16[2] = sc[0] & 0x00f0;
1670
+ sc16[3] = sc[0] & 0xf000;
1671
+
1672
+ float2 acc1 = {0.f, 0.f};
1673
+ float2 acc2 = {0.f, 0.f};
1674
+ for (int i = 0; i < 8; i += 2) {
1675
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
1676
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
1677
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
1678
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
1679
+ }
1680
+
1681
+ float dall = dh[0];
1682
+ float dmin = dh[1];
1683
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
1684
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
1685
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
1686
+
1687
+ qs += step;
1688
+ sc += step;
1689
+ dh += step;
1690
+ }
1691
+
1692
+ y4 += 8 * QK_K;
1592
1693
  }
1593
1694
 
1594
- //// accumulate the sum from all threads in the threadgroup
1595
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1596
- //for (uint i = nth/2; i > 0; i /= 2) {
1597
- // if (ith < i) {
1598
- // sum[ith] += sum[ith + i];
1599
- // }
1600
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1601
- //}
1602
-
1603
- //if (ith == 0) {
1604
- // dst[r1*ne0 + r0] = sum[0];
1605
- //}
1695
+ for (int row = 0; row < N_DST; ++row) {
1696
+ all_sum = simd_sum(sumf[row]);
1697
+ if (tiisg == 0) {
1698
+ dst[r1*ne0 + first_row + row] = all_sum;
1699
+ }
1700
+ }
1606
1701
  }
1702
+ #endif
1607
1703
 
1608
1704
  kernel void kernel_mul_mat_q5_K_f32(
1609
1705
  device const void * src0,
@@ -1612,39 +1708,39 @@ kernel void kernel_mul_mat_q5_K_f32(
1612
1708
  constant int64_t & ne00,
1613
1709
  constant int64_t & ne10,
1614
1710
  constant int64_t & ne0,
1615
- threadgroup float * sum [[threadgroup(0)]],
1616
1711
  uint2 tgpig[[threadgroup_position_in_grid]],
1617
- uint2 tpitg[[thread_position_in_threadgroup]],
1618
- uint2 tptg[[threads_per_threadgroup]]) {
1712
+ uint tiisg[[thread_index_in_simdgroup]],
1713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1619
1714
 
1620
1715
  const int nb = ne00/QK_K;
1621
1716
 
1622
1717
  const int64_t r0 = tgpig.x;
1623
1718
  const int64_t r1 = tgpig.y;
1624
1719
 
1625
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1721
+
1722
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1626
1723
  device const float * yy = (device const float *) src1 + r1*ne10;
1627
1724
 
1628
- const int nth = tptg.x*tptg.y;
1629
- const int ith = tptg.y*tpitg.x + tpitg.y;
1725
+ float sumf[2]={0.f};
1630
1726
 
1631
- float sumf = 0;
1727
+ const int step = sizeof(block_q5_K) * nb;
1632
1728
 
1633
1729
  #if QK_K == 256
1730
+ #
1731
+ float yl[16], yh[16];
1634
1732
 
1635
1733
  const uint16_t kmask1 = 0x3f3f;
1636
1734
  const uint16_t kmask2 = 0x0f0f;
1637
1735
  const uint16_t kmask3 = 0xc0c0;
1638
1736
 
1639
- const int tid = tpitg.y; // 0...16
1640
- const int il = tid/4; // 0...3
1641
- const int ir = tid - 4*il;// 0...3
1642
- const int n = 4;
1643
-
1644
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1645
- const int in = il%2;
1737
+ const int tid = tiisg/4;
1738
+ const int ix = tiisg%4;
1739
+ const int im = tid/4;
1740
+ const int ir = tid%4;
1741
+ const int n = 8;
1646
1742
 
1647
- const int l0 = n*(2*ir + in);
1743
+ const int l0 = n*ir;
1648
1744
  const int q_offset = 32*im + l0;
1649
1745
  const int y_offset = 64*im + l0;
1650
1746
 
@@ -1653,78 +1749,113 @@ kernel void kernel_mul_mat_q5_K_f32(
1653
1749
  const uint8_t hm3 = hm1 << 4;
1654
1750
  const uint8_t hm4 = hm2 << 4;
1655
1751
 
1656
- uchar2 sc1, sc2, sc3, sc4;
1752
+ uint16_t sc16[4];
1753
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1657
1754
 
1658
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1755
+ device const float * y1 = yy + ix*QK_K + y_offset;
1659
1756
 
1660
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1661
- device const uint8_t * q2 = q1 + 64;
1662
- device const uint8_t * qh = (x + i)->qh + l0;
1663
- device const float * y1 = yy + i*QK_K + y_offset;
1664
- device const float * y2 = y1 + 128;
1757
+ for (int i = ix; i < nb; i += 4) {
1665
1758
 
1666
- const float dall = (float)((x + i)->d);
1667
- const float dmin = (float)((x + i)->dmin);
1759
+ device const uint8_t * q1 = x[i].qs + q_offset;
1760
+ device const uint8_t * qh = x[i].qh + l0;
1761
+ device const half * dh = &x[i].d;
1762
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1668
1763
 
1669
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1670
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1671
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1672
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1673
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1764
+ device const float * y2 = y1 + 128;
1765
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1766
+ for (int l = 0; l < 8; ++l) {
1767
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1768
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1769
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1770
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1771
+ }
1674
1772
 
1675
- float4 s = {0.f, 0.f, 0.f, 0.f};
1676
- float smin = 0;
1677
- for (int l = 0; l < n; ++l) {
1773
+ for (int row = 0; row < 2; ++row) {
1774
+
1775
+ device const uint8_t * q2 = q1 + 64;
1776
+
1777
+ sc16[0] = a[0] & kmask1;
1778
+ sc16[1] = a[2] & kmask1;
1779
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1780
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1678
1781
 
1679
- s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1680
- s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1681
- s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1682
- s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1683
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1782
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
1783
+ for (int l = 0; l < n; ++l) {
1784
+ uint8_t h = qh[l];
1785
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1786
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1787
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1788
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1789
+ }
1790
+ const float dall = dh[0];
1791
+ const float dmin = dh[1];
1792
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1793
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1794
+
1795
+ q1 += step;
1796
+ qh += step;
1797
+ dh += step/2;
1798
+ a += step/2;
1684
1799
 
1685
1800
  }
1686
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1801
+
1802
+ y1 += 4 * QK_K;
1687
1803
 
1688
1804
  }
1689
1805
  #else
1690
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1691
- const int im = il/8; // 0, 0, 1, 1
1692
- const int in = il%8; // 0, 4, 0, 4
1806
+ float yl[8], yh[8];
1807
+
1808
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1809
+ const int ix = tiisg%8;
1810
+ const int im = il/8; // 0, 0, 1, 1
1811
+ const int in = il%8; // 0, 4, 0, 4
1693
1812
 
1694
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1813
+ device const float * y = yy + ix*QK_K + il;
1695
1814
 
1696
- const float d = (float)x[i].d;
1815
+ for (int i = ix; i < nb; i += 8) {
1816
+
1817
+ for (int l = 0; l < 4; ++l) {
1818
+ yl[l+0] = y[l+ 0];
1819
+ yl[l+4] = y[l+16];
1820
+ yh[l+0] = y[l+32];
1821
+ yh[l+4] = y[l+48];
1822
+ }
1823
+
1824
+ device const half * dh = &x[i].d;
1697
1825
  device const uint8_t * q = x[i].qs + il;
1698
1826
  device const uint8_t * h = x[i].qh + in;
1699
1827
  device const int8_t * s = x[i].scales;
1700
- device const float * y = yy + i*QK_K + il;
1701
1828
 
1702
- for (int l = 0; l < 4; ++l) {
1703
- const uint8_t hl = h[l] >> im;
1704
- sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1705
- + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1706
- + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1707
- + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1829
+ for (int row = 0; row < 2; ++row) {
1830
+
1831
+ const float d = dh[0];
1832
+
1833
+ float2 acc = {0.f, 0.f};
1834
+ for (int l = 0; l < 4; ++l) {
1835
+ const uint8_t hl = h[l] >> im;
1836
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1837
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1838
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1839
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1840
+ }
1841
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1842
+
1843
+ q += step;
1844
+ h += step;
1845
+ s += step;
1846
+ dh += step/2;
1847
+
1708
1848
  }
1849
+
1850
+ y += 8 * QK_K;
1709
1851
  }
1710
1852
  #endif
1711
- sum[ith] = sumf;
1712
1853
 
1713
- //
1714
- // Accumulate the sum from all threads in the threadgroup
1715
- //
1716
- threadgroup_barrier(mem_flags::mem_threadgroup);
1717
- if (ith%4 == 0) {
1718
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1719
- }
1720
- threadgroup_barrier(mem_flags::mem_threadgroup);
1721
- if (ith%16 == 0) {
1722
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1723
- }
1724
- threadgroup_barrier(mem_flags::mem_threadgroup);
1725
- if (ith == 0) {
1726
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1727
- dst[r1*ne0 + r0] = sum[0];
1854
+ for (int row = 0; row < 2; ++row) {
1855
+ const float tot = simd_sum(sumf[row]);
1856
+ if (tiisg == 0) {
1857
+ dst[r1*ne0 + first_row + row] = tot;
1858
+ }
1728
1859
  }
1729
1860
 
1730
1861
  }
@@ -1736,10 +1867,9 @@ kernel void kernel_mul_mat_q6_K_f32(
1736
1867
  constant int64_t & ne00,
1737
1868
  constant int64_t & ne10,
1738
1869
  constant int64_t & ne0,
1739
- threadgroup float * sum [[threadgroup(0)]],
1740
1870
  uint2 tgpig[[threadgroup_position_in_grid]],
1741
- uint2 tpitg[[thread_position_in_threadgroup]],
1742
- uint2 tptg[[threads_per_threadgroup]]) {
1871
+ uint tiisg[[thread_index_in_simdgroup]],
1872
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1743
1873
 
1744
1874
  const uint8_t kmask1 = 0x03;
1745
1875
  const uint8_t kmask2 = 0x0C;
@@ -1751,19 +1881,18 @@ kernel void kernel_mul_mat_q6_K_f32(
1751
1881
  const int64_t r0 = tgpig.x;
1752
1882
  const int64_t r1 = tgpig.y;
1753
1883
 
1754
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1755
- device const float * yy = (device const float *) src1 + r1*ne10;
1884
+ const int row = 2 * r0 + sgitg;
1756
1885
 
1757
- const int nth = tptg.x*tptg.y;
1758
- const int ith = tptg.y*tpitg.x + tpitg.y;
1886
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1887
+ device const float * yy = (device const float *) src1 + r1*ne10;
1759
1888
 
1760
1889
  float sumf = 0;
1761
1890
 
1762
1891
  #if QK_K == 256
1763
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1764
- const int iqs = 16 * tpitg.y;
1765
- const int ip = iqs / 128; // 0 or 1
1766
- const int il = (iqs - 128*ip)/16; // 0...7
1892
+ const int tid = tiisg/2;
1893
+ const int ix = tiisg%2;
1894
+ const int ip = tid/8; // 0 or 1
1895
+ const int il = tid%8;
1767
1896
  const int n = 4;
1768
1897
  const int l0 = n*il;
1769
1898
  const int is = 8*ip + l0/16;
@@ -1772,9 +1901,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1772
1901
  const int q_offset_l = 64*ip + l0;
1773
1902
  const int q_offset_h = 32*ip + l0;
1774
1903
 
1775
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1904
+ for (int i = ix; i < nb; i += 2) {
1776
1905
 
1777
- device const uint8_t * ql = x[i].ql + q_offset_l;
1906
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1907
+ device const uint8_t * q2 = q1 + 32;
1778
1908
  device const uint8_t * qh = x[i].qh + q_offset_h;
1779
1909
  device const int8_t * sc = x[i].scales + is;
1780
1910
 
@@ -1784,19 +1914,21 @@ kernel void kernel_mul_mat_q6_K_f32(
1784
1914
 
1785
1915
  float4 sums = {0.f, 0.f, 0.f, 0.f};
1786
1916
  for (int l = 0; l < n; ++l) {
1787
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1788
- sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1789
- sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1790
- sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1917
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1918
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1919
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1920
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1791
1921
  }
1792
1922
 
1793
1923
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1794
1924
 
1795
1925
  }
1926
+
1796
1927
  #else
1797
- const int il = 4*tpitg.x; // 0, 4, 8, 12
1928
+ const int ix = tiisg/4;
1929
+ const int il = 4*(tiisg%4);
1798
1930
 
1799
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1931
+ for (int i = ix; i < nb; i += 8) {
1800
1932
  device const float * y = yy + i * QK_K + il;
1801
1933
  device const uint8_t * ql = x[i].ql + il;
1802
1934
  device const uint8_t * qh = x[i].qh + il;
@@ -1816,23 +1948,8 @@ kernel void kernel_mul_mat_q6_K_f32(
1816
1948
 
1817
1949
  #endif
1818
1950
 
1819
- sum[ith] = sumf;
1820
-
1821
- //
1822
- // Accumulate the sum from all threads in the threadgroup
1823
- //
1824
- threadgroup_barrier(mem_flags::mem_threadgroup);
1825
- if (ith%4 == 0) {
1826
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1951
+ const float tot = simd_sum(sumf);
1952
+ if (tiisg == 0) {
1953
+ dst[r1*ne0 + row] = tot;
1827
1954
  }
1828
- threadgroup_barrier(mem_flags::mem_threadgroup);
1829
- if (ith%16 == 0) {
1830
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1831
- }
1832
- threadgroup_barrier(mem_flags::mem_threadgroup);
1833
- if (ith == 0) {
1834
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1835
- dst[r1*ne0 + r0] = sum[0];
1836
- }
1837
-
1838
1955
  }