llama_cpp 0.3.3 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -67,6 +67,17 @@ kernel void kernel_add(
67
67
  dst[tpig] = src0[tpig] + src1[tpig];
68
68
  }
69
69
 
70
+ // assumption: src1 is a row
71
+ // broadcast src1 into src0
72
+ kernel void kernel_add_row(
73
+ device const float * src0,
74
+ device const float * src1,
75
+ device float * dst,
76
+ constant int64_t & ne00,
77
+ uint tpig[[thread_position_in_grid]]) {
78
+ dst[tpig] = src0[tpig] + src1[tpig % ne00];
79
+ }
80
+
70
81
  kernel void kernel_mul(
71
82
  device const float * src0,
72
83
  device const float * src1,
@@ -331,26 +342,33 @@ kernel void kernel_rms_norm(
331
342
  threadgroup float * sum [[threadgroup(0)]],
332
343
  uint tgpig[[threadgroup_position_in_grid]],
333
344
  uint tpitg[[thread_position_in_threadgroup]],
345
+ uint sgitg[[simdgroup_index_in_threadgroup]],
346
+ uint tiisg[[thread_index_in_simdgroup]],
334
347
  uint ntg[[threads_per_threadgroup]]) {
335
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
348
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
349
+ device const float * x_scalar = (device const float *) x;
350
+ float4 sumf=0;
351
+ float all_sum=0;
336
352
 
337
353
  // parallel sum
338
- sum[tpitg] = 0.0f;
339
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340
- sum[tpitg] += x[i00] * x[i00];
354
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
355
+ sumf += x[i00] * x[i00];
356
+ }
357
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
358
+ all_sum = simd_sum(all_sum);
359
+ if (tiisg == 0) {
360
+ sum[sgitg] = all_sum;
341
361
  }
342
362
 
343
- // reduce
344
363
  threadgroup_barrier(mem_flags::mem_threadgroup);
345
- for (uint i = ntg/2; i > 0; i /= 2) {
346
- if (tpitg < i) {
347
- sum[tpitg] += sum[tpitg + i];
348
- }
349
- threadgroup_barrier(mem_flags::mem_threadgroup);
364
+ // broadcast, simd group number is ntg / 32
365
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
366
+ if (tpitg < i) {
367
+ sum[tpitg] += sum[tpitg + i];
368
+ }
350
369
  }
351
-
352
- // broadcast
353
370
  if (tpitg == 0) {
371
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
354
372
  sum[0] /= ne00;
355
373
  }
356
374
 
@@ -359,156 +377,130 @@ kernel void kernel_rms_norm(
359
377
  const float mean = sum[0];
360
378
  const float scale = 1.0f/sqrt(mean + eps);
361
379
 
362
- device float * y = dst + tgpig*ne00;
363
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
380
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
381
+ device float * y_scalar = (device float *) y;
382
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
364
383
  y[i00] = x[i00] * scale;
365
384
  }
385
+ if (tpitg == 0) {
386
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
387
+ }
388
+ }
389
+
390
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
391
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
392
+ // we assume that the yl's have been multiplied with the appropriate scale factor
393
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
394
+ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
395
+ float d = qb_curr->d;
396
+ float2 acc = 0.f;
397
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
398
+ for (int i = 0; i < 8; i+=2) {
399
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
400
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
401
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
402
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
403
+ }
404
+ return d * (sumy * -8.f + acc[0] + acc[1]);
405
+ }
406
+
407
+ // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
408
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
409
+ // we assume that the yl's have been multiplied with the appropriate scale factor
410
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
411
+ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
412
+ float d = qb_curr->d;
413
+ float m = qb_curr->m;
414
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
415
+ float2 acc = 0.f;
416
+ for (int i = 0; i < 8; i+=2) {
417
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
418
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
419
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
420
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
421
+ }
422
+ return d * (acc[0] + acc[1]) + sumy * m;
366
423
  }
367
424
 
368
425
  // putting them in the kernel cause a significant performance penalty
369
426
  #define N_DST 4 // each SIMD group works on 4 rows
370
427
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371
428
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
372
- kernel void kernel_mul_mat_q4_0_f32(
373
- device const void * src0,
374
- device const float * src1,
375
- device float * dst,
376
- constant int64_t & ne00,
377
- constant int64_t & ne10,
378
- constant int64_t & ne0,
379
- constant int64_t & ne01[[buffer(4)]],
380
- uint2 tgpig[[threadgroup_position_in_grid]],
381
- uint tiisg[[thread_index_in_simdgroup]],
382
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
429
+ //Note: This is a template, but strictly speaking it only applies to
430
+ // quantizations where the block size is 32. It also does not
431
+ // giard against the number of rows not being divisible by
432
+ // N_DST, so this is another explicit assumption of the implementation.
433
+ template<typename block_q_type, int nr, int nsg, int nw>
434
+ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
435
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
436
+ uint2 tgpig, uint tiisg, uint sgitg) {
383
437
  const int nb = ne00/QK4_0;
384
438
  const int r0 = tgpig.x;
385
439
  const int r1 = tgpig.y;
386
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
440
+ const int first_row = (r0 * nsg + sgitg) * nr;
441
+ device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
387
442
  device const float * y = (device const float *) src1 + r1*ne10;
388
- block_q4_0 qb_curr, qb_next;
389
- float4 y_curr[8]; // src1 vector cache
390
- float sumf[N_DST]={0.f}, all_sum;
391
- thread float * yl=(thread float *)y_curr;
392
-
393
- // bootstrap
394
- qb_curr = x[tiisg];
395
- // each thread in a SIMD group deals with 1 block.
396
- for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
397
-
398
- for (int i = 0; i < QK4_0 / 4; i++) {
399
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
443
+ float yl[16]; // src1 vector cache
444
+ float sumf[nr]={0.f};
445
+
446
+ const int ix = tiisg/2;
447
+ const int il = 8*(tiisg%2);
448
+
449
+ device const float * yb = y + ix * QK4_0 + il;
450
+
451
+ // each thread in a SIMD group deals with half a block.
452
+ for (int ib = ix; ib < nb; ib += nw/2) {
453
+ float sumy = 0;
454
+ for (int i = 0; i < 8; i += 2) {
455
+ sumy += yb[i] + yb[i+1];
456
+ yl[i+0] = yb[i+ 0];
457
+ yl[i+1] = yb[i+ 1]/256.f;
458
+ sumy += yb[i+16] + yb[i+17];
459
+ yl[i+8] = yb[i+16]/16.f;
460
+ yl[i+9] = yb[i+17]/4096.f;
400
461
  }
401
462
 
402
- for (int row = 0; row < N_DST; row++) {
403
- // prefetch next x block
404
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
405
-
406
- // calculate
407
- float d = qb_curr.d;
408
- float2 acc = {0.0f, 0.0f};
409
- for (int i = 0; i < 16; i++) {
410
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
411
- acc[1] += yl[i] + yl[i+16];
412
- }
413
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
414
- qb_curr = qb_next;
463
+ for (int row = 0; row < nr; row++) {
464
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
415
465
  }
416
- }
417
466
 
418
- for (int i = 0; i < QK4_0 / 4; i++) {
419
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
467
+ yb += QK4_0 * 16;
420
468
  }
421
469
 
422
- for (int row = 0; row < N_DST; row++) {
423
- // prefetch next x block
424
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
425
-
426
- // calculate
427
- float d = qb_curr.d;
428
- float2 acc = {0.0f, 0.0f};
429
- for (int i = 0; i < 16; i++) {
430
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
431
- acc[1] += yl[i] + yl[i+16];
432
- }
433
- if (tiisg < nb % N_SIMDWIDTH) {
434
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
435
- }
436
- qb_curr = qb_next;
437
-
438
- all_sum = simd_sum(sumf[row]);
439
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
440
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
470
+ for (int row = 0; row < nr; ++row) {
471
+ const float tot = simd_sum(sumf[row]);
472
+ if (tiisg == 0 && first_row + row < ne01) {
473
+ dst[r1*ne0 + first_row + row] = tot;
441
474
  }
442
475
  }
443
476
  }
444
477
 
445
- kernel void kernel_mul_mat_q4_1_f32(
478
+ kernel void kernel_mul_mat_q4_0_f32(
446
479
  device const void * src0,
447
480
  device const float * src1,
448
481
  device float * dst,
449
482
  constant int64_t & ne00,
450
483
  constant int64_t & ne10,
451
484
  constant int64_t & ne0,
452
- threadgroup float * sum [[threadgroup(0)]],
485
+ constant int64_t & ne01[[buffer(4)]],
453
486
  uint2 tgpig[[threadgroup_position_in_grid]],
454
- uint2 tpitg[[thread_position_in_threadgroup]],
455
- uint2 tptg[[threads_per_threadgroup]]) {
456
- const int nb = ne00/QK4_1;
457
-
458
- const int64_t r0 = tgpig.x;
459
- const int64_t r1 = tgpig.y;
460
-
461
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
462
- device const float * y = (device const float *) src1 + r1*ne10;
463
-
464
- const uint nth = tptg.x*tptg.y;
465
- const uint ith = tptg.y*tpitg.x + tpitg.y;
466
-
467
- const int ix = tpitg.y/4; // 0 or 1
468
- const int iy = tpitg.y - 4*ix; // 0...3
469
-
470
- const int first = 4 * iy;
471
-
472
- float sumf = 0;
473
-
474
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
475
-
476
- const float d = (float)x[i].d;
477
- const float m = (float)x[i].m;
478
-
479
- device const uint8_t * xl = x[i].qs + first;
480
- device const float * yl = y + i * QK4_1 + first;
481
-
482
- float2 acc = {0.0f, 0.0f};
483
-
484
- for (int j = 0; j < 4; ++j) {
485
-
486
- acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
487
- acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
488
-
489
- }
490
-
491
- sumf += acc[0] + acc[1];
492
- }
493
-
494
- sum[ith] = sumf;
487
+ uint tiisg[[thread_index_in_simdgroup]],
488
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
490
+ }
495
491
 
496
- //
497
- // Accumulate the sum from all threads in the threadgroup
498
- //
499
- threadgroup_barrier(mem_flags::mem_threadgroup);
500
- if (ith%4 == 0) {
501
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
502
- }
503
- threadgroup_barrier(mem_flags::mem_threadgroup);
504
- if (ith%16 == 0) {
505
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
506
- }
507
- threadgroup_barrier(mem_flags::mem_threadgroup);
508
- if (ith == 0) {
509
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
510
- dst[r1*ne0 + r0] = sum[0];
511
- }
492
+ kernel void kernel_mul_mat_q4_1_f32(
493
+ device const void * src0,
494
+ device const float * src1,
495
+ device float * dst,
496
+ constant int64_t & ne00,
497
+ constant int64_t & ne10,
498
+ constant int64_t & ne0,
499
+ constant int64_t & ne01[[buffer(4)]],
500
+ uint2 tgpig[[threadgroup_position_in_grid]],
501
+ uint tiisg[[thread_index_in_simdgroup]],
502
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
503
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
512
504
  }
513
505
 
514
506
  kernel void kernel_mul_mat_f16_f32(
@@ -624,17 +616,19 @@ kernel void kernel_rope(
624
616
  constant int & n_past,
625
617
  constant int & n_dims,
626
618
  constant int & mode,
619
+ constant float & freq_base,
620
+ constant float & freq_scale,
627
621
  uint3 tpig[[thread_position_in_grid]]) {
628
622
  const int64_t i3 = tpig[2];
629
623
  const int64_t i2 = tpig[1];
630
624
  const int64_t i1 = tpig[0];
631
625
 
632
626
  const bool is_neox = mode & 2;
633
- const float theta_scale = pow(10000.0, -2.0f/n_dims);
627
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
634
628
 
635
629
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
636
630
 
637
- float theta = (float)p;
631
+ float theta = freq_scale * (float)p;
638
632
 
639
633
  if (!is_neox) {
640
634
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -1229,111 +1223,137 @@ kernel void kernel_mul_mat_q2_K_f32(
1229
1223
  constant int64_t & ne00,
1230
1224
  constant int64_t & ne10,
1231
1225
  constant int64_t & ne0,
1232
- threadgroup float * sum [[threadgroup(0)]],
1226
+ constant int64_t & ne01[[buffer(4)]],
1233
1227
  uint2 tgpig[[threadgroup_position_in_grid]],
1234
- uint2 tpitg[[thread_position_in_threadgroup]],
1235
- uint2 tptg[[threads_per_threadgroup]]) {
1228
+ uint tiisg[[thread_index_in_simdgroup]],
1229
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1236
1230
 
1237
1231
  const int nb = ne00/QK_K;
1232
+ const int r0 = tgpig.x;
1233
+ const int r1 = tgpig.y;
1238
1234
 
1239
- const int64_t r0 = tgpig.x;
1240
- const int64_t r1 = tgpig.y;
1241
-
1242
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1243
- device const float * yy = (device const float *) src1 + r1*ne10;
1244
-
1245
- const int nth = tptg.x*tptg.y;
1246
- const int ith = tptg.y*tpitg.x + tpitg.y;
1235
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1236
+ const int ib_row = first_row * nb;
1237
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1238
+ device const float * y = (device const float *) src1 + r1*ne10;
1239
+ float yl[32];
1240
+ float sumf[N_DST]={0.f}, all_sum;
1247
1241
 
1248
- float sumf = 0;
1242
+ const int step = sizeof(block_q2_K) * nb;
1249
1243
 
1250
1244
  #if QK_K == 256
1251
- const int tid = tpitg.y; // 0...16
1252
- const int il = tid/4; // 0...3
1253
- const int ir = tid%4; // 0...3
1254
- const int ip = il/2; // 0 or 1
1255
- const int shift1 = 4*(il%2);// 0 or 4
1256
- const int shift2 = shift1+2;// 2 or 6
1257
- const int n = 8;
1258
- const int is = 4*il + (n*ir)/16;
1259
-
1260
- const int y_offset = 64*il + n*ir;
1261
- const int q_offset = 32*ip + n*ir;
1262
-
1263
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1264
-
1265
- device const uint8_t * q = x[i].qs + q_offset;
1266
- device const uint8_t * scales = x[i].scales + is;
1245
+ const int ix = tiisg/8; // 0...3
1246
+ const int it = tiisg%8; // 0...7
1247
+ const int im = it/4; // 0 or 1
1248
+ const int ir = it%4; // 0...3
1249
+ const int is = (8*ir)/16;// 0 or 1
1250
+
1251
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1252
+
1253
+ for (int ib = ix; ib < nb; ib += 4) {
1254
+
1255
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1256
+ for (int i = 0; i < 8; ++i) {
1257
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1258
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
1259
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
1260
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1261
+ }
1267
1262
 
1268
- uint8_t d1 = scales[0] & 0xF;
1269
- uint8_t d2 = scales[2] & 0xF;
1270
- uint8_t m1 = scales[0] >> 4;
1271
- uint8_t m2 = scales[2] >> 4;
1263
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1264
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1265
+ device const half * dh = &x[ib].d;
1272
1266
 
1273
- device const float * y = yy + i*QK_K + y_offset;
1267
+ for (int row = 0; row < N_DST; row++) {
1274
1268
 
1275
- float2 s = {0.f, 0.f};
1276
- float smin = 0;
1277
- for (int l = 0; l < n; ++l) {
1278
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1279
- s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1280
- smin += y[l+ 0] * m1 + y[l+32] * m2;
1269
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1270
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1271
+ for (int i = 0; i < 8; i += 2) {
1272
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1273
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1274
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1275
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1276
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1277
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1278
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1279
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1280
+ }
1281
+ float dall = dh[0];
1282
+ float dmin = dh[1] * 1.f/16.f;
1283
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1284
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
1285
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
1286
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
1287
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
1288
+
1289
+ qs += step/2;
1290
+ sc += step;
1291
+ dh += step/2;
1281
1292
  }
1282
1293
 
1283
- const float dall = (float)x[i].d;
1284
- const float dmin = (float)x[i].dmin;
1285
-
1286
- sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1287
-
1294
+ y4 += 4 * QK_K;
1288
1295
  }
1289
1296
  #else
1290
- const int il = 4 * tpitg.x;
1297
+ const int ix = tiisg/2; // 0...15
1298
+ const int it = tiisg%2; // 0...1
1291
1299
 
1292
- uint32_t aux[2];
1293
- thread const uint8_t * d = (thread const uint8_t *)aux;
1294
- thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1300
+ device const float * y4 = y + ix * QK_K + 8 * it;
1295
1301
 
1296
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1302
+ for (int ib = ix; ib < nb; ib += 16) {
1297
1303
 
1298
- device const uint8_t * q = x[i].qs + il;
1299
- device const float * y = yy + i*QK_K + il;
1304
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1305
+ for (int i = 0; i < 8; ++i) {
1306
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1307
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
1308
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
1309
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
1310
+ }
1311
+
1312
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
1313
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1314
+ device const half * dh = &x[ib].d;
1300
1315
 
1301
- const float dall = (float)x[i].d;
1302
- const float dmin = (float)x[i].dmin;
1316
+ for (int row = 0; row < N_DST; row++) {
1303
1317
 
1304
- device const uint32_t * a = (device const uint32_t *)x[i].scales;
1305
- aux[0] = a[0] & 0x0f0f0f0f;
1306
- aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1318
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1319
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1320
+ for (int i = 0; i < 8; i += 2) {
1321
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1322
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1323
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1324
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1325
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1326
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1327
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1328
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1329
+ }
1307
1330
 
1308
- for (int l = 0; l < 4; ++l) {
1309
- sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1310
- + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1311
- + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1312
- + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1331
+ float dall = dh[0];
1332
+ float dmin = dh[1];
1333
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1334
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1335
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1336
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1337
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1338
+
1339
+ qs += step/2;
1340
+ sc += step;
1341
+ dh += step/2;
1313
1342
  }
1343
+
1344
+ y4 += 16 * QK_K;
1314
1345
  }
1315
1346
  #endif
1316
1347
 
1317
- sum[ith] = sumf;
1318
-
1319
- //
1320
- // Accumulate the sum from all threads in the threadgroup
1321
- //
1322
- threadgroup_barrier(mem_flags::mem_threadgroup);
1323
- if (ith%4 == 0) {
1324
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1325
- }
1326
- threadgroup_barrier(mem_flags::mem_threadgroup);
1327
- if (ith%16 == 0) {
1328
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1329
- }
1330
- threadgroup_barrier(mem_flags::mem_threadgroup);
1331
- if (ith == 0) {
1332
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1333
- dst[r1*ne0 + r0] = sum[0];
1348
+ for (int row = 0; row < N_DST; ++row) {
1349
+ all_sum = simd_sum(sumf[row]);
1350
+ if (tiisg == 0) {
1351
+ dst[r1*ne0 + first_row + row] = all_sum;
1352
+ }
1334
1353
  }
1335
1354
  }
1336
1355
 
1356
+ #if QK_K == 256
1337
1357
  kernel void kernel_mul_mat_q3_K_f32(
1338
1358
  device const void * src0,
1339
1359
  device const float * src1,
@@ -1342,40 +1362,41 @@ kernel void kernel_mul_mat_q3_K_f32(
1342
1362
  constant int64_t & ne10,
1343
1363
  constant int64_t & ne0,
1344
1364
  constant int64_t & ne1,
1345
- threadgroup float * sum [[threadgroup(0)]],
1346
1365
  uint2 tgpig[[threadgroup_position_in_grid]],
1347
- uint2 tpitg[[thread_position_in_threadgroup]],
1348
- uint2 tptg[[threads_per_threadgroup]]) {
1366
+ uint tiisg[[thread_index_in_simdgroup]],
1367
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1349
1368
 
1350
1369
  const int nb = ne00/QK_K;
1351
1370
 
1352
1371
  const int64_t r0 = tgpig.x;
1353
1372
  const int64_t r1 = tgpig.y;
1354
1373
 
1355
- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1356
- device const float * yy = (device const float *) src1 + r1*ne10;
1357
-
1358
- const int nth = tptg.x*tptg.y;
1359
- const int ith = tptg.y*tpitg.x + tpitg.y;
1374
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1360
1375
 
1361
- #if QK_K == 256
1376
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1377
+ device const float * yy = (device const float *) src1 + r1*ne10;
1362
1378
 
1363
- const uint8_t m3 = 3;
1364
- const int8_t m4 = 4;
1379
+ float yl[16];
1365
1380
 
1366
1381
  const uint16_t kmask1 = 0x0303;
1367
1382
  const uint16_t kmask2 = 0x0f0f;
1368
1383
 
1369
- const int tid = tpitg.y; // expecting 16
1384
+ const int tid = tiisg/2;
1385
+ const int ix = tiisg%2;
1370
1386
  const int ip = tid/8; // 0 or 1
1371
1387
  const int il = tid/2 - 4*ip; // 0...3
1372
1388
  const int ir = tid%2;
1373
1389
  const int n = 8;
1374
1390
  const int l0 = n*ir;
1375
1391
 
1376
- const uint8_t m = 1 << (4*ip + il);
1392
+ const uint16_t m1 = 1 << (4*ip + il);
1393
+ const uint16_t m2 = m1 << 8;
1377
1394
 
1378
1395
  const int shift = 2*il;
1396
+ const uint16_t qm1 = 0x0003 << shift;
1397
+ const uint16_t qm2 = 0x0300 << shift;
1398
+ const int32_t v1 = 4 << shift;
1399
+ const int32_t v2 = 1024 << shift;
1379
1400
 
1380
1401
  const uint16_t s_shift1 = 4*ip;
1381
1402
  const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1384,226 +1405,315 @@ kernel void kernel_mul_mat_q3_K_f32(
1384
1405
  const int q_offset = 32*ip + l0;
1385
1406
  const int y_offset = 128*ip + 32*il + l0;
1386
1407
 
1387
- //float sumf = 0;
1388
- float sumf1 = 0, sumf2 = 0;
1389
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1408
+ const int step = sizeof(block_q3_K) * nb / 2;
1390
1409
 
1391
- const float d_all = (float)(x[i].d);
1392
-
1393
- device const uint8_t * q = x[i].qs + q_offset;
1394
- device const uint8_t * h = x[i].hmask + l0;
1395
- device const float * y = yy + i * QK_K + y_offset;
1410
+ device const float * y1 = yy + ix*QK_K + y_offset;
1396
1411
 
1397
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1398
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1412
+ float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1413
+ for (int i = ix; i < nb; i += 2) {
1399
1414
 
1400
- float s = 0;
1401
- for (int l = 0; l < n; ++l) {
1402
- s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1403
- }
1404
- float d = d_all * s;
1405
- sumf1 += d * scales[0];
1406
- sumf2 += d;
1407
- //sumf += d_all * s * (scales[0] - 32);
1408
-
1409
- s = 0;
1410
- for (int l = 0; l < n; ++l) {
1411
- s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1415
+ for (int l = 0; l < 8; ++l) {
1416
+ yl[l+0] = y1[l+ 0];
1417
+ yl[l+8] = y1[l+16];
1412
1418
  }
1413
- d = d_all * s;
1414
- sumf1 += d * scales[1];
1415
- sumf2 += d;
1416
- //sumf += d_all * s * (scales[1] - 32);
1417
-
1418
- }
1419
-
1420
- //sum[ith] = sumf;
1421
- sum[ith] = sumf1 - 32.f*sumf2;
1422
- #else
1423
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1424
- const int im = il/8; // 0, 0, 1, 1
1425
- const int in = il%8; // 0, 4, 0, 4
1426
1419
 
1427
- float sumf = 0;
1420
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1421
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1422
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1423
+ device const half * dh = &x[i].d;
1428
1424
 
1429
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1425
+ for (int row = 0; row < 2; ++row) {
1430
1426
 
1431
- const float d_all = (float)(x[i].d);
1427
+ const float d_all = (float)dh[0];
1428
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1432
1429
 
1433
- device const uint8_t * q = x[i].qs + il;
1434
- device const uint8_t * h = x[i].hmask + in;
1435
- device const float * y = yy + i * QK_K + il;
1430
+ float s1 = 0, s2 = 0;
1431
+ for (int l = 0; l < n; l += 2) {
1432
+ const uint16_t qs = q[l/2];
1433
+ s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1434
+ s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1435
+ }
1436
+ float d = d_all * (s1 + 1.f/256.f * s2);
1437
+ sumf1[row] += d * scales[0];
1438
+ sumf2[row] += d;
1439
+
1440
+ s1 = s2 = 0;
1441
+ for (int l = 0; l < n; l += 2) {
1442
+ const uint16_t qs = q[l/2+8];
1443
+ s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1444
+ s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1445
+ }
1446
+ d = d_all * (s1 + 1.f/256.f * s2);
1447
+ sumf1[row] += d * scales[1];
1448
+ sumf2[row] += d;
1436
1449
 
1437
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1438
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1439
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1440
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1450
+ q += step;
1451
+ h += step;
1452
+ a += step;
1453
+ dh += step;
1441
1454
 
1442
- for (int l = 0; l < 4; ++l) {
1443
- const uint8_t hm = h[l] >> im;
1444
- sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1445
- + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1446
- + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1447
- + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1448
1455
  }
1449
1456
 
1450
- }
1451
-
1452
- sum[ith] = sumf;
1457
+ y1 += 2 * QK_K;
1453
1458
 
1454
- #endif
1455
-
1456
- //
1457
- // Accumulate the sum from all threads in the threadgroup
1458
- //
1459
- threadgroup_barrier(mem_flags::mem_threadgroup);
1460
- if (ith%4 == 0) {
1461
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1462
- }
1463
- threadgroup_barrier(mem_flags::mem_threadgroup);
1464
- if (ith%16 == 0) {
1465
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1466
- }
1467
- threadgroup_barrier(mem_flags::mem_threadgroup);
1468
- if (ith == 0) {
1469
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1470
- dst[r1*ne0 + r0] = sum[0];
1471
1459
  }
1472
1460
 
1461
+ for (int row = 0; row < 2; ++row) {
1462
+ const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1463
+ const float tot = simd_sum(sumf);
1464
+ if (tiisg == 0) {
1465
+ dst[r1*ne0 + first_row + row] = tot;
1466
+ }
1467
+ }
1473
1468
  }
1474
-
1475
- kernel void kernel_mul_mat_q4_K_f32(
1469
+ #else
1470
+ kernel void kernel_mul_mat_q3_K_f32(
1476
1471
  device const void * src0,
1477
1472
  device const float * src1,
1478
1473
  device float * dst,
1479
1474
  constant int64_t & ne00,
1480
1475
  constant int64_t & ne10,
1481
1476
  constant int64_t & ne0,
1482
- threadgroup float * sum [[threadgroup(0)]],
1477
+ constant int64_t & ne1,
1483
1478
  uint2 tgpig[[threadgroup_position_in_grid]],
1484
- uint2 tpitg[[thread_position_in_threadgroup]],
1485
- uint2 tptg[[threads_per_threadgroup]]) {
1479
+ uint tiisg[[thread_index_in_simdgroup]],
1480
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1486
1481
 
1487
1482
  const int nb = ne00/QK_K;
1488
1483
 
1489
1484
  const int64_t r0 = tgpig.x;
1490
1485
  const int64_t r1 = tgpig.y;
1491
1486
 
1492
- const int nth = tptg.x*tptg.y;
1493
- const int ith = tptg.y*tpitg.x + tpitg.y;
1487
+ const int row = 2 * r0 + sgitg;
1494
1488
 
1495
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1489
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1496
1490
  device const float * yy = (device const float *) src1 + r1*ne10;
1491
+ const int ix = tiisg/4;
1492
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1493
+ const int im = il/8; // 0, 0, 1, 1
1494
+ const int in = il%8; // 0, 4, 0, 4
1497
1495
 
1498
- float sumf = 0;
1496
+ float2 sum = {0.f, 0.f};
1497
+
1498
+ for (int i = ix; i < nb; i += 8) {
1499
+
1500
+ const float d_all = (float)(x[i].d);
1501
+
1502
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1503
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1504
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1505
+ device const float * y = yy + i * QK_K + il;
1506
+
1507
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1508
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1509
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1510
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1511
+
1512
+ for (int l = 0; l < 4; l += 2) {
1513
+ const uint16_t hm = h[l/2] >> im;
1514
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1515
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1516
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1517
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1518
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1519
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1520
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1521
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
1522
+ }
1523
+
1524
+ }
1525
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
1526
+
1527
+ const float tot = simd_sum(sumf);
1528
+ if (tiisg == 0) {
1529
+ dst[r1*ne0 + row] = tot;
1530
+ }
1531
+
1532
+ }
1533
+ #endif
1499
1534
 
1500
1535
  #if QK_K == 256
1536
+ kernel void kernel_mul_mat_q4_K_f32(
1537
+ device const void * src0,
1538
+ device const float * src1,
1539
+ device float * dst,
1540
+ constant int64_t & ne00,
1541
+ constant int64_t & ne10,
1542
+ constant int64_t & ne0,
1543
+ constant int64_t & ne01[[buffer(4)]],
1544
+ uint2 tgpig[[threadgroup_position_in_grid]],
1545
+ uint tiisg[[thread_index_in_simdgroup]],
1546
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1501
1547
 
1502
1548
  const uint16_t kmask1 = 0x3f3f;
1503
1549
  const uint16_t kmask2 = 0x0f0f;
1504
1550
  const uint16_t kmask3 = 0xc0c0;
1505
1551
 
1506
- const int tid = tpitg.y; // 0...16
1507
- const int il = tid/4; // 0...3
1508
- const int ir = tid - 4*il;// 0...3
1509
- const int n = 4;
1552
+ const int ix = tiisg/8; // 0...3
1553
+ const int it = tiisg%8; // 0...7
1554
+ const int im = it/4; // 0 or 1
1555
+ const int ir = it%4; // 0...3
1510
1556
 
1511
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1512
- const int in = il%2;
1557
+ const int nb = ne00/QK_K;
1558
+ const int r0 = tgpig.x;
1559
+ const int r1 = tgpig.y;
1560
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1561
+ const int ib_row = first_row * nb;
1562
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1563
+ device const float * y = (device const float *) src1 + r1*ne10;
1564
+ float yl[16];
1565
+ float yh[16];
1566
+ float sumf[N_DST]={0.f}, all_sum;
1513
1567
 
1514
- const int l0 = n*(2*ir + in);
1515
- const int q_offset = 32*im + l0;
1516
- const int y_offset = 64*im + l0;
1568
+ const int step = sizeof(block_q4_K) * nb / 2;
1517
1569
 
1518
- uchar2 sc1, sc2, sc3, sc4;
1570
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
1519
1571
 
1520
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1572
+ uint16_t sc16[4];
1573
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1521
1574
 
1522
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1523
- device const uint8_t * q2 = q1 + 64;
1524
- device const float * y1 = yy + i*QK_K + y_offset;
1525
- device const float * y2 = y1 + 128;
1575
+ for (int ib = ix; ib < nb; ib += 4) {
1526
1576
 
1527
- const float dall = (float)((x + i)->d);
1528
- const float dmin = (float)((x + i)->dmin);
1577
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1578
+ for (int i = 0; i < 8; ++i) {
1579
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
1580
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
1581
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
1582
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1583
+ }
1529
1584
 
1530
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1531
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1532
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1533
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1534
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1585
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1586
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1587
+ device const half * dh = &x[ib].d;
1535
1588
 
1536
- float4 s = {0.f, 0.f, 0.f, 0.f};
1537
- float smin = 0;
1538
- for (int l = 0; l < n; ++l) {
1589
+ for (int row = 0; row < N_DST; row++) {
1539
1590
 
1540
- s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1541
- s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1542
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1591
+ sc16[0] = sc[0] & kmask1;
1592
+ sc16[1] = sc[2] & kmask1;
1593
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
1594
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
1595
+
1596
+ device const uint16_t * q2 = q1 + 32;
1597
+
1598
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1599
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1600
+ for (int i = 0; i < 8; i += 2) {
1601
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
1602
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
1603
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
1604
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
1605
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
1606
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
1607
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
1608
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
1609
+ }
1543
1610
 
1611
+ float dall = dh[0];
1612
+ float dmin = dh[1];
1613
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
1614
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
1615
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
1616
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
1617
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1618
+
1619
+ q1 += step;
1620
+ sc += step;
1621
+ dh += step;
1544
1622
  }
1545
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1546
1623
 
1624
+ y4 += 4 * QK_K;
1625
+ }
1626
+
1627
+ for (int row = 0; row < N_DST; ++row) {
1628
+ all_sum = simd_sum(sumf[row]);
1629
+ if (tiisg == 0) {
1630
+ dst[r1*ne0 + first_row + row] = all_sum;
1631
+ }
1547
1632
  }
1633
+ }
1548
1634
  #else
1549
- uint16_t aux16[2];
1550
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
1635
+ kernel void kernel_mul_mat_q4_K_f32(
1636
+ device const void * src0,
1637
+ device const float * src1,
1638
+ device float * dst,
1639
+ constant int64_t & ne00,
1640
+ constant int64_t & ne10,
1641
+ constant int64_t & ne0,
1642
+ constant int64_t & ne01[[buffer(4)]],
1643
+ uint2 tgpig[[threadgroup_position_in_grid]],
1644
+ uint tiisg[[thread_index_in_simdgroup]],
1645
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1551
1646
 
1552
- const int il = 4*tpitg.x;
1647
+ const int ix = tiisg/4; // 0...7
1648
+ const int it = tiisg%4; // 0...3
1553
1649
 
1554
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1650
+ const int nb = ne00/QK_K;
1651
+ const int r0 = tgpig.x;
1652
+ const int r1 = tgpig.y;
1653
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1654
+ const int ib_row = first_row * nb;
1655
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1656
+ device const float * y = (device const float *) src1 + r1*ne10;
1657
+ float yl[8];
1658
+ float yh[8];
1659
+ float sumf[N_DST]={0.f}, all_sum;
1555
1660
 
1556
- device const uint8_t * q = x[i].qs + il;
1557
- device const float * y = yy + i * QK_K + il;
1661
+ const int step = sizeof(block_q4_K) * nb / 2;
1558
1662
 
1559
- const float d = (float)x[i].d[0];
1560
- const float m = (float)x[i].d[1];
1663
+ device const float * y4 = y + ix * QK_K + 8 * it;
1561
1664
 
1562
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1563
- aux16[0] = a[0] & 0x0f0f;
1564
- aux16[1] = (a[0] >> 4) & 0x0f0f;
1665
+ uint16_t sc16[4];
1565
1666
 
1566
- for (int l = 0; l < 4; ++l) {
1567
- sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1568
- + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1667
+ for (int ib = ix; ib < nb; ib += 8) {
1668
+
1669
+ float2 sumy = {0.f, 0.f};
1670
+ for (int i = 0; i < 8; ++i) {
1671
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
1672
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
1569
1673
  }
1570
- }
1571
- #endif
1572
1674
 
1573
- sum[ith] = sumf;
1675
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
1676
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1677
+ device const half * dh = x[ib].d;
1574
1678
 
1575
- //
1576
- // Accumulate the sum from all threads in the threadgroup
1577
- // This version is slightly faster than the commented out one below,
1578
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1579
- //
1580
- threadgroup_barrier(mem_flags::mem_threadgroup);
1581
- if (ith%4 == 0) {
1582
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1583
- }
1584
- threadgroup_barrier(mem_flags::mem_threadgroup);
1585
- if (ith%16 == 0) {
1586
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1587
- }
1588
- threadgroup_barrier(mem_flags::mem_threadgroup);
1589
- if (ith == 0) {
1590
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1591
- dst[r1*ne0 + r0] = sum[0];
1679
+ for (int row = 0; row < N_DST; row++) {
1680
+
1681
+ sc16[0] = sc[0] & 0x000f;
1682
+ sc16[1] = sc[0] & 0x0f00;
1683
+ sc16[2] = sc[0] & 0x00f0;
1684
+ sc16[3] = sc[0] & 0xf000;
1685
+
1686
+ float2 acc1 = {0.f, 0.f};
1687
+ float2 acc2 = {0.f, 0.f};
1688
+ for (int i = 0; i < 8; i += 2) {
1689
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
1690
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
1691
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
1692
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
1693
+ }
1694
+
1695
+ float dall = dh[0];
1696
+ float dmin = dh[1];
1697
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
1698
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
1699
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
1700
+
1701
+ qs += step;
1702
+ sc += step;
1703
+ dh += step;
1704
+ }
1705
+
1706
+ y4 += 8 * QK_K;
1592
1707
  }
1593
1708
 
1594
- //// accumulate the sum from all threads in the threadgroup
1595
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1596
- //for (uint i = nth/2; i > 0; i /= 2) {
1597
- // if (ith < i) {
1598
- // sum[ith] += sum[ith + i];
1599
- // }
1600
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1601
- //}
1602
-
1603
- //if (ith == 0) {
1604
- // dst[r1*ne0 + r0] = sum[0];
1605
- //}
1709
+ for (int row = 0; row < N_DST; ++row) {
1710
+ all_sum = simd_sum(sumf[row]);
1711
+ if (tiisg == 0) {
1712
+ dst[r1*ne0 + first_row + row] = all_sum;
1713
+ }
1714
+ }
1606
1715
  }
1716
+ #endif
1607
1717
 
1608
1718
  kernel void kernel_mul_mat_q5_K_f32(
1609
1719
  device const void * src0,
@@ -1612,39 +1722,39 @@ kernel void kernel_mul_mat_q5_K_f32(
1612
1722
  constant int64_t & ne00,
1613
1723
  constant int64_t & ne10,
1614
1724
  constant int64_t & ne0,
1615
- threadgroup float * sum [[threadgroup(0)]],
1616
1725
  uint2 tgpig[[threadgroup_position_in_grid]],
1617
- uint2 tpitg[[thread_position_in_threadgroup]],
1618
- uint2 tptg[[threads_per_threadgroup]]) {
1726
+ uint tiisg[[thread_index_in_simdgroup]],
1727
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1619
1728
 
1620
1729
  const int nb = ne00/QK_K;
1621
1730
 
1622
1731
  const int64_t r0 = tgpig.x;
1623
1732
  const int64_t r1 = tgpig.y;
1624
1733
 
1625
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1734
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1735
+
1736
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1626
1737
  device const float * yy = (device const float *) src1 + r1*ne10;
1627
1738
 
1628
- const int nth = tptg.x*tptg.y;
1629
- const int ith = tptg.y*tpitg.x + tpitg.y;
1739
+ float sumf[2]={0.f};
1630
1740
 
1631
- float sumf = 0;
1741
+ const int step = sizeof(block_q5_K) * nb;
1632
1742
 
1633
1743
  #if QK_K == 256
1744
+ #
1745
+ float yl[16], yh[16];
1634
1746
 
1635
1747
  const uint16_t kmask1 = 0x3f3f;
1636
1748
  const uint16_t kmask2 = 0x0f0f;
1637
1749
  const uint16_t kmask3 = 0xc0c0;
1638
1750
 
1639
- const int tid = tpitg.y; // 0...16
1640
- const int il = tid/4; // 0...3
1641
- const int ir = tid - 4*il;// 0...3
1642
- const int n = 4;
1643
-
1644
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1645
- const int in = il%2;
1751
+ const int tid = tiisg/4;
1752
+ const int ix = tiisg%4;
1753
+ const int im = tid/4;
1754
+ const int ir = tid%4;
1755
+ const int n = 8;
1646
1756
 
1647
- const int l0 = n*(2*ir + in);
1757
+ const int l0 = n*ir;
1648
1758
  const int q_offset = 32*im + l0;
1649
1759
  const int y_offset = 64*im + l0;
1650
1760
 
@@ -1653,78 +1763,113 @@ kernel void kernel_mul_mat_q5_K_f32(
1653
1763
  const uint8_t hm3 = hm1 << 4;
1654
1764
  const uint8_t hm4 = hm2 << 4;
1655
1765
 
1656
- uchar2 sc1, sc2, sc3, sc4;
1766
+ uint16_t sc16[4];
1767
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1657
1768
 
1658
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1769
+ device const float * y1 = yy + ix*QK_K + y_offset;
1659
1770
 
1660
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1661
- device const uint8_t * q2 = q1 + 64;
1662
- device const uint8_t * qh = (x + i)->qh + l0;
1663
- device const float * y1 = yy + i*QK_K + y_offset;
1664
- device const float * y2 = y1 + 128;
1771
+ for (int i = ix; i < nb; i += 4) {
1665
1772
 
1666
- const float dall = (float)((x + i)->d);
1667
- const float dmin = (float)((x + i)->dmin);
1773
+ device const uint8_t * q1 = x[i].qs + q_offset;
1774
+ device const uint8_t * qh = x[i].qh + l0;
1775
+ device const half * dh = &x[i].d;
1776
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1668
1777
 
1669
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1670
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1671
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1672
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1673
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1778
+ device const float * y2 = y1 + 128;
1779
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1780
+ for (int l = 0; l < 8; ++l) {
1781
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1782
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1783
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1784
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1785
+ }
1674
1786
 
1675
- float4 s = {0.f, 0.f, 0.f, 0.f};
1676
- float smin = 0;
1677
- for (int l = 0; l < n; ++l) {
1787
+ for (int row = 0; row < 2; ++row) {
1788
+
1789
+ device const uint8_t * q2 = q1 + 64;
1678
1790
 
1679
- s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1680
- s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1681
- s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1682
- s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1683
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1791
+ sc16[0] = a[0] & kmask1;
1792
+ sc16[1] = a[2] & kmask1;
1793
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1794
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1795
+
1796
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
1797
+ for (int l = 0; l < n; ++l) {
1798
+ uint8_t h = qh[l];
1799
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1800
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1801
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1802
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1803
+ }
1804
+ const float dall = dh[0];
1805
+ const float dmin = dh[1];
1806
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1807
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1808
+
1809
+ q1 += step;
1810
+ qh += step;
1811
+ dh += step/2;
1812
+ a += step/2;
1684
1813
 
1685
1814
  }
1686
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1815
+
1816
+ y1 += 4 * QK_K;
1687
1817
 
1688
1818
  }
1689
1819
  #else
1690
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1691
- const int im = il/8; // 0, 0, 1, 1
1692
- const int in = il%8; // 0, 4, 0, 4
1820
+ float yl[8], yh[8];
1821
+
1822
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1823
+ const int ix = tiisg%8;
1824
+ const int im = il/8; // 0, 0, 1, 1
1825
+ const int in = il%8; // 0, 4, 0, 4
1693
1826
 
1694
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1827
+ device const float * y = yy + ix*QK_K + il;
1695
1828
 
1696
- const float d = (float)x[i].d;
1829
+ for (int i = ix; i < nb; i += 8) {
1830
+
1831
+ for (int l = 0; l < 4; ++l) {
1832
+ yl[l+0] = y[l+ 0];
1833
+ yl[l+4] = y[l+16];
1834
+ yh[l+0] = y[l+32];
1835
+ yh[l+4] = y[l+48];
1836
+ }
1837
+
1838
+ device const half * dh = &x[i].d;
1697
1839
  device const uint8_t * q = x[i].qs + il;
1698
1840
  device const uint8_t * h = x[i].qh + in;
1699
1841
  device const int8_t * s = x[i].scales;
1700
- device const float * y = yy + i*QK_K + il;
1701
1842
 
1702
- for (int l = 0; l < 4; ++l) {
1703
- const uint8_t hl = h[l] >> im;
1704
- sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1705
- + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1706
- + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1707
- + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1843
+ for (int row = 0; row < 2; ++row) {
1844
+
1845
+ const float d = dh[0];
1846
+
1847
+ float2 acc = {0.f, 0.f};
1848
+ for (int l = 0; l < 4; ++l) {
1849
+ const uint8_t hl = h[l] >> im;
1850
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1851
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1852
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1853
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1854
+ }
1855
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1856
+
1857
+ q += step;
1858
+ h += step;
1859
+ s += step;
1860
+ dh += step/2;
1861
+
1708
1862
  }
1863
+
1864
+ y += 8 * QK_K;
1709
1865
  }
1710
1866
  #endif
1711
- sum[ith] = sumf;
1712
1867
 
1713
- //
1714
- // Accumulate the sum from all threads in the threadgroup
1715
- //
1716
- threadgroup_barrier(mem_flags::mem_threadgroup);
1717
- if (ith%4 == 0) {
1718
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1719
- }
1720
- threadgroup_barrier(mem_flags::mem_threadgroup);
1721
- if (ith%16 == 0) {
1722
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1723
- }
1724
- threadgroup_barrier(mem_flags::mem_threadgroup);
1725
- if (ith == 0) {
1726
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1727
- dst[r1*ne0 + r0] = sum[0];
1868
+ for (int row = 0; row < 2; ++row) {
1869
+ const float tot = simd_sum(sumf[row]);
1870
+ if (tiisg == 0) {
1871
+ dst[r1*ne0 + first_row + row] = tot;
1872
+ }
1728
1873
  }
1729
1874
 
1730
1875
  }
@@ -1736,10 +1881,9 @@ kernel void kernel_mul_mat_q6_K_f32(
1736
1881
  constant int64_t & ne00,
1737
1882
  constant int64_t & ne10,
1738
1883
  constant int64_t & ne0,
1739
- threadgroup float * sum [[threadgroup(0)]],
1740
1884
  uint2 tgpig[[threadgroup_position_in_grid]],
1741
- uint2 tpitg[[thread_position_in_threadgroup]],
1742
- uint2 tptg[[threads_per_threadgroup]]) {
1885
+ uint tiisg[[thread_index_in_simdgroup]],
1886
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1743
1887
 
1744
1888
  const uint8_t kmask1 = 0x03;
1745
1889
  const uint8_t kmask2 = 0x0C;
@@ -1751,19 +1895,18 @@ kernel void kernel_mul_mat_q6_K_f32(
1751
1895
  const int64_t r0 = tgpig.x;
1752
1896
  const int64_t r1 = tgpig.y;
1753
1897
 
1754
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1755
- device const float * yy = (device const float *) src1 + r1*ne10;
1898
+ const int row = 2 * r0 + sgitg;
1756
1899
 
1757
- const int nth = tptg.x*tptg.y;
1758
- const int ith = tptg.y*tpitg.x + tpitg.y;
1900
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1901
+ device const float * yy = (device const float *) src1 + r1*ne10;
1759
1902
 
1760
1903
  float sumf = 0;
1761
1904
 
1762
1905
  #if QK_K == 256
1763
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1764
- const int iqs = 16 * tpitg.y;
1765
- const int ip = iqs / 128; // 0 or 1
1766
- const int il = (iqs - 128*ip)/16; // 0...7
1906
+ const int tid = tiisg/2;
1907
+ const int ix = tiisg%2;
1908
+ const int ip = tid/8; // 0 or 1
1909
+ const int il = tid%8;
1767
1910
  const int n = 4;
1768
1911
  const int l0 = n*il;
1769
1912
  const int is = 8*ip + l0/16;
@@ -1772,9 +1915,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1772
1915
  const int q_offset_l = 64*ip + l0;
1773
1916
  const int q_offset_h = 32*ip + l0;
1774
1917
 
1775
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1918
+ for (int i = ix; i < nb; i += 2) {
1776
1919
 
1777
- device const uint8_t * ql = x[i].ql + q_offset_l;
1920
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1921
+ device const uint8_t * q2 = q1 + 32;
1778
1922
  device const uint8_t * qh = x[i].qh + q_offset_h;
1779
1923
  device const int8_t * sc = x[i].scales + is;
1780
1924
 
@@ -1784,19 +1928,21 @@ kernel void kernel_mul_mat_q6_K_f32(
1784
1928
 
1785
1929
  float4 sums = {0.f, 0.f, 0.f, 0.f};
1786
1930
  for (int l = 0; l < n; ++l) {
1787
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1788
- sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1789
- sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1790
- sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1931
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1932
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1933
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1934
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1791
1935
  }
1792
1936
 
1793
1937
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1794
1938
 
1795
1939
  }
1940
+
1796
1941
  #else
1797
- const int il = 4*tpitg.x; // 0, 4, 8, 12
1942
+ const int ix = tiisg/4;
1943
+ const int il = 4*(tiisg%4);
1798
1944
 
1799
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1945
+ for (int i = ix; i < nb; i += 8) {
1800
1946
  device const float * y = yy + i * QK_K + il;
1801
1947
  device const uint8_t * ql = x[i].ql + il;
1802
1948
  device const uint8_t * qh = x[i].qh + il;
@@ -1816,23 +1962,8 @@ kernel void kernel_mul_mat_q6_K_f32(
1816
1962
 
1817
1963
  #endif
1818
1964
 
1819
- sum[ith] = sumf;
1820
-
1821
- //
1822
- // Accumulate the sum from all threads in the threadgroup
1823
- //
1824
- threadgroup_barrier(mem_flags::mem_threadgroup);
1825
- if (ith%4 == 0) {
1826
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1965
+ const float tot = simd_sum(sumf);
1966
+ if (tiisg == 0) {
1967
+ dst[r1*ne0 + row] = tot;
1827
1968
  }
1828
- threadgroup_barrier(mem_flags::mem_threadgroup);
1829
- if (ith%16 == 0) {
1830
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1831
- }
1832
- threadgroup_barrier(mem_flags::mem_threadgroup);
1833
- if (ith == 0) {
1834
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1835
- dst[r1*ne0 + r0] = sum[0];
1836
- }
1837
-
1838
1969
  }