llama_cpp 0.3.3 → 0.3.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -67,6 +67,17 @@ kernel void kernel_add(
67
67
  dst[tpig] = src0[tpig] + src1[tpig];
68
68
  }
69
69
 
70
+ // assumption: src1 is a row
71
+ // broadcast src1 into src0
72
+ kernel void kernel_add_row(
73
+ device const float * src0,
74
+ device const float * src1,
75
+ device float * dst,
76
+ constant int64_t & ne00,
77
+ uint tpig[[thread_position_in_grid]]) {
78
+ dst[tpig] = src0[tpig] + src1[tpig % ne00];
79
+ }
80
+
70
81
  kernel void kernel_mul(
71
82
  device const float * src0,
72
83
  device const float * src1,
@@ -331,26 +342,33 @@ kernel void kernel_rms_norm(
331
342
  threadgroup float * sum [[threadgroup(0)]],
332
343
  uint tgpig[[threadgroup_position_in_grid]],
333
344
  uint tpitg[[thread_position_in_threadgroup]],
345
+ uint sgitg[[simdgroup_index_in_threadgroup]],
346
+ uint tiisg[[thread_index_in_simdgroup]],
334
347
  uint ntg[[threads_per_threadgroup]]) {
335
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
348
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
349
+ device const float * x_scalar = (device const float *) x;
350
+ float4 sumf=0;
351
+ float all_sum=0;
336
352
 
337
353
  // parallel sum
338
- sum[tpitg] = 0.0f;
339
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340
- sum[tpitg] += x[i00] * x[i00];
354
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
355
+ sumf += x[i00] * x[i00];
356
+ }
357
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
358
+ all_sum = simd_sum(all_sum);
359
+ if (tiisg == 0) {
360
+ sum[sgitg] = all_sum;
341
361
  }
342
362
 
343
- // reduce
344
363
  threadgroup_barrier(mem_flags::mem_threadgroup);
345
- for (uint i = ntg/2; i > 0; i /= 2) {
346
- if (tpitg < i) {
347
- sum[tpitg] += sum[tpitg + i];
348
- }
349
- threadgroup_barrier(mem_flags::mem_threadgroup);
364
+ // broadcast, simd group number is ntg / 32
365
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
366
+ if (tpitg < i) {
367
+ sum[tpitg] += sum[tpitg + i];
368
+ }
350
369
  }
351
-
352
- // broadcast
353
370
  if (tpitg == 0) {
371
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
354
372
  sum[0] /= ne00;
355
373
  }
356
374
 
@@ -359,156 +377,130 @@ kernel void kernel_rms_norm(
359
377
  const float mean = sum[0];
360
378
  const float scale = 1.0f/sqrt(mean + eps);
361
379
 
362
- device float * y = dst + tgpig*ne00;
363
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
380
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
381
+ device float * y_scalar = (device float *) y;
382
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
364
383
  y[i00] = x[i00] * scale;
365
384
  }
385
+ if (tpitg == 0) {
386
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
387
+ }
388
+ }
389
+
390
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
391
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
392
+ // we assume that the yl's have been multiplied with the appropriate scale factor
393
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
394
+ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
395
+ float d = qb_curr->d;
396
+ float2 acc = 0.f;
397
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
398
+ for (int i = 0; i < 8; i+=2) {
399
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
400
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
401
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
402
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
403
+ }
404
+ return d * (sumy * -8.f + acc[0] + acc[1]);
405
+ }
406
+
407
+ // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
408
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
409
+ // we assume that the yl's have been multiplied with the appropriate scale factor
410
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
411
+ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
412
+ float d = qb_curr->d;
413
+ float m = qb_curr->m;
414
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
415
+ float2 acc = 0.f;
416
+ for (int i = 0; i < 8; i+=2) {
417
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
418
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
419
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
420
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
421
+ }
422
+ return d * (acc[0] + acc[1]) + sumy * m;
366
423
  }
367
424
 
368
425
  // putting them in the kernel cause a significant performance penalty
369
426
  #define N_DST 4 // each SIMD group works on 4 rows
370
427
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371
428
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
372
- kernel void kernel_mul_mat_q4_0_f32(
373
- device const void * src0,
374
- device const float * src1,
375
- device float * dst,
376
- constant int64_t & ne00,
377
- constant int64_t & ne10,
378
- constant int64_t & ne0,
379
- constant int64_t & ne01[[buffer(4)]],
380
- uint2 tgpig[[threadgroup_position_in_grid]],
381
- uint tiisg[[thread_index_in_simdgroup]],
382
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
429
+ //Note: This is a template, but strictly speaking it only applies to
430
+ // quantizations where the block size is 32. It also does not
431
+ // giard against the number of rows not being divisible by
432
+ // N_DST, so this is another explicit assumption of the implementation.
433
+ template<typename block_q_type, int nr, int nsg, int nw>
434
+ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
435
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
436
+ uint2 tgpig, uint tiisg, uint sgitg) {
383
437
  const int nb = ne00/QK4_0;
384
438
  const int r0 = tgpig.x;
385
439
  const int r1 = tgpig.y;
386
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
440
+ const int first_row = (r0 * nsg + sgitg) * nr;
441
+ device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
387
442
  device const float * y = (device const float *) src1 + r1*ne10;
388
- block_q4_0 qb_curr, qb_next;
389
- float4 y_curr[8]; // src1 vector cache
390
- float sumf[N_DST]={0.f}, all_sum;
391
- thread float * yl=(thread float *)y_curr;
392
-
393
- // bootstrap
394
- qb_curr = x[tiisg];
395
- // each thread in a SIMD group deals with 1 block.
396
- for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
397
-
398
- for (int i = 0; i < QK4_0 / 4; i++) {
399
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
443
+ float yl[16]; // src1 vector cache
444
+ float sumf[nr]={0.f};
445
+
446
+ const int ix = tiisg/2;
447
+ const int il = 8*(tiisg%2);
448
+
449
+ device const float * yb = y + ix * QK4_0 + il;
450
+
451
+ // each thread in a SIMD group deals with half a block.
452
+ for (int ib = ix; ib < nb; ib += nw/2) {
453
+ float sumy = 0;
454
+ for (int i = 0; i < 8; i += 2) {
455
+ sumy += yb[i] + yb[i+1];
456
+ yl[i+0] = yb[i+ 0];
457
+ yl[i+1] = yb[i+ 1]/256.f;
458
+ sumy += yb[i+16] + yb[i+17];
459
+ yl[i+8] = yb[i+16]/16.f;
460
+ yl[i+9] = yb[i+17]/4096.f;
400
461
  }
401
462
 
402
- for (int row = 0; row < N_DST; row++) {
403
- // prefetch next x block
404
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
405
-
406
- // calculate
407
- float d = qb_curr.d;
408
- float2 acc = {0.0f, 0.0f};
409
- for (int i = 0; i < 16; i++) {
410
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
411
- acc[1] += yl[i] + yl[i+16];
412
- }
413
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
414
- qb_curr = qb_next;
463
+ for (int row = 0; row < nr; row++) {
464
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
415
465
  }
416
- }
417
466
 
418
- for (int i = 0; i < QK4_0 / 4; i++) {
419
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
467
+ yb += QK4_0 * 16;
420
468
  }
421
469
 
422
- for (int row = 0; row < N_DST; row++) {
423
- // prefetch next x block
424
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
425
-
426
- // calculate
427
- float d = qb_curr.d;
428
- float2 acc = {0.0f, 0.0f};
429
- for (int i = 0; i < 16; i++) {
430
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
431
- acc[1] += yl[i] + yl[i+16];
432
- }
433
- if (tiisg < nb % N_SIMDWIDTH) {
434
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
435
- }
436
- qb_curr = qb_next;
437
-
438
- all_sum = simd_sum(sumf[row]);
439
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
440
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
470
+ for (int row = 0; row < nr; ++row) {
471
+ const float tot = simd_sum(sumf[row]);
472
+ if (tiisg == 0 && first_row + row < ne01) {
473
+ dst[r1*ne0 + first_row + row] = tot;
441
474
  }
442
475
  }
443
476
  }
444
477
 
445
- kernel void kernel_mul_mat_q4_1_f32(
478
+ kernel void kernel_mul_mat_q4_0_f32(
446
479
  device const void * src0,
447
480
  device const float * src1,
448
481
  device float * dst,
449
482
  constant int64_t & ne00,
450
483
  constant int64_t & ne10,
451
484
  constant int64_t & ne0,
452
- threadgroup float * sum [[threadgroup(0)]],
485
+ constant int64_t & ne01[[buffer(4)]],
453
486
  uint2 tgpig[[threadgroup_position_in_grid]],
454
- uint2 tpitg[[thread_position_in_threadgroup]],
455
- uint2 tptg[[threads_per_threadgroup]]) {
456
- const int nb = ne00/QK4_1;
457
-
458
- const int64_t r0 = tgpig.x;
459
- const int64_t r1 = tgpig.y;
460
-
461
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
462
- device const float * y = (device const float *) src1 + r1*ne10;
463
-
464
- const uint nth = tptg.x*tptg.y;
465
- const uint ith = tptg.y*tpitg.x + tpitg.y;
466
-
467
- const int ix = tpitg.y/4; // 0 or 1
468
- const int iy = tpitg.y - 4*ix; // 0...3
469
-
470
- const int first = 4 * iy;
471
-
472
- float sumf = 0;
473
-
474
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
475
-
476
- const float d = (float)x[i].d;
477
- const float m = (float)x[i].m;
478
-
479
- device const uint8_t * xl = x[i].qs + first;
480
- device const float * yl = y + i * QK4_1 + first;
481
-
482
- float2 acc = {0.0f, 0.0f};
483
-
484
- for (int j = 0; j < 4; ++j) {
485
-
486
- acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
487
- acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
488
-
489
- }
490
-
491
- sumf += acc[0] + acc[1];
492
- }
493
-
494
- sum[ith] = sumf;
487
+ uint tiisg[[thread_index_in_simdgroup]],
488
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
490
+ }
495
491
 
496
- //
497
- // Accumulate the sum from all threads in the threadgroup
498
- //
499
- threadgroup_barrier(mem_flags::mem_threadgroup);
500
- if (ith%4 == 0) {
501
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
502
- }
503
- threadgroup_barrier(mem_flags::mem_threadgroup);
504
- if (ith%16 == 0) {
505
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
506
- }
507
- threadgroup_barrier(mem_flags::mem_threadgroup);
508
- if (ith == 0) {
509
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
510
- dst[r1*ne0 + r0] = sum[0];
511
- }
492
+ kernel void kernel_mul_mat_q4_1_f32(
493
+ device const void * src0,
494
+ device const float * src1,
495
+ device float * dst,
496
+ constant int64_t & ne00,
497
+ constant int64_t & ne10,
498
+ constant int64_t & ne0,
499
+ constant int64_t & ne01[[buffer(4)]],
500
+ uint2 tgpig[[threadgroup_position_in_grid]],
501
+ uint tiisg[[thread_index_in_simdgroup]],
502
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
503
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
512
504
  }
513
505
 
514
506
  kernel void kernel_mul_mat_f16_f32(
@@ -624,17 +616,19 @@ kernel void kernel_rope(
624
616
  constant int & n_past,
625
617
  constant int & n_dims,
626
618
  constant int & mode,
619
+ constant float & freq_base,
620
+ constant float & freq_scale,
627
621
  uint3 tpig[[thread_position_in_grid]]) {
628
622
  const int64_t i3 = tpig[2];
629
623
  const int64_t i2 = tpig[1];
630
624
  const int64_t i1 = tpig[0];
631
625
 
632
626
  const bool is_neox = mode & 2;
633
- const float theta_scale = pow(10000.0, -2.0f/n_dims);
627
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
634
628
 
635
629
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
636
630
 
637
- float theta = (float)p;
631
+ float theta = freq_scale * (float)p;
638
632
 
639
633
  if (!is_neox) {
640
634
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -1229,111 +1223,137 @@ kernel void kernel_mul_mat_q2_K_f32(
1229
1223
  constant int64_t & ne00,
1230
1224
  constant int64_t & ne10,
1231
1225
  constant int64_t & ne0,
1232
- threadgroup float * sum [[threadgroup(0)]],
1226
+ constant int64_t & ne01[[buffer(4)]],
1233
1227
  uint2 tgpig[[threadgroup_position_in_grid]],
1234
- uint2 tpitg[[thread_position_in_threadgroup]],
1235
- uint2 tptg[[threads_per_threadgroup]]) {
1228
+ uint tiisg[[thread_index_in_simdgroup]],
1229
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1236
1230
 
1237
1231
  const int nb = ne00/QK_K;
1232
+ const int r0 = tgpig.x;
1233
+ const int r1 = tgpig.y;
1238
1234
 
1239
- const int64_t r0 = tgpig.x;
1240
- const int64_t r1 = tgpig.y;
1241
-
1242
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1243
- device const float * yy = (device const float *) src1 + r1*ne10;
1244
-
1245
- const int nth = tptg.x*tptg.y;
1246
- const int ith = tptg.y*tpitg.x + tpitg.y;
1235
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1236
+ const int ib_row = first_row * nb;
1237
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
1238
+ device const float * y = (device const float *) src1 + r1*ne10;
1239
+ float yl[32];
1240
+ float sumf[N_DST]={0.f}, all_sum;
1247
1241
 
1248
- float sumf = 0;
1242
+ const int step = sizeof(block_q2_K) * nb;
1249
1243
 
1250
1244
  #if QK_K == 256
1251
- const int tid = tpitg.y; // 0...16
1252
- const int il = tid/4; // 0...3
1253
- const int ir = tid%4; // 0...3
1254
- const int ip = il/2; // 0 or 1
1255
- const int shift1 = 4*(il%2);// 0 or 4
1256
- const int shift2 = shift1+2;// 2 or 6
1257
- const int n = 8;
1258
- const int is = 4*il + (n*ir)/16;
1259
-
1260
- const int y_offset = 64*il + n*ir;
1261
- const int q_offset = 32*ip + n*ir;
1262
-
1263
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1264
-
1265
- device const uint8_t * q = x[i].qs + q_offset;
1266
- device const uint8_t * scales = x[i].scales + is;
1245
+ const int ix = tiisg/8; // 0...3
1246
+ const int it = tiisg%8; // 0...7
1247
+ const int im = it/4; // 0 or 1
1248
+ const int ir = it%4; // 0...3
1249
+ const int is = (8*ir)/16;// 0 or 1
1250
+
1251
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
1252
+
1253
+ for (int ib = ix; ib < nb; ib += 4) {
1254
+
1255
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1256
+ for (int i = 0; i < 8; ++i) {
1257
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1258
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
1259
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
1260
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1261
+ }
1267
1262
 
1268
- uint8_t d1 = scales[0] & 0xF;
1269
- uint8_t d2 = scales[2] & 0xF;
1270
- uint8_t m1 = scales[0] >> 4;
1271
- uint8_t m2 = scales[2] >> 4;
1263
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1264
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1265
+ device const half * dh = &x[ib].d;
1272
1266
 
1273
- device const float * y = yy + i*QK_K + y_offset;
1267
+ for (int row = 0; row < N_DST; row++) {
1274
1268
 
1275
- float2 s = {0.f, 0.f};
1276
- float smin = 0;
1277
- for (int l = 0; l < n; ++l) {
1278
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1279
- s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1280
- smin += y[l+ 0] * m1 + y[l+32] * m2;
1269
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1270
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1271
+ for (int i = 0; i < 8; i += 2) {
1272
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1273
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1274
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1275
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1276
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1277
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1278
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1279
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1280
+ }
1281
+ float dall = dh[0];
1282
+ float dmin = dh[1] * 1.f/16.f;
1283
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1284
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
1285
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
1286
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
1287
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
1288
+
1289
+ qs += step/2;
1290
+ sc += step;
1291
+ dh += step/2;
1281
1292
  }
1282
1293
 
1283
- const float dall = (float)x[i].d;
1284
- const float dmin = (float)x[i].dmin;
1285
-
1286
- sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1287
-
1294
+ y4 += 4 * QK_K;
1288
1295
  }
1289
1296
  #else
1290
- const int il = 4 * tpitg.x;
1297
+ const int ix = tiisg/2; // 0...15
1298
+ const int it = tiisg%2; // 0...1
1291
1299
 
1292
- uint32_t aux[2];
1293
- thread const uint8_t * d = (thread const uint8_t *)aux;
1294
- thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1300
+ device const float * y4 = y + ix * QK_K + 8 * it;
1295
1301
 
1296
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1302
+ for (int ib = ix; ib < nb; ib += 16) {
1297
1303
 
1298
- device const uint8_t * q = x[i].qs + il;
1299
- device const float * y = yy + i*QK_K + il;
1304
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1305
+ for (int i = 0; i < 8; ++i) {
1306
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
1307
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
1308
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
1309
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
1310
+ }
1311
+
1312
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
1313
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1314
+ device const half * dh = &x[ib].d;
1300
1315
 
1301
- const float dall = (float)x[i].d;
1302
- const float dmin = (float)x[i].dmin;
1316
+ for (int row = 0; row < N_DST; row++) {
1303
1317
 
1304
- device const uint32_t * a = (device const uint32_t *)x[i].scales;
1305
- aux[0] = a[0] & 0x0f0f0f0f;
1306
- aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1318
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1319
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1320
+ for (int i = 0; i < 8; i += 2) {
1321
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
1322
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
1323
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
1324
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
1325
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
1326
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
1327
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1328
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1329
+ }
1307
1330
 
1308
- for (int l = 0; l < 4; ++l) {
1309
- sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1310
- + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1311
- + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1312
- + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1331
+ float dall = dh[0];
1332
+ float dmin = dh[1];
1333
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1334
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1335
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1336
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1337
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1338
+
1339
+ qs += step/2;
1340
+ sc += step;
1341
+ dh += step/2;
1313
1342
  }
1343
+
1344
+ y4 += 16 * QK_K;
1314
1345
  }
1315
1346
  #endif
1316
1347
 
1317
- sum[ith] = sumf;
1318
-
1319
- //
1320
- // Accumulate the sum from all threads in the threadgroup
1321
- //
1322
- threadgroup_barrier(mem_flags::mem_threadgroup);
1323
- if (ith%4 == 0) {
1324
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1325
- }
1326
- threadgroup_barrier(mem_flags::mem_threadgroup);
1327
- if (ith%16 == 0) {
1328
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1329
- }
1330
- threadgroup_barrier(mem_flags::mem_threadgroup);
1331
- if (ith == 0) {
1332
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1333
- dst[r1*ne0 + r0] = sum[0];
1348
+ for (int row = 0; row < N_DST; ++row) {
1349
+ all_sum = simd_sum(sumf[row]);
1350
+ if (tiisg == 0) {
1351
+ dst[r1*ne0 + first_row + row] = all_sum;
1352
+ }
1334
1353
  }
1335
1354
  }
1336
1355
 
1356
+ #if QK_K == 256
1337
1357
  kernel void kernel_mul_mat_q3_K_f32(
1338
1358
  device const void * src0,
1339
1359
  device const float * src1,
@@ -1342,40 +1362,41 @@ kernel void kernel_mul_mat_q3_K_f32(
1342
1362
  constant int64_t & ne10,
1343
1363
  constant int64_t & ne0,
1344
1364
  constant int64_t & ne1,
1345
- threadgroup float * sum [[threadgroup(0)]],
1346
1365
  uint2 tgpig[[threadgroup_position_in_grid]],
1347
- uint2 tpitg[[thread_position_in_threadgroup]],
1348
- uint2 tptg[[threads_per_threadgroup]]) {
1366
+ uint tiisg[[thread_index_in_simdgroup]],
1367
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1349
1368
 
1350
1369
  const int nb = ne00/QK_K;
1351
1370
 
1352
1371
  const int64_t r0 = tgpig.x;
1353
1372
  const int64_t r1 = tgpig.y;
1354
1373
 
1355
- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1356
- device const float * yy = (device const float *) src1 + r1*ne10;
1357
-
1358
- const int nth = tptg.x*tptg.y;
1359
- const int ith = tptg.y*tpitg.x + tpitg.y;
1374
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1360
1375
 
1361
- #if QK_K == 256
1376
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1377
+ device const float * yy = (device const float *) src1 + r1*ne10;
1362
1378
 
1363
- const uint8_t m3 = 3;
1364
- const int8_t m4 = 4;
1379
+ float yl[16];
1365
1380
 
1366
1381
  const uint16_t kmask1 = 0x0303;
1367
1382
  const uint16_t kmask2 = 0x0f0f;
1368
1383
 
1369
- const int tid = tpitg.y; // expecting 16
1384
+ const int tid = tiisg/2;
1385
+ const int ix = tiisg%2;
1370
1386
  const int ip = tid/8; // 0 or 1
1371
1387
  const int il = tid/2 - 4*ip; // 0...3
1372
1388
  const int ir = tid%2;
1373
1389
  const int n = 8;
1374
1390
  const int l0 = n*ir;
1375
1391
 
1376
- const uint8_t m = 1 << (4*ip + il);
1392
+ const uint16_t m1 = 1 << (4*ip + il);
1393
+ const uint16_t m2 = m1 << 8;
1377
1394
 
1378
1395
  const int shift = 2*il;
1396
+ const uint16_t qm1 = 0x0003 << shift;
1397
+ const uint16_t qm2 = 0x0300 << shift;
1398
+ const int32_t v1 = 4 << shift;
1399
+ const int32_t v2 = 1024 << shift;
1379
1400
 
1380
1401
  const uint16_t s_shift1 = 4*ip;
1381
1402
  const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1384,226 +1405,315 @@ kernel void kernel_mul_mat_q3_K_f32(
1384
1405
  const int q_offset = 32*ip + l0;
1385
1406
  const int y_offset = 128*ip + 32*il + l0;
1386
1407
 
1387
- //float sumf = 0;
1388
- float sumf1 = 0, sumf2 = 0;
1389
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1408
+ const int step = sizeof(block_q3_K) * nb / 2;
1390
1409
 
1391
- const float d_all = (float)(x[i].d);
1392
-
1393
- device const uint8_t * q = x[i].qs + q_offset;
1394
- device const uint8_t * h = x[i].hmask + l0;
1395
- device const float * y = yy + i * QK_K + y_offset;
1410
+ device const float * y1 = yy + ix*QK_K + y_offset;
1396
1411
 
1397
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1398
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1412
+ float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1413
+ for (int i = ix; i < nb; i += 2) {
1399
1414
 
1400
- float s = 0;
1401
- for (int l = 0; l < n; ++l) {
1402
- s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1403
- }
1404
- float d = d_all * s;
1405
- sumf1 += d * scales[0];
1406
- sumf2 += d;
1407
- //sumf += d_all * s * (scales[0] - 32);
1408
-
1409
- s = 0;
1410
- for (int l = 0; l < n; ++l) {
1411
- s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1415
+ for (int l = 0; l < 8; ++l) {
1416
+ yl[l+0] = y1[l+ 0];
1417
+ yl[l+8] = y1[l+16];
1412
1418
  }
1413
- d = d_all * s;
1414
- sumf1 += d * scales[1];
1415
- sumf2 += d;
1416
- //sumf += d_all * s * (scales[1] - 32);
1417
-
1418
- }
1419
-
1420
- //sum[ith] = sumf;
1421
- sum[ith] = sumf1 - 32.f*sumf2;
1422
- #else
1423
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1424
- const int im = il/8; // 0, 0, 1, 1
1425
- const int in = il%8; // 0, 4, 0, 4
1426
1419
 
1427
- float sumf = 0;
1420
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1421
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1422
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1423
+ device const half * dh = &x[i].d;
1428
1424
 
1429
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1425
+ for (int row = 0; row < 2; ++row) {
1430
1426
 
1431
- const float d_all = (float)(x[i].d);
1427
+ const float d_all = (float)dh[0];
1428
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1432
1429
 
1433
- device const uint8_t * q = x[i].qs + il;
1434
- device const uint8_t * h = x[i].hmask + in;
1435
- device const float * y = yy + i * QK_K + il;
1430
+ float s1 = 0, s2 = 0;
1431
+ for (int l = 0; l < n; l += 2) {
1432
+ const uint16_t qs = q[l/2];
1433
+ s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1434
+ s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1435
+ }
1436
+ float d = d_all * (s1 + 1.f/256.f * s2);
1437
+ sumf1[row] += d * scales[0];
1438
+ sumf2[row] += d;
1439
+
1440
+ s1 = s2 = 0;
1441
+ for (int l = 0; l < n; l += 2) {
1442
+ const uint16_t qs = q[l/2+8];
1443
+ s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1444
+ s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1445
+ }
1446
+ d = d_all * (s1 + 1.f/256.f * s2);
1447
+ sumf1[row] += d * scales[1];
1448
+ sumf2[row] += d;
1436
1449
 
1437
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1438
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1439
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1440
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1450
+ q += step;
1451
+ h += step;
1452
+ a += step;
1453
+ dh += step;
1441
1454
 
1442
- for (int l = 0; l < 4; ++l) {
1443
- const uint8_t hm = h[l] >> im;
1444
- sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1445
- + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1446
- + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1447
- + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1448
1455
  }
1449
1456
 
1450
- }
1451
-
1452
- sum[ith] = sumf;
1457
+ y1 += 2 * QK_K;
1453
1458
 
1454
- #endif
1455
-
1456
- //
1457
- // Accumulate the sum from all threads in the threadgroup
1458
- //
1459
- threadgroup_barrier(mem_flags::mem_threadgroup);
1460
- if (ith%4 == 0) {
1461
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1462
- }
1463
- threadgroup_barrier(mem_flags::mem_threadgroup);
1464
- if (ith%16 == 0) {
1465
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1466
- }
1467
- threadgroup_barrier(mem_flags::mem_threadgroup);
1468
- if (ith == 0) {
1469
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1470
- dst[r1*ne0 + r0] = sum[0];
1471
1459
  }
1472
1460
 
1461
+ for (int row = 0; row < 2; ++row) {
1462
+ const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1463
+ const float tot = simd_sum(sumf);
1464
+ if (tiisg == 0) {
1465
+ dst[r1*ne0 + first_row + row] = tot;
1466
+ }
1467
+ }
1473
1468
  }
1474
-
1475
- kernel void kernel_mul_mat_q4_K_f32(
1469
+ #else
1470
+ kernel void kernel_mul_mat_q3_K_f32(
1476
1471
  device const void * src0,
1477
1472
  device const float * src1,
1478
1473
  device float * dst,
1479
1474
  constant int64_t & ne00,
1480
1475
  constant int64_t & ne10,
1481
1476
  constant int64_t & ne0,
1482
- threadgroup float * sum [[threadgroup(0)]],
1477
+ constant int64_t & ne1,
1483
1478
  uint2 tgpig[[threadgroup_position_in_grid]],
1484
- uint2 tpitg[[thread_position_in_threadgroup]],
1485
- uint2 tptg[[threads_per_threadgroup]]) {
1479
+ uint tiisg[[thread_index_in_simdgroup]],
1480
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1486
1481
 
1487
1482
  const int nb = ne00/QK_K;
1488
1483
 
1489
1484
  const int64_t r0 = tgpig.x;
1490
1485
  const int64_t r1 = tgpig.y;
1491
1486
 
1492
- const int nth = tptg.x*tptg.y;
1493
- const int ith = tptg.y*tpitg.x + tpitg.y;
1487
+ const int row = 2 * r0 + sgitg;
1494
1488
 
1495
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1489
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1496
1490
  device const float * yy = (device const float *) src1 + r1*ne10;
1491
+ const int ix = tiisg/4;
1492
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1493
+ const int im = il/8; // 0, 0, 1, 1
1494
+ const int in = il%8; // 0, 4, 0, 4
1497
1495
 
1498
- float sumf = 0;
1496
+ float2 sum = {0.f, 0.f};
1497
+
1498
+ for (int i = ix; i < nb; i += 8) {
1499
+
1500
+ const float d_all = (float)(x[i].d);
1501
+
1502
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1503
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1504
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1505
+ device const float * y = yy + i * QK_K + il;
1506
+
1507
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1508
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1509
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1510
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1511
+
1512
+ for (int l = 0; l < 4; l += 2) {
1513
+ const uint16_t hm = h[l/2] >> im;
1514
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1515
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1516
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1517
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1518
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1519
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1520
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1521
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
1522
+ }
1523
+
1524
+ }
1525
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
1526
+
1527
+ const float tot = simd_sum(sumf);
1528
+ if (tiisg == 0) {
1529
+ dst[r1*ne0 + row] = tot;
1530
+ }
1531
+
1532
+ }
1533
+ #endif
1499
1534
 
1500
1535
  #if QK_K == 256
1536
+ kernel void kernel_mul_mat_q4_K_f32(
1537
+ device const void * src0,
1538
+ device const float * src1,
1539
+ device float * dst,
1540
+ constant int64_t & ne00,
1541
+ constant int64_t & ne10,
1542
+ constant int64_t & ne0,
1543
+ constant int64_t & ne01[[buffer(4)]],
1544
+ uint2 tgpig[[threadgroup_position_in_grid]],
1545
+ uint tiisg[[thread_index_in_simdgroup]],
1546
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1501
1547
 
1502
1548
  const uint16_t kmask1 = 0x3f3f;
1503
1549
  const uint16_t kmask2 = 0x0f0f;
1504
1550
  const uint16_t kmask3 = 0xc0c0;
1505
1551
 
1506
- const int tid = tpitg.y; // 0...16
1507
- const int il = tid/4; // 0...3
1508
- const int ir = tid - 4*il;// 0...3
1509
- const int n = 4;
1552
+ const int ix = tiisg/8; // 0...3
1553
+ const int it = tiisg%8; // 0...7
1554
+ const int im = it/4; // 0 or 1
1555
+ const int ir = it%4; // 0...3
1510
1556
 
1511
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1512
- const int in = il%2;
1557
+ const int nb = ne00/QK_K;
1558
+ const int r0 = tgpig.x;
1559
+ const int r1 = tgpig.y;
1560
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1561
+ const int ib_row = first_row * nb;
1562
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1563
+ device const float * y = (device const float *) src1 + r1*ne10;
1564
+ float yl[16];
1565
+ float yh[16];
1566
+ float sumf[N_DST]={0.f}, all_sum;
1513
1567
 
1514
- const int l0 = n*(2*ir + in);
1515
- const int q_offset = 32*im + l0;
1516
- const int y_offset = 64*im + l0;
1568
+ const int step = sizeof(block_q4_K) * nb / 2;
1517
1569
 
1518
- uchar2 sc1, sc2, sc3, sc4;
1570
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
1519
1571
 
1520
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1572
+ uint16_t sc16[4];
1573
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1521
1574
 
1522
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1523
- device const uint8_t * q2 = q1 + 64;
1524
- device const float * y1 = yy + i*QK_K + y_offset;
1525
- device const float * y2 = y1 + 128;
1575
+ for (int ib = ix; ib < nb; ib += 4) {
1526
1576
 
1527
- const float dall = (float)((x + i)->d);
1528
- const float dmin = (float)((x + i)->dmin);
1577
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1578
+ for (int i = 0; i < 8; ++i) {
1579
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
1580
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
1581
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
1582
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1583
+ }
1529
1584
 
1530
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1531
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1532
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1533
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1534
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1585
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1586
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1587
+ device const half * dh = &x[ib].d;
1535
1588
 
1536
- float4 s = {0.f, 0.f, 0.f, 0.f};
1537
- float smin = 0;
1538
- for (int l = 0; l < n; ++l) {
1589
+ for (int row = 0; row < N_DST; row++) {
1539
1590
 
1540
- s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1541
- s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1542
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1591
+ sc16[0] = sc[0] & kmask1;
1592
+ sc16[1] = sc[2] & kmask1;
1593
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
1594
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
1595
+
1596
+ device const uint16_t * q2 = q1 + 32;
1597
+
1598
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1599
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1600
+ for (int i = 0; i < 8; i += 2) {
1601
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
1602
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
1603
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
1604
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
1605
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
1606
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
1607
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
1608
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
1609
+ }
1543
1610
 
1611
+ float dall = dh[0];
1612
+ float dmin = dh[1];
1613
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
1614
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
1615
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
1616
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
1617
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1618
+
1619
+ q1 += step;
1620
+ sc += step;
1621
+ dh += step;
1544
1622
  }
1545
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1546
1623
 
1624
+ y4 += 4 * QK_K;
1625
+ }
1626
+
1627
+ for (int row = 0; row < N_DST; ++row) {
1628
+ all_sum = simd_sum(sumf[row]);
1629
+ if (tiisg == 0) {
1630
+ dst[r1*ne0 + first_row + row] = all_sum;
1631
+ }
1547
1632
  }
1633
+ }
1548
1634
  #else
1549
- uint16_t aux16[2];
1550
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
1635
+ kernel void kernel_mul_mat_q4_K_f32(
1636
+ device const void * src0,
1637
+ device const float * src1,
1638
+ device float * dst,
1639
+ constant int64_t & ne00,
1640
+ constant int64_t & ne10,
1641
+ constant int64_t & ne0,
1642
+ constant int64_t & ne01[[buffer(4)]],
1643
+ uint2 tgpig[[threadgroup_position_in_grid]],
1644
+ uint tiisg[[thread_index_in_simdgroup]],
1645
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1551
1646
 
1552
- const int il = 4*tpitg.x;
1647
+ const int ix = tiisg/4; // 0...7
1648
+ const int it = tiisg%4; // 0...3
1553
1649
 
1554
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1650
+ const int nb = ne00/QK_K;
1651
+ const int r0 = tgpig.x;
1652
+ const int r1 = tgpig.y;
1653
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1654
+ const int ib_row = first_row * nb;
1655
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
1656
+ device const float * y = (device const float *) src1 + r1*ne10;
1657
+ float yl[8];
1658
+ float yh[8];
1659
+ float sumf[N_DST]={0.f}, all_sum;
1555
1660
 
1556
- device const uint8_t * q = x[i].qs + il;
1557
- device const float * y = yy + i * QK_K + il;
1661
+ const int step = sizeof(block_q4_K) * nb / 2;
1558
1662
 
1559
- const float d = (float)x[i].d[0];
1560
- const float m = (float)x[i].d[1];
1663
+ device const float * y4 = y + ix * QK_K + 8 * it;
1561
1664
 
1562
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1563
- aux16[0] = a[0] & 0x0f0f;
1564
- aux16[1] = (a[0] >> 4) & 0x0f0f;
1665
+ uint16_t sc16[4];
1565
1666
 
1566
- for (int l = 0; l < 4; ++l) {
1567
- sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1568
- + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1667
+ for (int ib = ix; ib < nb; ib += 8) {
1668
+
1669
+ float2 sumy = {0.f, 0.f};
1670
+ for (int i = 0; i < 8; ++i) {
1671
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
1672
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
1569
1673
  }
1570
- }
1571
- #endif
1572
1674
 
1573
- sum[ith] = sumf;
1675
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
1676
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1677
+ device const half * dh = x[ib].d;
1574
1678
 
1575
- //
1576
- // Accumulate the sum from all threads in the threadgroup
1577
- // This version is slightly faster than the commented out one below,
1578
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1579
- //
1580
- threadgroup_barrier(mem_flags::mem_threadgroup);
1581
- if (ith%4 == 0) {
1582
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1583
- }
1584
- threadgroup_barrier(mem_flags::mem_threadgroup);
1585
- if (ith%16 == 0) {
1586
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1587
- }
1588
- threadgroup_barrier(mem_flags::mem_threadgroup);
1589
- if (ith == 0) {
1590
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1591
- dst[r1*ne0 + r0] = sum[0];
1679
+ for (int row = 0; row < N_DST; row++) {
1680
+
1681
+ sc16[0] = sc[0] & 0x000f;
1682
+ sc16[1] = sc[0] & 0x0f00;
1683
+ sc16[2] = sc[0] & 0x00f0;
1684
+ sc16[3] = sc[0] & 0xf000;
1685
+
1686
+ float2 acc1 = {0.f, 0.f};
1687
+ float2 acc2 = {0.f, 0.f};
1688
+ for (int i = 0; i < 8; i += 2) {
1689
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
1690
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
1691
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
1692
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
1693
+ }
1694
+
1695
+ float dall = dh[0];
1696
+ float dmin = dh[1];
1697
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
1698
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
1699
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
1700
+
1701
+ qs += step;
1702
+ sc += step;
1703
+ dh += step;
1704
+ }
1705
+
1706
+ y4 += 8 * QK_K;
1592
1707
  }
1593
1708
 
1594
- //// accumulate the sum from all threads in the threadgroup
1595
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1596
- //for (uint i = nth/2; i > 0; i /= 2) {
1597
- // if (ith < i) {
1598
- // sum[ith] += sum[ith + i];
1599
- // }
1600
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1601
- //}
1602
-
1603
- //if (ith == 0) {
1604
- // dst[r1*ne0 + r0] = sum[0];
1605
- //}
1709
+ for (int row = 0; row < N_DST; ++row) {
1710
+ all_sum = simd_sum(sumf[row]);
1711
+ if (tiisg == 0) {
1712
+ dst[r1*ne0 + first_row + row] = all_sum;
1713
+ }
1714
+ }
1606
1715
  }
1716
+ #endif
1607
1717
 
1608
1718
  kernel void kernel_mul_mat_q5_K_f32(
1609
1719
  device const void * src0,
@@ -1612,39 +1722,39 @@ kernel void kernel_mul_mat_q5_K_f32(
1612
1722
  constant int64_t & ne00,
1613
1723
  constant int64_t & ne10,
1614
1724
  constant int64_t & ne0,
1615
- threadgroup float * sum [[threadgroup(0)]],
1616
1725
  uint2 tgpig[[threadgroup_position_in_grid]],
1617
- uint2 tpitg[[thread_position_in_threadgroup]],
1618
- uint2 tptg[[threads_per_threadgroup]]) {
1726
+ uint tiisg[[thread_index_in_simdgroup]],
1727
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1619
1728
 
1620
1729
  const int nb = ne00/QK_K;
1621
1730
 
1622
1731
  const int64_t r0 = tgpig.x;
1623
1732
  const int64_t r1 = tgpig.y;
1624
1733
 
1625
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1734
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1735
+
1736
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1626
1737
  device const float * yy = (device const float *) src1 + r1*ne10;
1627
1738
 
1628
- const int nth = tptg.x*tptg.y;
1629
- const int ith = tptg.y*tpitg.x + tpitg.y;
1739
+ float sumf[2]={0.f};
1630
1740
 
1631
- float sumf = 0;
1741
+ const int step = sizeof(block_q5_K) * nb;
1632
1742
 
1633
1743
  #if QK_K == 256
1744
+ #
1745
+ float yl[16], yh[16];
1634
1746
 
1635
1747
  const uint16_t kmask1 = 0x3f3f;
1636
1748
  const uint16_t kmask2 = 0x0f0f;
1637
1749
  const uint16_t kmask3 = 0xc0c0;
1638
1750
 
1639
- const int tid = tpitg.y; // 0...16
1640
- const int il = tid/4; // 0...3
1641
- const int ir = tid - 4*il;// 0...3
1642
- const int n = 4;
1643
-
1644
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1645
- const int in = il%2;
1751
+ const int tid = tiisg/4;
1752
+ const int ix = tiisg%4;
1753
+ const int im = tid/4;
1754
+ const int ir = tid%4;
1755
+ const int n = 8;
1646
1756
 
1647
- const int l0 = n*(2*ir + in);
1757
+ const int l0 = n*ir;
1648
1758
  const int q_offset = 32*im + l0;
1649
1759
  const int y_offset = 64*im + l0;
1650
1760
 
@@ -1653,78 +1763,113 @@ kernel void kernel_mul_mat_q5_K_f32(
1653
1763
  const uint8_t hm3 = hm1 << 4;
1654
1764
  const uint8_t hm4 = hm2 << 4;
1655
1765
 
1656
- uchar2 sc1, sc2, sc3, sc4;
1766
+ uint16_t sc16[4];
1767
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1657
1768
 
1658
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1769
+ device const float * y1 = yy + ix*QK_K + y_offset;
1659
1770
 
1660
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1661
- device const uint8_t * q2 = q1 + 64;
1662
- device const uint8_t * qh = (x + i)->qh + l0;
1663
- device const float * y1 = yy + i*QK_K + y_offset;
1664
- device const float * y2 = y1 + 128;
1771
+ for (int i = ix; i < nb; i += 4) {
1665
1772
 
1666
- const float dall = (float)((x + i)->d);
1667
- const float dmin = (float)((x + i)->dmin);
1773
+ device const uint8_t * q1 = x[i].qs + q_offset;
1774
+ device const uint8_t * qh = x[i].qh + l0;
1775
+ device const half * dh = &x[i].d;
1776
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1668
1777
 
1669
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1670
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1671
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1672
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1673
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1778
+ device const float * y2 = y1 + 128;
1779
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1780
+ for (int l = 0; l < 8; ++l) {
1781
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1782
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1783
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1784
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1785
+ }
1674
1786
 
1675
- float4 s = {0.f, 0.f, 0.f, 0.f};
1676
- float smin = 0;
1677
- for (int l = 0; l < n; ++l) {
1787
+ for (int row = 0; row < 2; ++row) {
1788
+
1789
+ device const uint8_t * q2 = q1 + 64;
1678
1790
 
1679
- s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1680
- s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1681
- s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1682
- s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1683
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1791
+ sc16[0] = a[0] & kmask1;
1792
+ sc16[1] = a[2] & kmask1;
1793
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1794
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1795
+
1796
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
1797
+ for (int l = 0; l < n; ++l) {
1798
+ uint8_t h = qh[l];
1799
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1800
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1801
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1802
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1803
+ }
1804
+ const float dall = dh[0];
1805
+ const float dmin = dh[1];
1806
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1807
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1808
+
1809
+ q1 += step;
1810
+ qh += step;
1811
+ dh += step/2;
1812
+ a += step/2;
1684
1813
 
1685
1814
  }
1686
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1815
+
1816
+ y1 += 4 * QK_K;
1687
1817
 
1688
1818
  }
1689
1819
  #else
1690
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1691
- const int im = il/8; // 0, 0, 1, 1
1692
- const int in = il%8; // 0, 4, 0, 4
1820
+ float yl[8], yh[8];
1821
+
1822
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1823
+ const int ix = tiisg%8;
1824
+ const int im = il/8; // 0, 0, 1, 1
1825
+ const int in = il%8; // 0, 4, 0, 4
1693
1826
 
1694
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1827
+ device const float * y = yy + ix*QK_K + il;
1695
1828
 
1696
- const float d = (float)x[i].d;
1829
+ for (int i = ix; i < nb; i += 8) {
1830
+
1831
+ for (int l = 0; l < 4; ++l) {
1832
+ yl[l+0] = y[l+ 0];
1833
+ yl[l+4] = y[l+16];
1834
+ yh[l+0] = y[l+32];
1835
+ yh[l+4] = y[l+48];
1836
+ }
1837
+
1838
+ device const half * dh = &x[i].d;
1697
1839
  device const uint8_t * q = x[i].qs + il;
1698
1840
  device const uint8_t * h = x[i].qh + in;
1699
1841
  device const int8_t * s = x[i].scales;
1700
- device const float * y = yy + i*QK_K + il;
1701
1842
 
1702
- for (int l = 0; l < 4; ++l) {
1703
- const uint8_t hl = h[l] >> im;
1704
- sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1705
- + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1706
- + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1707
- + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1843
+ for (int row = 0; row < 2; ++row) {
1844
+
1845
+ const float d = dh[0];
1846
+
1847
+ float2 acc = {0.f, 0.f};
1848
+ for (int l = 0; l < 4; ++l) {
1849
+ const uint8_t hl = h[l] >> im;
1850
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1851
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1852
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1853
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1854
+ }
1855
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1856
+
1857
+ q += step;
1858
+ h += step;
1859
+ s += step;
1860
+ dh += step/2;
1861
+
1708
1862
  }
1863
+
1864
+ y += 8 * QK_K;
1709
1865
  }
1710
1866
  #endif
1711
- sum[ith] = sumf;
1712
1867
 
1713
- //
1714
- // Accumulate the sum from all threads in the threadgroup
1715
- //
1716
- threadgroup_barrier(mem_flags::mem_threadgroup);
1717
- if (ith%4 == 0) {
1718
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1719
- }
1720
- threadgroup_barrier(mem_flags::mem_threadgroup);
1721
- if (ith%16 == 0) {
1722
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1723
- }
1724
- threadgroup_barrier(mem_flags::mem_threadgroup);
1725
- if (ith == 0) {
1726
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1727
- dst[r1*ne0 + r0] = sum[0];
1868
+ for (int row = 0; row < 2; ++row) {
1869
+ const float tot = simd_sum(sumf[row]);
1870
+ if (tiisg == 0) {
1871
+ dst[r1*ne0 + first_row + row] = tot;
1872
+ }
1728
1873
  }
1729
1874
 
1730
1875
  }
@@ -1736,10 +1881,9 @@ kernel void kernel_mul_mat_q6_K_f32(
1736
1881
  constant int64_t & ne00,
1737
1882
  constant int64_t & ne10,
1738
1883
  constant int64_t & ne0,
1739
- threadgroup float * sum [[threadgroup(0)]],
1740
1884
  uint2 tgpig[[threadgroup_position_in_grid]],
1741
- uint2 tpitg[[thread_position_in_threadgroup]],
1742
- uint2 tptg[[threads_per_threadgroup]]) {
1885
+ uint tiisg[[thread_index_in_simdgroup]],
1886
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1743
1887
 
1744
1888
  const uint8_t kmask1 = 0x03;
1745
1889
  const uint8_t kmask2 = 0x0C;
@@ -1751,19 +1895,18 @@ kernel void kernel_mul_mat_q6_K_f32(
1751
1895
  const int64_t r0 = tgpig.x;
1752
1896
  const int64_t r1 = tgpig.y;
1753
1897
 
1754
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1755
- device const float * yy = (device const float *) src1 + r1*ne10;
1898
+ const int row = 2 * r0 + sgitg;
1756
1899
 
1757
- const int nth = tptg.x*tptg.y;
1758
- const int ith = tptg.y*tpitg.x + tpitg.y;
1900
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
1901
+ device const float * yy = (device const float *) src1 + r1*ne10;
1759
1902
 
1760
1903
  float sumf = 0;
1761
1904
 
1762
1905
  #if QK_K == 256
1763
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1764
- const int iqs = 16 * tpitg.y;
1765
- const int ip = iqs / 128; // 0 or 1
1766
- const int il = (iqs - 128*ip)/16; // 0...7
1906
+ const int tid = tiisg/2;
1907
+ const int ix = tiisg%2;
1908
+ const int ip = tid/8; // 0 or 1
1909
+ const int il = tid%8;
1767
1910
  const int n = 4;
1768
1911
  const int l0 = n*il;
1769
1912
  const int is = 8*ip + l0/16;
@@ -1772,9 +1915,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1772
1915
  const int q_offset_l = 64*ip + l0;
1773
1916
  const int q_offset_h = 32*ip + l0;
1774
1917
 
1775
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1918
+ for (int i = ix; i < nb; i += 2) {
1776
1919
 
1777
- device const uint8_t * ql = x[i].ql + q_offset_l;
1920
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1921
+ device const uint8_t * q2 = q1 + 32;
1778
1922
  device const uint8_t * qh = x[i].qh + q_offset_h;
1779
1923
  device const int8_t * sc = x[i].scales + is;
1780
1924
 
@@ -1784,19 +1928,21 @@ kernel void kernel_mul_mat_q6_K_f32(
1784
1928
 
1785
1929
  float4 sums = {0.f, 0.f, 0.f, 0.f};
1786
1930
  for (int l = 0; l < n; ++l) {
1787
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1788
- sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1789
- sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1790
- sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1931
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1932
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1933
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1934
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1791
1935
  }
1792
1936
 
1793
1937
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1794
1938
 
1795
1939
  }
1940
+
1796
1941
  #else
1797
- const int il = 4*tpitg.x; // 0, 4, 8, 12
1942
+ const int ix = tiisg/4;
1943
+ const int il = 4*(tiisg%4);
1798
1944
 
1799
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1945
+ for (int i = ix; i < nb; i += 8) {
1800
1946
  device const float * y = yy + i * QK_K + il;
1801
1947
  device const uint8_t * ql = x[i].ql + il;
1802
1948
  device const uint8_t * qh = x[i].qh + il;
@@ -1816,23 +1962,8 @@ kernel void kernel_mul_mat_q6_K_f32(
1816
1962
 
1817
1963
  #endif
1818
1964
 
1819
- sum[ith] = sumf;
1820
-
1821
- //
1822
- // Accumulate the sum from all threads in the threadgroup
1823
- //
1824
- threadgroup_barrier(mem_flags::mem_threadgroup);
1825
- if (ith%4 == 0) {
1826
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1965
+ const float tot = simd_sum(sumf);
1966
+ if (tiisg == 0) {
1967
+ dst[r1*ne0 + row] = tot;
1827
1968
  }
1828
- threadgroup_barrier(mem_flags::mem_threadgroup);
1829
- if (ith%16 == 0) {
1830
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1831
- }
1832
- threadgroup_barrier(mem_flags::mem_threadgroup);
1833
- if (ith == 0) {
1834
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1835
- dst[r1*ne0 + r0] = sum[0];
1836
- }
1837
-
1838
1969
  }