llama_cpp 0.3.3 → 0.3.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
331
331
  threadgroup float * sum [[threadgroup(0)]],
332
332
  uint tgpig[[threadgroup_position_in_grid]],
333
333
  uint tpitg[[thread_position_in_threadgroup]],
334
+ uint sgitg[[simdgroup_index_in_threadgroup]],
335
+ uint tiisg[[thread_index_in_simdgroup]],
334
336
  uint ntg[[threads_per_threadgroup]]) {
335
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
337
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
338
+ device const float * x_scalar = (device const float *) x;
339
+ float4 sumf=0;
340
+ float all_sum=0;
336
341
 
337
342
  // parallel sum
338
- sum[tpitg] = 0.0f;
339
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340
- sum[tpitg] += x[i00] * x[i00];
343
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
344
+ sumf += x[i00] * x[i00];
345
+ }
346
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
347
+ all_sum = simd_sum(all_sum);
348
+ if (tiisg == 0) {
349
+ sum[sgitg] = all_sum;
341
350
  }
342
351
 
343
- // reduce
344
352
  threadgroup_barrier(mem_flags::mem_threadgroup);
345
- for (uint i = ntg/2; i > 0; i /= 2) {
346
- if (tpitg < i) {
347
- sum[tpitg] += sum[tpitg + i];
348
- }
349
- threadgroup_barrier(mem_flags::mem_threadgroup);
353
+ // broadcast, simd group number is ntg / 32
354
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
355
+ if (tpitg < i) {
356
+ sum[tpitg] += sum[tpitg + i];
357
+ }
350
358
  }
351
-
352
- // broadcast
353
359
  if (tpitg == 0) {
360
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
354
361
  sum[0] /= ne00;
355
362
  }
356
363
 
@@ -359,82 +366,94 @@ kernel void kernel_rms_norm(
359
366
  const float mean = sum[0];
360
367
  const float scale = 1.0f/sqrt(mean + eps);
361
368
 
362
- device float * y = dst + tgpig*ne00;
363
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
369
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
370
+ device float * y_scalar = (device float *) y;
371
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
364
372
  y[i00] = x[i00] * scale;
365
373
  }
374
+ if (tpitg == 0) {
375
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
376
+ }
377
+ }
378
+
379
+ // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380
+ float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
381
+ float d = qb_curr->d;
382
+ float4 acc = 0.f;
383
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
384
+ for (int i = 0; i < 16; i+=2) {
385
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
386
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
387
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
388
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
389
+ }
390
+ return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
391
+ }
392
+
393
+ // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394
+ float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
395
+ float d = qb_curr->d;
396
+ float m = qb_curr->m;
397
+ float4 acc = 0.f;
398
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
399
+ for (int i = 0; i < 16; i+=2) {
400
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
401
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
402
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
403
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
404
+ }
405
+ return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
366
406
  }
367
407
 
368
408
  // putting them in the kernel cause a significant performance penalty
369
409
  #define N_DST 4 // each SIMD group works on 4 rows
370
410
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371
411
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
372
- kernel void kernel_mul_mat_q4_0_f32(
373
- device const void * src0,
374
- device const float * src1,
375
- device float * dst,
376
- constant int64_t & ne00,
377
- constant int64_t & ne10,
378
- constant int64_t & ne0,
379
- constant int64_t & ne01[[buffer(4)]],
380
- uint2 tgpig[[threadgroup_position_in_grid]],
381
- uint tiisg[[thread_index_in_simdgroup]],
382
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
412
+ template<typename block_q_type>
413
+ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
414
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415
+ uint2 tgpig, uint tiisg, uint sgitg) {
383
416
  const int nb = ne00/QK4_0;
384
417
  const int r0 = tgpig.x;
385
418
  const int r1 = tgpig.y;
386
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
419
+ device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
387
420
  device const float * y = (device const float *) src1 + r1*ne10;
388
- block_q4_0 qb_curr, qb_next;
389
421
  float4 y_curr[8]; // src1 vector cache
390
422
  float sumf[N_DST]={0.f}, all_sum;
391
423
  thread float * yl=(thread float *)y_curr;
392
424
 
393
- // bootstrap
394
- qb_curr = x[tiisg];
395
425
  // each thread in a SIMD group deals with 1 block.
396
426
  for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
397
-
427
+ float sumy = 0;
398
428
  for (int i = 0; i < QK4_0 / 4; i++) {
399
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
429
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
430
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
400
431
  }
401
432
 
402
433
  for (int row = 0; row < N_DST; row++) {
403
- // prefetch next x block
404
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
405
-
406
- // calculate
407
- float d = qb_curr.d;
408
- float2 acc = {0.0f, 0.0f};
409
- for (int i = 0; i < 16; i++) {
410
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
411
- acc[1] += yl[i] + yl[i+16];
412
- }
413
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
414
- qb_curr = qb_next;
434
+ sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
415
435
  }
416
436
  }
417
437
 
418
- for (int i = 0; i < QK4_0 / 4; i++) {
419
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
420
- }
421
-
422
- for (int row = 0; row < N_DST; row++) {
423
- // prefetch next x block
424
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
425
-
426
- // calculate
427
- float d = qb_curr.d;
428
- float2 acc = {0.0f, 0.0f};
429
- for (int i = 0; i < 16; i++) {
430
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
431
- acc[1] += yl[i] + yl[i+16];
438
+ // from now loads two rows every time and 16 blocks per row
439
+ int ir = tiisg / (N_SIMDWIDTH / 2);
440
+ int ib = tiisg % (N_SIMDWIDTH / 2);
441
+ for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
442
+ int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
443
+ float sumy = 0;
444
+ for (int i = 0; i < QK4_0 / 4; i++) {
445
+ y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
446
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
432
447
  }
433
- if (tiisg < nb % N_SIMDWIDTH) {
434
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
448
+
449
+ for (int row = 0; row < N_DST; row+=2) {
450
+ if (nb_start + ib < nb) {
451
+ sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
452
+ }
435
453
  }
436
- qb_curr = qb_next;
454
+ }
437
455
 
456
+ for (int row = 0; row < N_DST; ++row) {
438
457
  all_sum = simd_sum(sumf[row]);
439
458
  if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
440
459
  dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
@@ -442,73 +461,32 @@ kernel void kernel_mul_mat_q4_0_f32(
442
461
  }
443
462
  }
444
463
 
445
- kernel void kernel_mul_mat_q4_1_f32(
464
+ kernel void kernel_mul_mat_q4_0_f32(
446
465
  device const void * src0,
447
466
  device const float * src1,
448
467
  device float * dst,
449
468
  constant int64_t & ne00,
450
469
  constant int64_t & ne10,
451
470
  constant int64_t & ne0,
452
- threadgroup float * sum [[threadgroup(0)]],
471
+ constant int64_t & ne01[[buffer(4)]],
453
472
  uint2 tgpig[[threadgroup_position_in_grid]],
454
- uint2 tpitg[[thread_position_in_threadgroup]],
455
- uint2 tptg[[threads_per_threadgroup]]) {
456
- const int nb = ne00/QK4_1;
457
-
458
- const int64_t r0 = tgpig.x;
459
- const int64_t r1 = tgpig.y;
460
-
461
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
462
- device const float * y = (device const float *) src1 + r1*ne10;
463
-
464
- const uint nth = tptg.x*tptg.y;
465
- const uint ith = tptg.y*tpitg.x + tpitg.y;
466
-
467
- const int ix = tpitg.y/4; // 0 or 1
468
- const int iy = tpitg.y - 4*ix; // 0...3
469
-
470
- const int first = 4 * iy;
471
-
472
- float sumf = 0;
473
-
474
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
475
-
476
- const float d = (float)x[i].d;
477
- const float m = (float)x[i].m;
478
-
479
- device const uint8_t * xl = x[i].qs + first;
480
- device const float * yl = y + i * QK4_1 + first;
481
-
482
- float2 acc = {0.0f, 0.0f};
483
-
484
- for (int j = 0; j < 4; ++j) {
485
-
486
- acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
487
- acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
488
-
489
- }
490
-
491
- sumf += acc[0] + acc[1];
492
- }
493
-
494
- sum[ith] = sumf;
473
+ uint tiisg[[thread_index_in_simdgroup]],
474
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
475
+ mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476
+ }
495
477
 
496
- //
497
- // Accumulate the sum from all threads in the threadgroup
498
- //
499
- threadgroup_barrier(mem_flags::mem_threadgroup);
500
- if (ith%4 == 0) {
501
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
502
- }
503
- threadgroup_barrier(mem_flags::mem_threadgroup);
504
- if (ith%16 == 0) {
505
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
506
- }
507
- threadgroup_barrier(mem_flags::mem_threadgroup);
508
- if (ith == 0) {
509
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
510
- dst[r1*ne0 + r0] = sum[0];
511
- }
478
+ kernel void kernel_mul_mat_q4_1_f32(
479
+ device const void * src0,
480
+ device const float * src1,
481
+ device float * dst,
482
+ constant int64_t & ne00,
483
+ constant int64_t & ne10,
484
+ constant int64_t & ne0,
485
+ constant int64_t & ne01[[buffer(4)]],
486
+ uint2 tgpig[[threadgroup_position_in_grid]],
487
+ uint tiisg[[thread_index_in_simdgroup]],
488
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
+ mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
512
490
  }
513
491
 
514
492
  kernel void kernel_mul_mat_f16_f32(
@@ -624,17 +602,19 @@ kernel void kernel_rope(
624
602
  constant int & n_past,
625
603
  constant int & n_dims,
626
604
  constant int & mode,
605
+ constant float & freq_base,
606
+ constant float & freq_scale,
627
607
  uint3 tpig[[thread_position_in_grid]]) {
628
608
  const int64_t i3 = tpig[2];
629
609
  const int64_t i2 = tpig[1];
630
610
  const int64_t i1 = tpig[0];
631
611
 
632
612
  const bool is_neox = mode & 2;
633
- const float theta_scale = pow(10000.0, -2.0f/n_dims);
613
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
634
614
 
635
615
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
636
616
 
637
- float theta = (float)p;
617
+ float theta = freq_scale * (float)p;
638
618
 
639
619
  if (!is_neox) {
640
620
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -1229,111 +1209,137 @@ kernel void kernel_mul_mat_q2_K_f32(
1229
1209
  constant int64_t & ne00,
1230
1210
  constant int64_t & ne10,
1231
1211
  constant int64_t & ne0,
1232
- threadgroup float * sum [[threadgroup(0)]],
1212
+ constant int64_t & ne01[[buffer(4)]],
1233
1213
  uint2 tgpig[[threadgroup_position_in_grid]],
1234
- uint2 tpitg[[thread_position_in_threadgroup]],
1235
- uint2 tptg[[threads_per_threadgroup]]) {
1214
+ uint tiisg[[thread_index_in_simdgroup]],
1215
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1236
1216
 
1237
1217
  const int nb = ne00/QK_K;
1218
+ const int r0 = tgpig.x;
1219
+ const int r1 = tgpig.y;
1238
1220
 
1239
- const int64_t r0 = tgpig.x;
1240
- const int64_t r1 = tgpig.y;
1241
-
1242
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1243
- device const float * yy = (device const float *) src1 + r1*ne10;
1244
-
1245
- const int nth = tptg.x*tptg.y;
1246
- const int ith = tptg.y*tpitg.x + tpitg.y;
1221
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1222
+ const int ib_row = first_row * nb;
1223
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1224
+ device const float * y = (device const float *) src1 + r1*ne10;
1225
+ float yl[32];
1226
+ float sumf[N_DST]={0.f}, all_sum;
1247
1227
 
1248
- float sumf = 0;
1228
+ const int step = sizeof(block_q2_K) * nb;
1249
1229
 
1250
1230
  #if QK_K == 256
1251
- const int tid = tpitg.y; // 0...16
1252
- const int il = tid/4; // 0...3
1253
- const int ir = tid%4; // 0...3
1254
- const int ip = il/2; // 0 or 1
1255
- const int shift1 = 4*(il%2);// 0 or 4
1256
- const int shift2 = shift1+2;// 2 or 6
1257
- const int n = 8;
1258
- const int is = 4*il + (n*ir)/16;
1259
-
1260
- const int y_offset = 64*il + n*ir;
1261
- const int q_offset = 32*ip + n*ir;
1262
-
1263
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1264
-
1265
- device const uint8_t * q = x[i].qs + q_offset;
1266
- device const uint8_t * scales = x[i].scales + is;
1231
+ const int ix = tiisg/8; // 0...3
1232
+ const int it = tiisg%8; // 0...7
1233
+ const int im = it/4; // 0 or 1
1234
+ const int ir = it%4; // 0...3
1235
+ const int is = (8*ir)/16;// 0 or 1
1236
+
1237
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1238
+
1239
+ for (int ib = ix; ib < nb; ib += 4) {
1240
+
1241
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1242
+ for (int i = 0; i < 8; ++i) {
1243
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1244
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
1245
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
1246
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1247
+ }
1267
1248
 
1268
- uint8_t d1 = scales[0] & 0xF;
1269
- uint8_t d2 = scales[2] & 0xF;
1270
- uint8_t m1 = scales[0] >> 4;
1271
- uint8_t m2 = scales[2] >> 4;
1249
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1250
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1251
+ device const half * dh = &x[ib].d;
1272
1252
 
1273
- device const float * y = yy + i*QK_K + y_offset;
1253
+ for (int row = 0; row < N_DST; row++) {
1274
1254
 
1275
- float2 s = {0.f, 0.f};
1276
- float smin = 0;
1277
- for (int l = 0; l < n; ++l) {
1278
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1279
- s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1280
- smin += y[l+ 0] * m1 + y[l+32] * m2;
1255
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1256
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1257
+ for (int i = 0; i < 8; i += 2) {
1258
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1259
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1260
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1261
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1262
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1263
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1264
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1265
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1266
+ }
1267
+ float dall = dh[0];
1268
+ float dmin = dh[1] * 1.f/16.f;
1269
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1270
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
1271
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
1272
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
1273
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
1274
+
1275
+ qs += step/2;
1276
+ sc += step;
1277
+ dh += step/2;
1281
1278
  }
1282
1279
 
1283
- const float dall = (float)x[i].d;
1284
- const float dmin = (float)x[i].dmin;
1285
-
1286
- sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1287
-
1280
+ y4 += 4 * QK_K;
1288
1281
  }
1289
1282
  #else
1290
- const int il = 4 * tpitg.x;
1283
+ const int ix = tiisg/2; // 0...15
1284
+ const int it = tiisg%2; // 0...1
1291
1285
 
1292
- uint32_t aux[2];
1293
- thread const uint8_t * d = (thread const uint8_t *)aux;
1294
- thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1286
+ device const float * y4 = y + ix * QK_K + 8 * it;
1295
1287
 
1296
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1288
+ for (int ib = ix; ib < nb; ib += 16) {
1297
1289
 
1298
- device const uint8_t * q = x[i].qs + il;
1299
- device const float * y = yy + i*QK_K + il;
1290
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1291
+ for (int i = 0; i < 8; ++i) {
1292
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1293
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
1294
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
1295
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
1296
+ }
1300
1297
 
1301
- const float dall = (float)x[i].d;
1302
- const float dmin = (float)x[i].dmin;
1298
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
1299
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1300
+ device const half * dh = &x[ib].d;
1303
1301
 
1304
- device const uint32_t * a = (device const uint32_t *)x[i].scales;
1305
- aux[0] = a[0] & 0x0f0f0f0f;
1306
- aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1302
+ for (int row = 0; row < N_DST; row++) {
1307
1303
 
1308
- for (int l = 0; l < 4; ++l) {
1309
- sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1310
- + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1311
- + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1312
- + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1305
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1306
+ for (int i = 0; i < 8; i += 2) {
1307
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1308
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1309
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1310
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1311
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1312
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1313
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1314
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1315
+ }
1316
+
1317
+ float dall = dh[0];
1318
+ float dmin = dh[1];
1319
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1320
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1321
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1322
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1323
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1324
+
1325
+ qs += step/2;
1326
+ sc += step;
1327
+ dh += step/2;
1313
1328
  }
1329
+
1330
+ y4 += 16 * QK_K;
1314
1331
  }
1315
1332
  #endif
1316
1333
 
1317
- sum[ith] = sumf;
1318
-
1319
- //
1320
- // Accumulate the sum from all threads in the threadgroup
1321
- //
1322
- threadgroup_barrier(mem_flags::mem_threadgroup);
1323
- if (ith%4 == 0) {
1324
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1325
- }
1326
- threadgroup_barrier(mem_flags::mem_threadgroup);
1327
- if (ith%16 == 0) {
1328
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1329
- }
1330
- threadgroup_barrier(mem_flags::mem_threadgroup);
1331
- if (ith == 0) {
1332
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1333
- dst[r1*ne0 + r0] = sum[0];
1334
+ for (int row = 0; row < N_DST; ++row) {
1335
+ all_sum = simd_sum(sumf[row]);
1336
+ if (tiisg == 0) {
1337
+ dst[r1*ne0 + first_row + row] = all_sum;
1338
+ }
1334
1339
  }
1335
1340
  }
1336
1341
 
1342
+ #if QK_K == 256
1337
1343
  kernel void kernel_mul_mat_q3_K_f32(
1338
1344
  device const void * src0,
1339
1345
  device const float * src1,
@@ -1342,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
1342
1348
  constant int64_t & ne10,
1343
1349
  constant int64_t & ne0,
1344
1350
  constant int64_t & ne1,
1345
- threadgroup float * sum [[threadgroup(0)]],
1346
1351
  uint2 tgpig[[threadgroup_position_in_grid]],
1347
- uint2 tpitg[[thread_position_in_threadgroup]],
1348
- uint2 tptg[[threads_per_threadgroup]]) {
1352
+ uint tiisg[[thread_index_in_simdgroup]],
1353
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1349
1354
 
1350
1355
  const int nb = ne00/QK_K;
1351
1356
 
1352
1357
  const int64_t r0 = tgpig.x;
1353
1358
  const int64_t r1 = tgpig.y;
1354
1359
 
1355
- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1356
- device const float * yy = (device const float *) src1 + r1*ne10;
1360
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1357
1361
 
1358
- const int nth = tptg.x*tptg.y;
1359
- const int ith = tptg.y*tpitg.x + tpitg.y;
1360
-
1361
- #if QK_K == 256
1362
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1363
+ device const float * yy = (device const float *) src1 + r1*ne10;
1362
1364
 
1363
- const uint8_t m3 = 3;
1364
- const int8_t m4 = 4;
1365
+ float yl[16];
1365
1366
 
1366
1367
  const uint16_t kmask1 = 0x0303;
1367
1368
  const uint16_t kmask2 = 0x0f0f;
1368
1369
 
1369
- const int tid = tpitg.y; // expecting 16
1370
+ const int tid = tiisg/2;
1371
+ const int ix = tiisg%2;
1370
1372
  const int ip = tid/8; // 0 or 1
1371
1373
  const int il = tid/2 - 4*ip; // 0...3
1372
1374
  const int ir = tid%2;
1373
1375
  const int n = 8;
1374
1376
  const int l0 = n*ir;
1375
1377
 
1376
- const uint8_t m = 1 << (4*ip + il);
1378
+ const uint16_t m1 = 1 << (4*ip + il);
1379
+ const uint16_t m2 = m1 << 8;
1377
1380
 
1378
1381
  const int shift = 2*il;
1382
+ const uint16_t qm1 = 0x0003 << shift;
1383
+ const uint16_t qm2 = 0x0300 << shift;
1384
+ const int32_t v1 = 4 << shift;
1385
+ const int32_t v2 = 1024 << shift;
1379
1386
 
1380
1387
  const uint16_t s_shift1 = 4*ip;
1381
1388
  const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1384,226 +1391,315 @@ kernel void kernel_mul_mat_q3_K_f32(
1384
1391
  const int q_offset = 32*ip + l0;
1385
1392
  const int y_offset = 128*ip + 32*il + l0;
1386
1393
 
1387
- //float sumf = 0;
1388
- float sumf1 = 0, sumf2 = 0;
1389
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1390
-
1391
- const float d_all = (float)(x[i].d);
1394
+ const int step = sizeof(block_q3_K) * nb / 2;
1392
1395
 
1393
- device const uint8_t * q = x[i].qs + q_offset;
1394
- device const uint8_t * h = x[i].hmask + l0;
1395
- device const float * y = yy + i * QK_K + y_offset;
1396
+ device const float * y1 = yy + ix*QK_K + y_offset;
1396
1397
 
1397
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1398
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1398
+ float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1399
+ for (int i = ix; i < nb; i += 2) {
1399
1400
 
1400
- float s = 0;
1401
- for (int l = 0; l < n; ++l) {
1402
- s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1401
+ for (int l = 0; l < 8; ++l) {
1402
+ yl[l+0] = y1[l+ 0];
1403
+ yl[l+8] = y1[l+16];
1403
1404
  }
1404
- float d = d_all * s;
1405
- sumf1 += d * scales[0];
1406
- sumf2 += d;
1407
- //sumf += d_all * s * (scales[0] - 32);
1408
1405
 
1409
- s = 0;
1410
- for (int l = 0; l < n; ++l) {
1411
- s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1412
- }
1413
- d = d_all * s;
1414
- sumf1 += d * scales[1];
1415
- sumf2 += d;
1416
- //sumf += d_all * s * (scales[1] - 32);
1406
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1407
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1408
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1409
+ device const half * dh = &x[i].d;
1417
1410
 
1418
- }
1411
+ for (int row = 0; row < 2; ++row) {
1419
1412
 
1420
- //sum[ith] = sumf;
1421
- sum[ith] = sumf1 - 32.f*sumf2;
1422
- #else
1423
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1424
- const int im = il/8; // 0, 0, 1, 1
1425
- const int in = il%8; // 0, 4, 0, 4
1413
+ const float d_all = (float)dh[0];
1414
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1426
1415
 
1427
- float sumf = 0;
1428
-
1429
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1430
-
1431
- const float d_all = (float)(x[i].d);
1416
+ float s1 = 0, s2 = 0;
1417
+ for (int l = 0; l < n; l += 2) {
1418
+ const uint16_t qs = q[l/2];
1419
+ s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1420
+ s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1421
+ }
1422
+ float d = d_all * (s1 + 1.f/256.f * s2);
1423
+ sumf1[row] += d * scales[0];
1424
+ sumf2[row] += d;
1425
+
1426
+ s1 = s2 = 0;
1427
+ for (int l = 0; l < n; l += 2) {
1428
+ const uint16_t qs = q[l/2+8];
1429
+ s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1430
+ s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1431
+ }
1432
+ d = d_all * (s1 + 1.f/256.f * s2);
1433
+ sumf1[row] += d * scales[1];
1434
+ sumf2[row] += d;
1432
1435
 
1433
- device const uint8_t * q = x[i].qs + il;
1434
- device const uint8_t * h = x[i].hmask + in;
1435
- device const float * y = yy + i * QK_K + il;
1436
+ q += step;
1437
+ h += step;
1438
+ a += step;
1439
+ dh += step;
1436
1440
 
1437
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1438
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1439
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1440
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1441
-
1442
- for (int l = 0; l < 4; ++l) {
1443
- const uint8_t hm = h[l] >> im;
1444
- sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1445
- + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1446
- + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1447
- + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1448
1441
  }
1449
1442
 
1450
- }
1451
-
1452
- sum[ith] = sumf;
1453
-
1454
- #endif
1443
+ y1 += 2 * QK_K;
1455
1444
 
1456
- //
1457
- // Accumulate the sum from all threads in the threadgroup
1458
- //
1459
- threadgroup_barrier(mem_flags::mem_threadgroup);
1460
- if (ith%4 == 0) {
1461
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1462
- }
1463
- threadgroup_barrier(mem_flags::mem_threadgroup);
1464
- if (ith%16 == 0) {
1465
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1466
- }
1467
- threadgroup_barrier(mem_flags::mem_threadgroup);
1468
- if (ith == 0) {
1469
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1470
- dst[r1*ne0 + r0] = sum[0];
1471
1445
  }
1472
1446
 
1447
+ for (int row = 0; row < 2; ++row) {
1448
+ const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1449
+ const float tot = simd_sum(sumf);
1450
+ if (tiisg == 0) {
1451
+ dst[r1*ne0 + first_row + row] = tot;
1452
+ }
1453
+ }
1473
1454
  }
1474
-
1475
- kernel void kernel_mul_mat_q4_K_f32(
1455
+ #else
1456
+ kernel void kernel_mul_mat_q3_K_f32(
1476
1457
  device const void * src0,
1477
1458
  device const float * src1,
1478
1459
  device float * dst,
1479
1460
  constant int64_t & ne00,
1480
1461
  constant int64_t & ne10,
1481
1462
  constant int64_t & ne0,
1482
- threadgroup float * sum [[threadgroup(0)]],
1463
+ constant int64_t & ne1,
1483
1464
  uint2 tgpig[[threadgroup_position_in_grid]],
1484
- uint2 tpitg[[thread_position_in_threadgroup]],
1485
- uint2 tptg[[threads_per_threadgroup]]) {
1465
+ uint tiisg[[thread_index_in_simdgroup]],
1466
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1486
1467
 
1487
1468
  const int nb = ne00/QK_K;
1488
1469
 
1489
1470
  const int64_t r0 = tgpig.x;
1490
1471
  const int64_t r1 = tgpig.y;
1491
1472
 
1492
- const int nth = tptg.x*tptg.y;
1493
- const int ith = tptg.y*tpitg.x + tpitg.y;
1473
+ const int row = 2 * r0 + sgitg;
1494
1474
 
1495
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1475
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1496
1476
  device const float * yy = (device const float *) src1 + r1*ne10;
1477
+ const int ix = tiisg/4;
1478
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1479
+ const int im = il/8; // 0, 0, 1, 1
1480
+ const int in = il%8; // 0, 4, 0, 4
1497
1481
 
1498
- float sumf = 0;
1482
+ float2 sum = {0.f, 0.f};
1483
+
1484
+ for (int i = ix; i < nb; i += 8) {
1485
+
1486
+ const float d_all = (float)(x[i].d);
1487
+
1488
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1489
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1490
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1491
+ device const float * y = yy + i * QK_K + il;
1492
+
1493
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1494
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1495
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1496
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1497
+
1498
+ for (int l = 0; l < 4; l += 2) {
1499
+ const uint16_t hm = h[l/2] >> im;
1500
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1501
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1502
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1503
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1504
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1505
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1506
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1507
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
1508
+ }
1509
+
1510
+ }
1511
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
1512
+
1513
+ const float tot = simd_sum(sumf);
1514
+ if (tiisg == 0) {
1515
+ dst[r1*ne0 + row] = tot;
1516
+ }
1517
+
1518
+ }
1519
+ #endif
1499
1520
 
1500
1521
  #if QK_K == 256
1522
+ kernel void kernel_mul_mat_q4_K_f32(
1523
+ device const void * src0,
1524
+ device const float * src1,
1525
+ device float * dst,
1526
+ constant int64_t & ne00,
1527
+ constant int64_t & ne10,
1528
+ constant int64_t & ne0,
1529
+ constant int64_t & ne01[[buffer(4)]],
1530
+ uint2 tgpig[[threadgroup_position_in_grid]],
1531
+ uint tiisg[[thread_index_in_simdgroup]],
1532
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1501
1533
 
1502
1534
  const uint16_t kmask1 = 0x3f3f;
1503
1535
  const uint16_t kmask2 = 0x0f0f;
1504
1536
  const uint16_t kmask3 = 0xc0c0;
1505
1537
 
1506
- const int tid = tpitg.y; // 0...16
1507
- const int il = tid/4; // 0...3
1508
- const int ir = tid - 4*il;// 0...3
1509
- const int n = 4;
1538
+ const int ix = tiisg/8; // 0...3
1539
+ const int it = tiisg%8; // 0...7
1540
+ const int im = it/4; // 0 or 1
1541
+ const int ir = it%4; // 0...3
1510
1542
 
1511
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1512
- const int in = il%2;
1543
+ const int nb = ne00/QK_K;
1544
+ const int r0 = tgpig.x;
1545
+ const int r1 = tgpig.y;
1546
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1547
+ const int ib_row = first_row * nb;
1548
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1549
+ device const float * y = (device const float *) src1 + r1*ne10;
1550
+ float yl[16];
1551
+ float yh[16];
1552
+ float sumf[N_DST]={0.f}, all_sum;
1513
1553
 
1514
- const int l0 = n*(2*ir + in);
1515
- const int q_offset = 32*im + l0;
1516
- const int y_offset = 64*im + l0;
1554
+ const int step = sizeof(block_q4_K) * nb / 2;
1517
1555
 
1518
- uchar2 sc1, sc2, sc3, sc4;
1556
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
1519
1557
 
1520
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1558
+ uint16_t sc16[4];
1559
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1521
1560
 
1522
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1523
- device const uint8_t * q2 = q1 + 64;
1524
- device const float * y1 = yy + i*QK_K + y_offset;
1525
- device const float * y2 = y1 + 128;
1561
+ for (int ib = ix; ib < nb; ib += 4) {
1526
1562
 
1527
- const float dall = (float)((x + i)->d);
1528
- const float dmin = (float)((x + i)->dmin);
1563
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1564
+ for (int i = 0; i < 8; ++i) {
1565
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
1566
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
1567
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
1568
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1569
+ }
1529
1570
 
1530
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1531
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1532
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1533
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1534
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1571
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1572
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1573
+ device const half * dh = &x[ib].d;
1535
1574
 
1536
- float4 s = {0.f, 0.f, 0.f, 0.f};
1537
- float smin = 0;
1538
- for (int l = 0; l < n; ++l) {
1575
+ for (int row = 0; row < N_DST; row++) {
1539
1576
 
1540
- s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1541
- s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1542
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1577
+ sc16[0] = sc[0] & kmask1;
1578
+ sc16[1] = sc[2] & kmask1;
1579
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
1580
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
1581
+
1582
+ device const uint16_t * q2 = q1 + 32;
1583
+
1584
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1585
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1586
+ for (int i = 0; i < 8; i += 2) {
1587
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
1588
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
1589
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
1590
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
1591
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
1592
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
1593
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
1594
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
1595
+ }
1543
1596
 
1597
+ float dall = dh[0];
1598
+ float dmin = dh[1];
1599
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
1600
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
1601
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
1602
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
1603
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1604
+
1605
+ q1 += step;
1606
+ sc += step;
1607
+ dh += step;
1544
1608
  }
1545
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1546
1609
 
1610
+ y4 += 4 * QK_K;
1611
+ }
1612
+
1613
+ for (int row = 0; row < N_DST; ++row) {
1614
+ all_sum = simd_sum(sumf[row]);
1615
+ if (tiisg == 0) {
1616
+ dst[r1*ne0 + first_row + row] = all_sum;
1617
+ }
1547
1618
  }
1619
+ }
1548
1620
  #else
1549
- uint16_t aux16[2];
1550
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
1621
+ kernel void kernel_mul_mat_q4_K_f32(
1622
+ device const void * src0,
1623
+ device const float * src1,
1624
+ device float * dst,
1625
+ constant int64_t & ne00,
1626
+ constant int64_t & ne10,
1627
+ constant int64_t & ne0,
1628
+ constant int64_t & ne01[[buffer(4)]],
1629
+ uint2 tgpig[[threadgroup_position_in_grid]],
1630
+ uint tiisg[[thread_index_in_simdgroup]],
1631
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1551
1632
 
1552
- const int il = 4*tpitg.x;
1633
+ const int ix = tiisg/4; // 0...7
1634
+ const int it = tiisg%4; // 0...3
1553
1635
 
1554
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1636
+ const int nb = ne00/QK_K;
1637
+ const int r0 = tgpig.x;
1638
+ const int r1 = tgpig.y;
1639
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1640
+ const int ib_row = first_row * nb;
1641
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1642
+ device const float * y = (device const float *) src1 + r1*ne10;
1643
+ float yl[8];
1644
+ float yh[8];
1645
+ float sumf[N_DST]={0.f}, all_sum;
1555
1646
 
1556
- device const uint8_t * q = x[i].qs + il;
1557
- device const float * y = yy + i * QK_K + il;
1647
+ const int step = sizeof(block_q4_K) * nb / 2;
1558
1648
 
1559
- const float d = (float)x[i].d[0];
1560
- const float m = (float)x[i].d[1];
1649
+ device const float * y4 = y + ix * QK_K + 8 * it;
1561
1650
 
1562
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1563
- aux16[0] = a[0] & 0x0f0f;
1564
- aux16[1] = (a[0] >> 4) & 0x0f0f;
1651
+ uint16_t sc16[4];
1565
1652
 
1566
- for (int l = 0; l < 4; ++l) {
1567
- sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1568
- + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1653
+ for (int ib = ix; ib < nb; ib += 8) {
1654
+
1655
+ float2 sumy = {0.f, 0.f};
1656
+ for (int i = 0; i < 8; ++i) {
1657
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
1658
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
1569
1659
  }
1570
- }
1571
- #endif
1572
1660
 
1573
- sum[ith] = sumf;
1661
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
1662
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1663
+ device const half * dh = x[ib].d;
1574
1664
 
1575
- //
1576
- // Accumulate the sum from all threads in the threadgroup
1577
- // This version is slightly faster than the commented out one below,
1578
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1579
- //
1580
- threadgroup_barrier(mem_flags::mem_threadgroup);
1581
- if (ith%4 == 0) {
1582
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1583
- }
1584
- threadgroup_barrier(mem_flags::mem_threadgroup);
1585
- if (ith%16 == 0) {
1586
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1587
- }
1588
- threadgroup_barrier(mem_flags::mem_threadgroup);
1589
- if (ith == 0) {
1590
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1591
- dst[r1*ne0 + r0] = sum[0];
1665
+ for (int row = 0; row < N_DST; row++) {
1666
+
1667
+ sc16[0] = sc[0] & 0x000f;
1668
+ sc16[1] = sc[0] & 0x0f00;
1669
+ sc16[2] = sc[0] & 0x00f0;
1670
+ sc16[3] = sc[0] & 0xf000;
1671
+
1672
+ float2 acc1 = {0.f, 0.f};
1673
+ float2 acc2 = {0.f, 0.f};
1674
+ for (int i = 0; i < 8; i += 2) {
1675
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
1676
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
1677
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
1678
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
1679
+ }
1680
+
1681
+ float dall = dh[0];
1682
+ float dmin = dh[1];
1683
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
1684
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
1685
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
1686
+
1687
+ qs += step;
1688
+ sc += step;
1689
+ dh += step;
1690
+ }
1691
+
1692
+ y4 += 8 * QK_K;
1592
1693
  }
1593
1694
 
1594
- //// accumulate the sum from all threads in the threadgroup
1595
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1596
- //for (uint i = nth/2; i > 0; i /= 2) {
1597
- // if (ith < i) {
1598
- // sum[ith] += sum[ith + i];
1599
- // }
1600
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1601
- //}
1602
-
1603
- //if (ith == 0) {
1604
- // dst[r1*ne0 + r0] = sum[0];
1605
- //}
1695
+ for (int row = 0; row < N_DST; ++row) {
1696
+ all_sum = simd_sum(sumf[row]);
1697
+ if (tiisg == 0) {
1698
+ dst[r1*ne0 + first_row + row] = all_sum;
1699
+ }
1700
+ }
1606
1701
  }
1702
+ #endif
1607
1703
 
1608
1704
  kernel void kernel_mul_mat_q5_K_f32(
1609
1705
  device const void * src0,
@@ -1612,39 +1708,39 @@ kernel void kernel_mul_mat_q5_K_f32(
1612
1708
  constant int64_t & ne00,
1613
1709
  constant int64_t & ne10,
1614
1710
  constant int64_t & ne0,
1615
- threadgroup float * sum [[threadgroup(0)]],
1616
1711
  uint2 tgpig[[threadgroup_position_in_grid]],
1617
- uint2 tpitg[[thread_position_in_threadgroup]],
1618
- uint2 tptg[[threads_per_threadgroup]]) {
1712
+ uint tiisg[[thread_index_in_simdgroup]],
1713
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1619
1714
 
1620
1715
  const int nb = ne00/QK_K;
1621
1716
 
1622
1717
  const int64_t r0 = tgpig.x;
1623
1718
  const int64_t r1 = tgpig.y;
1624
1719
 
1625
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1720
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1721
+
1722
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1626
1723
  device const float * yy = (device const float *) src1 + r1*ne10;
1627
1724
 
1628
- const int nth = tptg.x*tptg.y;
1629
- const int ith = tptg.y*tpitg.x + tpitg.y;
1725
+ float sumf[2]={0.f};
1630
1726
 
1631
- float sumf = 0;
1727
+ const int step = sizeof(block_q5_K) * nb;
1632
1728
 
1633
1729
  #if QK_K == 256
1730
+ #
1731
+ float yl[16], yh[16];
1634
1732
 
1635
1733
  const uint16_t kmask1 = 0x3f3f;
1636
1734
  const uint16_t kmask2 = 0x0f0f;
1637
1735
  const uint16_t kmask3 = 0xc0c0;
1638
1736
 
1639
- const int tid = tpitg.y; // 0...16
1640
- const int il = tid/4; // 0...3
1641
- const int ir = tid - 4*il;// 0...3
1642
- const int n = 4;
1643
-
1644
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1645
- const int in = il%2;
1737
+ const int tid = tiisg/4;
1738
+ const int ix = tiisg%4;
1739
+ const int im = tid/4;
1740
+ const int ir = tid%4;
1741
+ const int n = 8;
1646
1742
 
1647
- const int l0 = n*(2*ir + in);
1743
+ const int l0 = n*ir;
1648
1744
  const int q_offset = 32*im + l0;
1649
1745
  const int y_offset = 64*im + l0;
1650
1746
 
@@ -1653,78 +1749,113 @@ kernel void kernel_mul_mat_q5_K_f32(
1653
1749
  const uint8_t hm3 = hm1 << 4;
1654
1750
  const uint8_t hm4 = hm2 << 4;
1655
1751
 
1656
- uchar2 sc1, sc2, sc3, sc4;
1752
+ uint16_t sc16[4];
1753
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1657
1754
 
1658
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1755
+ device const float * y1 = yy + ix*QK_K + y_offset;
1659
1756
 
1660
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1661
- device const uint8_t * q2 = q1 + 64;
1662
- device const uint8_t * qh = (x + i)->qh + l0;
1663
- device const float * y1 = yy + i*QK_K + y_offset;
1664
- device const float * y2 = y1 + 128;
1757
+ for (int i = ix; i < nb; i += 4) {
1665
1758
 
1666
- const float dall = (float)((x + i)->d);
1667
- const float dmin = (float)((x + i)->dmin);
1759
+ device const uint8_t * q1 = x[i].qs + q_offset;
1760
+ device const uint8_t * qh = x[i].qh + l0;
1761
+ device const half * dh = &x[i].d;
1762
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1668
1763
 
1669
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1670
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1671
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1672
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1673
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1764
+ device const float * y2 = y1 + 128;
1765
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1766
+ for (int l = 0; l < 8; ++l) {
1767
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1768
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1769
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1770
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1771
+ }
1674
1772
 
1675
- float4 s = {0.f, 0.f, 0.f, 0.f};
1676
- float smin = 0;
1677
- for (int l = 0; l < n; ++l) {
1773
+ for (int row = 0; row < 2; ++row) {
1774
+
1775
+ device const uint8_t * q2 = q1 + 64;
1776
+
1777
+ sc16[0] = a[0] & kmask1;
1778
+ sc16[1] = a[2] & kmask1;
1779
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1780
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1678
1781
 
1679
- s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1680
- s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1681
- s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1682
- s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1683
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1782
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
1783
+ for (int l = 0; l < n; ++l) {
1784
+ uint8_t h = qh[l];
1785
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1786
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1787
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1788
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1789
+ }
1790
+ const float dall = dh[0];
1791
+ const float dmin = dh[1];
1792
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1793
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1794
+
1795
+ q1 += step;
1796
+ qh += step;
1797
+ dh += step/2;
1798
+ a += step/2;
1684
1799
 
1685
1800
  }
1686
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1801
+
1802
+ y1 += 4 * QK_K;
1687
1803
 
1688
1804
  }
1689
1805
  #else
1690
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1691
- const int im = il/8; // 0, 0, 1, 1
1692
- const int in = il%8; // 0, 4, 0, 4
1806
+ float yl[8], yh[8];
1807
+
1808
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1809
+ const int ix = tiisg%8;
1810
+ const int im = il/8; // 0, 0, 1, 1
1811
+ const int in = il%8; // 0, 4, 0, 4
1693
1812
 
1694
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1813
+ device const float * y = yy + ix*QK_K + il;
1695
1814
 
1696
- const float d = (float)x[i].d;
1815
+ for (int i = ix; i < nb; i += 8) {
1816
+
1817
+ for (int l = 0; l < 4; ++l) {
1818
+ yl[l+0] = y[l+ 0];
1819
+ yl[l+4] = y[l+16];
1820
+ yh[l+0] = y[l+32];
1821
+ yh[l+4] = y[l+48];
1822
+ }
1823
+
1824
+ device const half * dh = &x[i].d;
1697
1825
  device const uint8_t * q = x[i].qs + il;
1698
1826
  device const uint8_t * h = x[i].qh + in;
1699
1827
  device const int8_t * s = x[i].scales;
1700
- device const float * y = yy + i*QK_K + il;
1701
1828
 
1702
- for (int l = 0; l < 4; ++l) {
1703
- const uint8_t hl = h[l] >> im;
1704
- sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1705
- + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1706
- + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1707
- + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1829
+ for (int row = 0; row < 2; ++row) {
1830
+
1831
+ const float d = dh[0];
1832
+
1833
+ float2 acc = {0.f, 0.f};
1834
+ for (int l = 0; l < 4; ++l) {
1835
+ const uint8_t hl = h[l] >> im;
1836
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1837
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1838
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1839
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1840
+ }
1841
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1842
+
1843
+ q += step;
1844
+ h += step;
1845
+ s += step;
1846
+ dh += step/2;
1847
+
1708
1848
  }
1849
+
1850
+ y += 8 * QK_K;
1709
1851
  }
1710
1852
  #endif
1711
- sum[ith] = sumf;
1712
1853
 
1713
- //
1714
- // Accumulate the sum from all threads in the threadgroup
1715
- //
1716
- threadgroup_barrier(mem_flags::mem_threadgroup);
1717
- if (ith%4 == 0) {
1718
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1719
- }
1720
- threadgroup_barrier(mem_flags::mem_threadgroup);
1721
- if (ith%16 == 0) {
1722
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1723
- }
1724
- threadgroup_barrier(mem_flags::mem_threadgroup);
1725
- if (ith == 0) {
1726
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1727
- dst[r1*ne0 + r0] = sum[0];
1854
+ for (int row = 0; row < 2; ++row) {
1855
+ const float tot = simd_sum(sumf[row]);
1856
+ if (tiisg == 0) {
1857
+ dst[r1*ne0 + first_row + row] = tot;
1858
+ }
1728
1859
  }
1729
1860
 
1730
1861
  }
@@ -1736,10 +1867,9 @@ kernel void kernel_mul_mat_q6_K_f32(
1736
1867
  constant int64_t & ne00,
1737
1868
  constant int64_t & ne10,
1738
1869
  constant int64_t & ne0,
1739
- threadgroup float * sum [[threadgroup(0)]],
1740
1870
  uint2 tgpig[[threadgroup_position_in_grid]],
1741
- uint2 tpitg[[thread_position_in_threadgroup]],
1742
- uint2 tptg[[threads_per_threadgroup]]) {
1871
+ uint tiisg[[thread_index_in_simdgroup]],
1872
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1743
1873
 
1744
1874
  const uint8_t kmask1 = 0x03;
1745
1875
  const uint8_t kmask2 = 0x0C;
@@ -1751,19 +1881,18 @@ kernel void kernel_mul_mat_q6_K_f32(
1751
1881
  const int64_t r0 = tgpig.x;
1752
1882
  const int64_t r1 = tgpig.y;
1753
1883
 
1754
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1755
- device const float * yy = (device const float *) src1 + r1*ne10;
1884
+ const int row = 2 * r0 + sgitg;
1756
1885
 
1757
- const int nth = tptg.x*tptg.y;
1758
- const int ith = tptg.y*tpitg.x + tpitg.y;
1886
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1887
+ device const float * yy = (device const float *) src1 + r1*ne10;
1759
1888
 
1760
1889
  float sumf = 0;
1761
1890
 
1762
1891
  #if QK_K == 256
1763
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1764
- const int iqs = 16 * tpitg.y;
1765
- const int ip = iqs / 128; // 0 or 1
1766
- const int il = (iqs - 128*ip)/16; // 0...7
1892
+ const int tid = tiisg/2;
1893
+ const int ix = tiisg%2;
1894
+ const int ip = tid/8; // 0 or 1
1895
+ const int il = tid%8;
1767
1896
  const int n = 4;
1768
1897
  const int l0 = n*il;
1769
1898
  const int is = 8*ip + l0/16;
@@ -1772,9 +1901,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1772
1901
  const int q_offset_l = 64*ip + l0;
1773
1902
  const int q_offset_h = 32*ip + l0;
1774
1903
 
1775
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1904
+ for (int i = ix; i < nb; i += 2) {
1776
1905
 
1777
- device const uint8_t * ql = x[i].ql + q_offset_l;
1906
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1907
+ device const uint8_t * q2 = q1 + 32;
1778
1908
  device const uint8_t * qh = x[i].qh + q_offset_h;
1779
1909
  device const int8_t * sc = x[i].scales + is;
1780
1910
 
@@ -1784,19 +1914,21 @@ kernel void kernel_mul_mat_q6_K_f32(
1784
1914
 
1785
1915
  float4 sums = {0.f, 0.f, 0.f, 0.f};
1786
1916
  for (int l = 0; l < n; ++l) {
1787
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1788
- sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1789
- sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1790
- sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1917
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1918
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1919
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1920
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1791
1921
  }
1792
1922
 
1793
1923
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1794
1924
 
1795
1925
  }
1926
+
1796
1927
  #else
1797
- const int il = 4*tpitg.x; // 0, 4, 8, 12
1928
+ const int ix = tiisg/4;
1929
+ const int il = 4*(tiisg%4);
1798
1930
 
1799
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1931
+ for (int i = ix; i < nb; i += 8) {
1800
1932
  device const float * y = yy + i * QK_K + il;
1801
1933
  device const uint8_t * ql = x[i].ql + il;
1802
1934
  device const uint8_t * qh = x[i].qh + il;
@@ -1816,23 +1948,8 @@ kernel void kernel_mul_mat_q6_K_f32(
1816
1948
 
1817
1949
  #endif
1818
1950
 
1819
- sum[ith] = sumf;
1820
-
1821
- //
1822
- // Accumulate the sum from all threads in the threadgroup
1823
- //
1824
- threadgroup_barrier(mem_flags::mem_threadgroup);
1825
- if (ith%4 == 0) {
1826
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1951
+ const float tot = simd_sum(sumf);
1952
+ if (tiisg == 0) {
1953
+ dst[r1*ne0 + row] = tot;
1827
1954
  }
1828
- threadgroup_barrier(mem_flags::mem_threadgroup);
1829
- if (ith%16 == 0) {
1830
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1831
- }
1832
- threadgroup_barrier(mem_flags::mem_threadgroup);
1833
- if (ith == 0) {
1834
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1835
- dst[r1*ne0 + r0] = sum[0];
1836
- }
1837
-
1838
1955
  }