llama_cpp 0.3.3 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/ext/llama_cpp/llama_cpp.cpp +146 -9
- data/ext/llama_cpp/src/ggml-cuda.cu +485 -67
- data/ext/llama_cpp/src/ggml-metal.m +52 -43
- data/ext/llama_cpp/src/ggml-metal.metal +587 -470
- data/ext/llama_cpp/src/ggml.c +105 -79
- data/ext/llama_cpp/src/ggml.h +13 -1
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +123 -66
- data/ext/llama_cpp/src/llama.h +34 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -0
- data/sig/llama_cpp.rbs +12 -1
- metadata +2 -2
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
|
|
331
331
|
threadgroup float * sum [[threadgroup(0)]],
|
332
332
|
uint tgpig[[threadgroup_position_in_grid]],
|
333
333
|
uint tpitg[[thread_position_in_threadgroup]],
|
334
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
335
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
334
336
|
uint ntg[[threads_per_threadgroup]]) {
|
335
|
-
device const
|
337
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
338
|
+
device const float * x_scalar = (device const float *) x;
|
339
|
+
float4 sumf=0;
|
340
|
+
float all_sum=0;
|
336
341
|
|
337
342
|
// parallel sum
|
338
|
-
|
339
|
-
|
340
|
-
|
343
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
344
|
+
sumf += x[i00] * x[i00];
|
345
|
+
}
|
346
|
+
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
347
|
+
all_sum = simd_sum(all_sum);
|
348
|
+
if (tiisg == 0) {
|
349
|
+
sum[sgitg] = all_sum;
|
341
350
|
}
|
342
351
|
|
343
|
-
// reduce
|
344
352
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
353
|
+
// broadcast, simd group number is ntg / 32
|
354
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
355
|
+
if (tpitg < i) {
|
356
|
+
sum[tpitg] += sum[tpitg + i];
|
357
|
+
}
|
350
358
|
}
|
351
|
-
|
352
|
-
// broadcast
|
353
359
|
if (tpitg == 0) {
|
360
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
354
361
|
sum[0] /= ne00;
|
355
362
|
}
|
356
363
|
|
@@ -359,82 +366,94 @@ kernel void kernel_rms_norm(
|
|
359
366
|
const float mean = sum[0];
|
360
367
|
const float scale = 1.0f/sqrt(mean + eps);
|
361
368
|
|
362
|
-
device
|
363
|
-
|
369
|
+
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
370
|
+
device float * y_scalar = (device float *) y;
|
371
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
364
372
|
y[i00] = x[i00] * scale;
|
365
373
|
}
|
374
|
+
if (tpitg == 0) {
|
375
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
376
|
+
}
|
377
|
+
}
|
378
|
+
|
379
|
+
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
|
380
|
+
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
|
381
|
+
float d = qb_curr->d;
|
382
|
+
float4 acc = 0.f;
|
383
|
+
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
|
384
|
+
for (int i = 0; i < 16; i+=2) {
|
385
|
+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
386
|
+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
387
|
+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
388
|
+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
389
|
+
}
|
390
|
+
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
|
391
|
+
}
|
392
|
+
|
393
|
+
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
|
394
|
+
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
|
395
|
+
float d = qb_curr->d;
|
396
|
+
float m = qb_curr->m;
|
397
|
+
float4 acc = 0.f;
|
398
|
+
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
|
399
|
+
for (int i = 0; i < 16; i+=2) {
|
400
|
+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
401
|
+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
402
|
+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
403
|
+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
404
|
+
}
|
405
|
+
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
|
366
406
|
}
|
367
407
|
|
368
408
|
// putting them in the kernel cause a significant performance penalty
|
369
409
|
#define N_DST 4 // each SIMD group works on 4 rows
|
370
410
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
371
411
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
constant int64_t & ne00,
|
377
|
-
constant int64_t & ne10,
|
378
|
-
constant int64_t & ne0,
|
379
|
-
constant int64_t & ne01[[buffer(4)]],
|
380
|
-
uint2 tgpig[[threadgroup_position_in_grid]],
|
381
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
382
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
412
|
+
template<typename block_q_type>
|
413
|
+
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
414
|
+
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
415
|
+
uint2 tgpig, uint tiisg, uint sgitg) {
|
383
416
|
const int nb = ne00/QK4_0;
|
384
417
|
const int r0 = tgpig.x;
|
385
418
|
const int r1 = tgpig.y;
|
386
|
-
device const
|
419
|
+
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
387
420
|
device const float * y = (device const float *) src1 + r1*ne10;
|
388
|
-
block_q4_0 qb_curr, qb_next;
|
389
421
|
float4 y_curr[8]; // src1 vector cache
|
390
422
|
float sumf[N_DST]={0.f}, all_sum;
|
391
423
|
thread float * yl=(thread float *)y_curr;
|
392
424
|
|
393
|
-
// bootstrap
|
394
|
-
qb_curr = x[tiisg];
|
395
425
|
// each thread in a SIMD group deals with 1 block.
|
396
426
|
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
397
|
-
|
427
|
+
float sumy = 0;
|
398
428
|
for (int i = 0; i < QK4_0 / 4; i++) {
|
399
|
-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) +
|
429
|
+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
|
430
|
+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
400
431
|
}
|
401
432
|
|
402
433
|
for (int row = 0; row < N_DST; row++) {
|
403
|
-
|
404
|
-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
405
|
-
|
406
|
-
// calculate
|
407
|
-
float d = qb_curr.d;
|
408
|
-
float2 acc = {0.0f, 0.0f};
|
409
|
-
for (int i = 0; i < 16; i++) {
|
410
|
-
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
411
|
-
acc[1] += yl[i] + yl[i+16];
|
412
|
-
}
|
413
|
-
sumf[row] += d * (acc[0] - 8.f*acc[1]);
|
414
|
-
qb_curr = qb_next;
|
434
|
+
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
|
415
435
|
}
|
416
436
|
}
|
417
437
|
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
float d = qb_curr.d;
|
428
|
-
float2 acc = {0.0f, 0.0f};
|
429
|
-
for (int i = 0; i < 16; i++) {
|
430
|
-
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
431
|
-
acc[1] += yl[i] + yl[i+16];
|
438
|
+
// from now loads two rows every time and 16 blocks per row
|
439
|
+
int ir = tiisg / (N_SIMDWIDTH / 2);
|
440
|
+
int ib = tiisg % (N_SIMDWIDTH / 2);
|
441
|
+
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
|
442
|
+
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
|
443
|
+
float sumy = 0;
|
444
|
+
for (int i = 0; i < QK4_0 / 4; i++) {
|
445
|
+
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
|
446
|
+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
432
447
|
}
|
433
|
-
|
434
|
-
|
448
|
+
|
449
|
+
for (int row = 0; row < N_DST; row+=2) {
|
450
|
+
if (nb_start + ib < nb) {
|
451
|
+
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
|
452
|
+
}
|
435
453
|
}
|
436
|
-
|
454
|
+
}
|
437
455
|
|
456
|
+
for (int row = 0; row < N_DST; ++row) {
|
438
457
|
all_sum = simd_sum(sumf[row]);
|
439
458
|
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
440
459
|
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
@@ -442,73 +461,32 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
442
461
|
}
|
443
462
|
}
|
444
463
|
|
445
|
-
kernel void
|
464
|
+
kernel void kernel_mul_mat_q4_0_f32(
|
446
465
|
device const void * src0,
|
447
466
|
device const float * src1,
|
448
467
|
device float * dst,
|
449
468
|
constant int64_t & ne00,
|
450
469
|
constant int64_t & ne10,
|
451
470
|
constant int64_t & ne0,
|
452
|
-
|
471
|
+
constant int64_t & ne01[[buffer(4)]],
|
453
472
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
const int64_t r0 = tgpig.x;
|
459
|
-
const int64_t r1 = tgpig.y;
|
460
|
-
|
461
|
-
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
|
462
|
-
device const float * y = (device const float *) src1 + r1*ne10;
|
463
|
-
|
464
|
-
const uint nth = tptg.x*tptg.y;
|
465
|
-
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
466
|
-
|
467
|
-
const int ix = tpitg.y/4; // 0 or 1
|
468
|
-
const int iy = tpitg.y - 4*ix; // 0...3
|
469
|
-
|
470
|
-
const int first = 4 * iy;
|
471
|
-
|
472
|
-
float sumf = 0;
|
473
|
-
|
474
|
-
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
475
|
-
|
476
|
-
const float d = (float)x[i].d;
|
477
|
-
const float m = (float)x[i].m;
|
478
|
-
|
479
|
-
device const uint8_t * xl = x[i].qs + first;
|
480
|
-
device const float * yl = y + i * QK4_1 + first;
|
481
|
-
|
482
|
-
float2 acc = {0.0f, 0.0f};
|
483
|
-
|
484
|
-
for (int j = 0; j < 4; ++j) {
|
485
|
-
|
486
|
-
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
|
487
|
-
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
|
488
|
-
|
489
|
-
}
|
490
|
-
|
491
|
-
sumf += acc[0] + acc[1];
|
492
|
-
}
|
493
|
-
|
494
|
-
sum[ith] = sumf;
|
473
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
474
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
475
|
+
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
476
|
+
}
|
495
477
|
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
if (ith == 0) {
|
509
|
-
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
510
|
-
dst[r1*ne0 + r0] = sum[0];
|
511
|
-
}
|
478
|
+
kernel void kernel_mul_mat_q4_1_f32(
|
479
|
+
device const void * src0,
|
480
|
+
device const float * src1,
|
481
|
+
device float * dst,
|
482
|
+
constant int64_t & ne00,
|
483
|
+
constant int64_t & ne10,
|
484
|
+
constant int64_t & ne0,
|
485
|
+
constant int64_t & ne01[[buffer(4)]],
|
486
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
487
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
488
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
489
|
+
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
512
490
|
}
|
513
491
|
|
514
492
|
kernel void kernel_mul_mat_f16_f32(
|
@@ -624,17 +602,19 @@ kernel void kernel_rope(
|
|
624
602
|
constant int & n_past,
|
625
603
|
constant int & n_dims,
|
626
604
|
constant int & mode,
|
605
|
+
constant float & freq_base,
|
606
|
+
constant float & freq_scale,
|
627
607
|
uint3 tpig[[thread_position_in_grid]]) {
|
628
608
|
const int64_t i3 = tpig[2];
|
629
609
|
const int64_t i2 = tpig[1];
|
630
610
|
const int64_t i1 = tpig[0];
|
631
611
|
|
632
612
|
const bool is_neox = mode & 2;
|
633
|
-
const float theta_scale = pow(
|
613
|
+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
634
614
|
|
635
615
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
636
616
|
|
637
|
-
float theta = (float)p;
|
617
|
+
float theta = freq_scale * (float)p;
|
638
618
|
|
639
619
|
if (!is_neox) {
|
640
620
|
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
@@ -1229,111 +1209,137 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1229
1209
|
constant int64_t & ne00,
|
1230
1210
|
constant int64_t & ne10,
|
1231
1211
|
constant int64_t & ne0,
|
1232
|
-
|
1212
|
+
constant int64_t & ne01[[buffer(4)]],
|
1233
1213
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1234
|
-
|
1235
|
-
|
1214
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1215
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1236
1216
|
|
1237
1217
|
const int nb = ne00/QK_K;
|
1218
|
+
const int r0 = tgpig.x;
|
1219
|
+
const int r1 = tgpig.y;
|
1238
1220
|
|
1239
|
-
const
|
1240
|
-
const
|
1241
|
-
|
1242
|
-
device const
|
1243
|
-
|
1244
|
-
|
1245
|
-
const int nth = tptg.x*tptg.y;
|
1246
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1221
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1222
|
+
const int ib_row = first_row * nb;
|
1223
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
|
1224
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1225
|
+
float yl[32];
|
1226
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1247
1227
|
|
1248
|
-
|
1228
|
+
const int step = sizeof(block_q2_K) * nb;
|
1249
1229
|
|
1250
1230
|
#if QK_K == 256
|
1251
|
-
const int
|
1252
|
-
const int
|
1253
|
-
const int
|
1254
|
-
const int
|
1255
|
-
const int
|
1256
|
-
|
1257
|
-
const
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1231
|
+
const int ix = tiisg/8; // 0...3
|
1232
|
+
const int it = tiisg%8; // 0...7
|
1233
|
+
const int im = it/4; // 0 or 1
|
1234
|
+
const int ir = it%4; // 0...3
|
1235
|
+
const int is = (8*ir)/16;// 0 or 1
|
1236
|
+
|
1237
|
+
device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
|
1238
|
+
|
1239
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
1240
|
+
|
1241
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1242
|
+
for (int i = 0; i < 8; ++i) {
|
1243
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
1244
|
+
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
1245
|
+
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
1246
|
+
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
1247
|
+
}
|
1267
1248
|
|
1268
|
-
uint8_t
|
1269
|
-
|
1270
|
-
|
1271
|
-
uint8_t m2 = scales[2] >> 4;
|
1249
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
|
1250
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
1251
|
+
device const half * dh = &x[ib].d;
|
1272
1252
|
|
1273
|
-
|
1253
|
+
for (int row = 0; row < N_DST; row++) {
|
1274
1254
|
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1255
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1256
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1257
|
+
for (int i = 0; i < 8; i += 2) {
|
1258
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
1259
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
1260
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
1261
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
1262
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
1263
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
1264
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
1265
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
1266
|
+
}
|
1267
|
+
float dall = dh[0];
|
1268
|
+
float dmin = dh[1] * 1.f/16.f;
|
1269
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
1270
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
|
1271
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
|
1272
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
1273
|
+
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
1274
|
+
|
1275
|
+
qs += step/2;
|
1276
|
+
sc += step;
|
1277
|
+
dh += step/2;
|
1281
1278
|
}
|
1282
1279
|
|
1283
|
-
|
1284
|
-
const float dmin = (float)x[i].dmin;
|
1285
|
-
|
1286
|
-
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
1287
|
-
|
1280
|
+
y4 += 4 * QK_K;
|
1288
1281
|
}
|
1289
1282
|
#else
|
1290
|
-
const int
|
1283
|
+
const int ix = tiisg/2; // 0...15
|
1284
|
+
const int it = tiisg%2; // 0...1
|
1291
1285
|
|
1292
|
-
|
1293
|
-
thread const uint8_t * d = (thread const uint8_t *)aux;
|
1294
|
-
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
1286
|
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
1295
1287
|
|
1296
|
-
for (int
|
1288
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
1297
1289
|
|
1298
|
-
|
1299
|
-
|
1290
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1291
|
+
for (int i = 0; i < 8; ++i) {
|
1292
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
1293
|
+
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
1294
|
+
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
1295
|
+
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
1296
|
+
}
|
1300
1297
|
|
1301
|
-
const
|
1302
|
-
const
|
1298
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
1299
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
1300
|
+
device const half * dh = &x[ib].d;
|
1303
1301
|
|
1304
|
-
|
1305
|
-
aux[0] = a[0] & 0x0f0f0f0f;
|
1306
|
-
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
1302
|
+
for (int row = 0; row < N_DST; row++) {
|
1307
1303
|
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1304
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1305
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1306
|
+
for (int i = 0; i < 8; i += 2) {
|
1307
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
1308
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
1309
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
1310
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
1311
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
1312
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
1313
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
1314
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
1315
|
+
}
|
1316
|
+
|
1317
|
+
float dall = dh[0];
|
1318
|
+
float dmin = dh[1];
|
1319
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
1320
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
1321
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
1322
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
1323
|
+
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
1324
|
+
|
1325
|
+
qs += step/2;
|
1326
|
+
sc += step;
|
1327
|
+
dh += step/2;
|
1313
1328
|
}
|
1329
|
+
|
1330
|
+
y4 += 16 * QK_K;
|
1314
1331
|
}
|
1315
1332
|
#endif
|
1316
1333
|
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1323
|
-
if (ith%4 == 0) {
|
1324
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1325
|
-
}
|
1326
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1327
|
-
if (ith%16 == 0) {
|
1328
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1329
|
-
}
|
1330
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1331
|
-
if (ith == 0) {
|
1332
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1333
|
-
dst[r1*ne0 + r0] = sum[0];
|
1334
|
+
for (int row = 0; row < N_DST; ++row) {
|
1335
|
+
all_sum = simd_sum(sumf[row]);
|
1336
|
+
if (tiisg == 0) {
|
1337
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1338
|
+
}
|
1334
1339
|
}
|
1335
1340
|
}
|
1336
1341
|
|
1342
|
+
#if QK_K == 256
|
1337
1343
|
kernel void kernel_mul_mat_q3_K_f32(
|
1338
1344
|
device const void * src0,
|
1339
1345
|
device const float * src1,
|
@@ -1342,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1342
1348
|
constant int64_t & ne10,
|
1343
1349
|
constant int64_t & ne0,
|
1344
1350
|
constant int64_t & ne1,
|
1345
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1346
1351
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1347
|
-
|
1348
|
-
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1353
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1349
1354
|
|
1350
1355
|
const int nb = ne00/QK_K;
|
1351
1356
|
|
1352
1357
|
const int64_t r0 = tgpig.x;
|
1353
1358
|
const int64_t r1 = tgpig.y;
|
1354
1359
|
|
1355
|
-
|
1356
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1360
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1357
1361
|
|
1358
|
-
const
|
1359
|
-
const
|
1360
|
-
|
1361
|
-
#if QK_K == 256
|
1362
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
1363
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1362
1364
|
|
1363
|
-
|
1364
|
-
const int8_t m4 = 4;
|
1365
|
+
float yl[16];
|
1365
1366
|
|
1366
1367
|
const uint16_t kmask1 = 0x0303;
|
1367
1368
|
const uint16_t kmask2 = 0x0f0f;
|
1368
1369
|
|
1369
|
-
const int tid =
|
1370
|
+
const int tid = tiisg/2;
|
1371
|
+
const int ix = tiisg%2;
|
1370
1372
|
const int ip = tid/8; // 0 or 1
|
1371
1373
|
const int il = tid/2 - 4*ip; // 0...3
|
1372
1374
|
const int ir = tid%2;
|
1373
1375
|
const int n = 8;
|
1374
1376
|
const int l0 = n*ir;
|
1375
1377
|
|
1376
|
-
const
|
1378
|
+
const uint16_t m1 = 1 << (4*ip + il);
|
1379
|
+
const uint16_t m2 = m1 << 8;
|
1377
1380
|
|
1378
1381
|
const int shift = 2*il;
|
1382
|
+
const uint16_t qm1 = 0x0003 << shift;
|
1383
|
+
const uint16_t qm2 = 0x0300 << shift;
|
1384
|
+
const int32_t v1 = 4 << shift;
|
1385
|
+
const int32_t v2 = 1024 << shift;
|
1379
1386
|
|
1380
1387
|
const uint16_t s_shift1 = 4*ip;
|
1381
1388
|
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
@@ -1384,226 +1391,315 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1384
1391
|
const int q_offset = 32*ip + l0;
|
1385
1392
|
const int y_offset = 128*ip + 32*il + l0;
|
1386
1393
|
|
1387
|
-
|
1388
|
-
float sumf1 = 0, sumf2 = 0;
|
1389
|
-
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1390
|
-
|
1391
|
-
const float d_all = (float)(x[i].d);
|
1394
|
+
const int step = sizeof(block_q3_K) * nb / 2;
|
1392
1395
|
|
1393
|
-
|
1394
|
-
device const uint8_t * h = x[i].hmask + l0;
|
1395
|
-
device const float * y = yy + i * QK_K + y_offset;
|
1396
|
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
1396
1397
|
|
1397
|
-
|
1398
|
-
|
1398
|
+
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
1399
|
+
for (int i = ix; i < nb; i += 2) {
|
1399
1400
|
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1401
|
+
for (int l = 0; l < 8; ++l) {
|
1402
|
+
yl[l+0] = y1[l+ 0];
|
1403
|
+
yl[l+8] = y1[l+16];
|
1403
1404
|
}
|
1404
|
-
float d = d_all * s;
|
1405
|
-
sumf1 += d * scales[0];
|
1406
|
-
sumf2 += d;
|
1407
|
-
//sumf += d_all * s * (scales[0] - 32);
|
1408
1405
|
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
d = d_all * s;
|
1414
|
-
sumf1 += d * scales[1];
|
1415
|
-
sumf2 += d;
|
1416
|
-
//sumf += d_all * s * (scales[1] - 32);
|
1406
|
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
1407
|
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
1408
|
+
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
1409
|
+
device const half * dh = &x[i].d;
|
1417
1410
|
|
1418
|
-
|
1411
|
+
for (int row = 0; row < 2; ++row) {
|
1419
1412
|
|
1420
|
-
|
1421
|
-
|
1422
|
-
#else
|
1423
|
-
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1424
|
-
const int im = il/8; // 0, 0, 1, 1
|
1425
|
-
const int in = il%8; // 0, 4, 0, 4
|
1413
|
+
const float d_all = (float)dh[0];
|
1414
|
+
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1426
1415
|
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1416
|
+
float s1 = 0, s2 = 0;
|
1417
|
+
for (int l = 0; l < n; l += 2) {
|
1418
|
+
const uint16_t qs = q[l/2];
|
1419
|
+
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
1420
|
+
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
1421
|
+
}
|
1422
|
+
float d = d_all * (s1 + 1.f/256.f * s2);
|
1423
|
+
sumf1[row] += d * scales[0];
|
1424
|
+
sumf2[row] += d;
|
1425
|
+
|
1426
|
+
s1 = s2 = 0;
|
1427
|
+
for (int l = 0; l < n; l += 2) {
|
1428
|
+
const uint16_t qs = q[l/2+8];
|
1429
|
+
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
1430
|
+
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
1431
|
+
}
|
1432
|
+
d = d_all * (s1 + 1.f/256.f * s2);
|
1433
|
+
sumf1[row] += d * scales[1];
|
1434
|
+
sumf2[row] += d;
|
1432
1435
|
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
+
q += step;
|
1437
|
+
h += step;
|
1438
|
+
a += step;
|
1439
|
+
dh += step;
|
1436
1440
|
|
1437
|
-
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
1438
|
-
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
1439
|
-
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
1440
|
-
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
1441
|
-
|
1442
|
-
for (int l = 0; l < 4; ++l) {
|
1443
|
-
const uint8_t hm = h[l] >> im;
|
1444
|
-
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
1445
|
-
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
1446
|
-
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
1447
|
-
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
1448
1441
|
}
|
1449
1442
|
|
1450
|
-
|
1451
|
-
|
1452
|
-
sum[ith] = sumf;
|
1453
|
-
|
1454
|
-
#endif
|
1443
|
+
y1 += 2 * QK_K;
|
1455
1444
|
|
1456
|
-
//
|
1457
|
-
// Accumulate the sum from all threads in the threadgroup
|
1458
|
-
//
|
1459
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1460
|
-
if (ith%4 == 0) {
|
1461
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1462
|
-
}
|
1463
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1464
|
-
if (ith%16 == 0) {
|
1465
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1466
|
-
}
|
1467
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1468
|
-
if (ith == 0) {
|
1469
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1470
|
-
dst[r1*ne0 + r0] = sum[0];
|
1471
1445
|
}
|
1472
1446
|
|
1447
|
+
for (int row = 0; row < 2; ++row) {
|
1448
|
+
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
1449
|
+
const float tot = simd_sum(sumf);
|
1450
|
+
if (tiisg == 0) {
|
1451
|
+
dst[r1*ne0 + first_row + row] = tot;
|
1452
|
+
}
|
1453
|
+
}
|
1473
1454
|
}
|
1474
|
-
|
1475
|
-
kernel void
|
1455
|
+
#else
|
1456
|
+
kernel void kernel_mul_mat_q3_K_f32(
|
1476
1457
|
device const void * src0,
|
1477
1458
|
device const float * src1,
|
1478
1459
|
device float * dst,
|
1479
1460
|
constant int64_t & ne00,
|
1480
1461
|
constant int64_t & ne10,
|
1481
1462
|
constant int64_t & ne0,
|
1482
|
-
|
1463
|
+
constant int64_t & ne1,
|
1483
1464
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1484
|
-
|
1485
|
-
|
1465
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1466
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1486
1467
|
|
1487
1468
|
const int nb = ne00/QK_K;
|
1488
1469
|
|
1489
1470
|
const int64_t r0 = tgpig.x;
|
1490
1471
|
const int64_t r1 = tgpig.y;
|
1491
1472
|
|
1492
|
-
const int
|
1493
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1473
|
+
const int row = 2 * r0 + sgitg;
|
1494
1474
|
|
1495
|
-
device const
|
1475
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
1496
1476
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1477
|
+
const int ix = tiisg/4;
|
1478
|
+
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1479
|
+
const int im = il/8; // 0, 0, 1, 1
|
1480
|
+
const int in = il%8; // 0, 4, 0, 4
|
1497
1481
|
|
1498
|
-
|
1482
|
+
float2 sum = {0.f, 0.f};
|
1483
|
+
|
1484
|
+
for (int i = ix; i < nb; i += 8) {
|
1485
|
+
|
1486
|
+
const float d_all = (float)(x[i].d);
|
1487
|
+
|
1488
|
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
1489
|
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
1490
|
+
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
1491
|
+
device const float * y = yy + i * QK_K + il;
|
1492
|
+
|
1493
|
+
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
1494
|
+
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
1495
|
+
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
1496
|
+
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
1497
|
+
|
1498
|
+
for (int l = 0; l < 4; l += 2) {
|
1499
|
+
const uint16_t hm = h[l/2] >> im;
|
1500
|
+
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
1501
|
+
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
1502
|
+
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
1503
|
+
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
1504
|
+
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
1505
|
+
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
1506
|
+
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
1507
|
+
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
1508
|
+
}
|
1509
|
+
|
1510
|
+
}
|
1511
|
+
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
1512
|
+
|
1513
|
+
const float tot = simd_sum(sumf);
|
1514
|
+
if (tiisg == 0) {
|
1515
|
+
dst[r1*ne0 + row] = tot;
|
1516
|
+
}
|
1517
|
+
|
1518
|
+
}
|
1519
|
+
#endif
|
1499
1520
|
|
1500
1521
|
#if QK_K == 256
|
1522
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1523
|
+
device const void * src0,
|
1524
|
+
device const float * src1,
|
1525
|
+
device float * dst,
|
1526
|
+
constant int64_t & ne00,
|
1527
|
+
constant int64_t & ne10,
|
1528
|
+
constant int64_t & ne0,
|
1529
|
+
constant int64_t & ne01[[buffer(4)]],
|
1530
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1531
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1532
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1501
1533
|
|
1502
1534
|
const uint16_t kmask1 = 0x3f3f;
|
1503
1535
|
const uint16_t kmask2 = 0x0f0f;
|
1504
1536
|
const uint16_t kmask3 = 0xc0c0;
|
1505
1537
|
|
1506
|
-
const int
|
1507
|
-
const int
|
1508
|
-
const int
|
1509
|
-
const int
|
1538
|
+
const int ix = tiisg/8; // 0...3
|
1539
|
+
const int it = tiisg%8; // 0...7
|
1540
|
+
const int im = it/4; // 0 or 1
|
1541
|
+
const int ir = it%4; // 0...3
|
1510
1542
|
|
1511
|
-
const int
|
1512
|
-
const int
|
1543
|
+
const int nb = ne00/QK_K;
|
1544
|
+
const int r0 = tgpig.x;
|
1545
|
+
const int r1 = tgpig.y;
|
1546
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1547
|
+
const int ib_row = first_row * nb;
|
1548
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
|
1549
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1550
|
+
float yl[16];
|
1551
|
+
float yh[16];
|
1552
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1513
1553
|
|
1514
|
-
const int
|
1515
|
-
const int q_offset = 32*im + l0;
|
1516
|
-
const int y_offset = 64*im + l0;
|
1554
|
+
const int step = sizeof(block_q4_K) * nb / 2;
|
1517
1555
|
|
1518
|
-
|
1556
|
+
device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
|
1519
1557
|
|
1520
|
-
|
1558
|
+
uint16_t sc16[4];
|
1559
|
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
1521
1560
|
|
1522
|
-
|
1523
|
-
device const uint8_t * q2 = q1 + 64;
|
1524
|
-
device const float * y1 = yy + i*QK_K + y_offset;
|
1525
|
-
device const float * y2 = y1 + 128;
|
1561
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
1526
1562
|
|
1527
|
-
|
1528
|
-
|
1563
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1564
|
+
for (int i = 0; i < 8; ++i) {
|
1565
|
+
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
1566
|
+
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
1567
|
+
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
1568
|
+
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
1569
|
+
}
|
1529
1570
|
|
1530
|
-
device const uint16_t *
|
1531
|
-
|
1532
|
-
|
1533
|
-
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
|
1534
|
-
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
|
1571
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
|
1572
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
1573
|
+
device const half * dh = &x[ib].d;
|
1535
1574
|
|
1536
|
-
|
1537
|
-
float smin = 0;
|
1538
|
-
for (int l = 0; l < n; ++l) {
|
1575
|
+
for (int row = 0; row < N_DST; row++) {
|
1539
1576
|
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1577
|
+
sc16[0] = sc[0] & kmask1;
|
1578
|
+
sc16[1] = sc[2] & kmask1;
|
1579
|
+
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
1580
|
+
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
1581
|
+
|
1582
|
+
device const uint16_t * q2 = q1 + 32;
|
1583
|
+
|
1584
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1585
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1586
|
+
for (int i = 0; i < 8; i += 2) {
|
1587
|
+
acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
|
1588
|
+
acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
|
1589
|
+
acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
|
1590
|
+
acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
|
1591
|
+
acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
|
1592
|
+
acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
|
1593
|
+
acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
|
1594
|
+
acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
|
1595
|
+
}
|
1543
1596
|
|
1597
|
+
float dall = dh[0];
|
1598
|
+
float dmin = dh[1];
|
1599
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
1600
|
+
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
1601
|
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
1602
|
+
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
1603
|
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1604
|
+
|
1605
|
+
q1 += step;
|
1606
|
+
sc += step;
|
1607
|
+
dh += step;
|
1544
1608
|
}
|
1545
|
-
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1546
1609
|
|
1610
|
+
y4 += 4 * QK_K;
|
1611
|
+
}
|
1612
|
+
|
1613
|
+
for (int row = 0; row < N_DST; ++row) {
|
1614
|
+
all_sum = simd_sum(sumf[row]);
|
1615
|
+
if (tiisg == 0) {
|
1616
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1617
|
+
}
|
1547
1618
|
}
|
1619
|
+
}
|
1548
1620
|
#else
|
1549
|
-
|
1550
|
-
|
1621
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1622
|
+
device const void * src0,
|
1623
|
+
device const float * src1,
|
1624
|
+
device float * dst,
|
1625
|
+
constant int64_t & ne00,
|
1626
|
+
constant int64_t & ne10,
|
1627
|
+
constant int64_t & ne0,
|
1628
|
+
constant int64_t & ne01[[buffer(4)]],
|
1629
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1630
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1631
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1551
1632
|
|
1552
|
-
const int
|
1633
|
+
const int ix = tiisg/4; // 0...7
|
1634
|
+
const int it = tiisg%4; // 0...3
|
1553
1635
|
|
1554
|
-
|
1636
|
+
const int nb = ne00/QK_K;
|
1637
|
+
const int r0 = tgpig.x;
|
1638
|
+
const int r1 = tgpig.y;
|
1639
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1640
|
+
const int ib_row = first_row * nb;
|
1641
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
|
1642
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1643
|
+
float yl[8];
|
1644
|
+
float yh[8];
|
1645
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1555
1646
|
|
1556
|
-
|
1557
|
-
device const float * y = yy + i * QK_K + il;
|
1647
|
+
const int step = sizeof(block_q4_K) * nb / 2;
|
1558
1648
|
|
1559
|
-
|
1560
|
-
const float m = (float)x[i].d[1];
|
1649
|
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
1561
1650
|
|
1562
|
-
|
1563
|
-
aux16[0] = a[0] & 0x0f0f;
|
1564
|
-
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
1651
|
+
uint16_t sc16[4];
|
1565
1652
|
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1653
|
+
for (int ib = ix; ib < nb; ib += 8) {
|
1654
|
+
|
1655
|
+
float2 sumy = {0.f, 0.f};
|
1656
|
+
for (int i = 0; i < 8; ++i) {
|
1657
|
+
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
1658
|
+
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
1569
1659
|
}
|
1570
|
-
}
|
1571
|
-
#endif
|
1572
1660
|
|
1573
|
-
|
1661
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
1662
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
1663
|
+
device const half * dh = x[ib].d;
|
1574
1664
|
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1665
|
+
for (int row = 0; row < N_DST; row++) {
|
1666
|
+
|
1667
|
+
sc16[0] = sc[0] & 0x000f;
|
1668
|
+
sc16[1] = sc[0] & 0x0f00;
|
1669
|
+
sc16[2] = sc[0] & 0x00f0;
|
1670
|
+
sc16[3] = sc[0] & 0xf000;
|
1671
|
+
|
1672
|
+
float2 acc1 = {0.f, 0.f};
|
1673
|
+
float2 acc2 = {0.f, 0.f};
|
1674
|
+
for (int i = 0; i < 8; i += 2) {
|
1675
|
+
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
1676
|
+
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
1677
|
+
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
1678
|
+
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
1679
|
+
}
|
1680
|
+
|
1681
|
+
float dall = dh[0];
|
1682
|
+
float dmin = dh[1];
|
1683
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
1684
|
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
1685
|
+
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
1686
|
+
|
1687
|
+
qs += step;
|
1688
|
+
sc += step;
|
1689
|
+
dh += step;
|
1690
|
+
}
|
1691
|
+
|
1692
|
+
y4 += 8 * QK_K;
|
1592
1693
|
}
|
1593
1694
|
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
1601
|
-
//}
|
1602
|
-
|
1603
|
-
//if (ith == 0) {
|
1604
|
-
// dst[r1*ne0 + r0] = sum[0];
|
1605
|
-
//}
|
1695
|
+
for (int row = 0; row < N_DST; ++row) {
|
1696
|
+
all_sum = simd_sum(sumf[row]);
|
1697
|
+
if (tiisg == 0) {
|
1698
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1699
|
+
}
|
1700
|
+
}
|
1606
1701
|
}
|
1702
|
+
#endif
|
1607
1703
|
|
1608
1704
|
kernel void kernel_mul_mat_q5_K_f32(
|
1609
1705
|
device const void * src0,
|
@@ -1612,39 +1708,39 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1612
1708
|
constant int64_t & ne00,
|
1613
1709
|
constant int64_t & ne10,
|
1614
1710
|
constant int64_t & ne0,
|
1615
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1616
1711
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1617
|
-
|
1618
|
-
|
1712
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1713
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1619
1714
|
|
1620
1715
|
const int nb = ne00/QK_K;
|
1621
1716
|
|
1622
1717
|
const int64_t r0 = tgpig.x;
|
1623
1718
|
const int64_t r1 = tgpig.y;
|
1624
1719
|
|
1625
|
-
|
1720
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1721
|
+
|
1722
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
|
1626
1723
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1627
1724
|
|
1628
|
-
|
1629
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1725
|
+
float sumf[2]={0.f};
|
1630
1726
|
|
1631
|
-
|
1727
|
+
const int step = sizeof(block_q5_K) * nb;
|
1632
1728
|
|
1633
1729
|
#if QK_K == 256
|
1730
|
+
#
|
1731
|
+
float yl[16], yh[16];
|
1634
1732
|
|
1635
1733
|
const uint16_t kmask1 = 0x3f3f;
|
1636
1734
|
const uint16_t kmask2 = 0x0f0f;
|
1637
1735
|
const uint16_t kmask3 = 0xc0c0;
|
1638
1736
|
|
1639
|
-
const int tid =
|
1640
|
-
const int
|
1641
|
-
const int
|
1642
|
-
const int
|
1643
|
-
|
1644
|
-
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
1645
|
-
const int in = il%2;
|
1737
|
+
const int tid = tiisg/4;
|
1738
|
+
const int ix = tiisg%4;
|
1739
|
+
const int im = tid/4;
|
1740
|
+
const int ir = tid%4;
|
1741
|
+
const int n = 8;
|
1646
1742
|
|
1647
|
-
const int l0 = n*
|
1743
|
+
const int l0 = n*ir;
|
1648
1744
|
const int q_offset = 32*im + l0;
|
1649
1745
|
const int y_offset = 64*im + l0;
|
1650
1746
|
|
@@ -1653,78 +1749,113 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1653
1749
|
const uint8_t hm3 = hm1 << 4;
|
1654
1750
|
const uint8_t hm4 = hm2 << 4;
|
1655
1751
|
|
1656
|
-
|
1752
|
+
uint16_t sc16[4];
|
1753
|
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
1657
1754
|
|
1658
|
-
|
1755
|
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
1659
1756
|
|
1660
|
-
|
1661
|
-
device const uint8_t * q2 = q1 + 64;
|
1662
|
-
device const uint8_t * qh = (x + i)->qh + l0;
|
1663
|
-
device const float * y1 = yy + i*QK_K + y_offset;
|
1664
|
-
device const float * y2 = y1 + 128;
|
1757
|
+
for (int i = ix; i < nb; i += 4) {
|
1665
1758
|
|
1666
|
-
const
|
1667
|
-
const
|
1759
|
+
device const uint8_t * q1 = x[i].qs + q_offset;
|
1760
|
+
device const uint8_t * qh = x[i].qh + l0;
|
1761
|
+
device const half * dh = &x[i].d;
|
1762
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
|
1668
1763
|
|
1669
|
-
device const
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1673
|
-
|
1764
|
+
device const float * y2 = y1 + 128;
|
1765
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1766
|
+
for (int l = 0; l < 8; ++l) {
|
1767
|
+
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
1768
|
+
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
1769
|
+
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
1770
|
+
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
1771
|
+
}
|
1674
1772
|
|
1675
|
-
|
1676
|
-
|
1677
|
-
|
1773
|
+
for (int row = 0; row < 2; ++row) {
|
1774
|
+
|
1775
|
+
device const uint8_t * q2 = q1 + 64;
|
1776
|
+
|
1777
|
+
sc16[0] = a[0] & kmask1;
|
1778
|
+
sc16[1] = a[2] & kmask1;
|
1779
|
+
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
1780
|
+
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
1678
1781
|
|
1679
|
-
|
1680
|
-
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1782
|
+
float4 acc = {0.f, 0.f, 0.f, 0.f};
|
1783
|
+
for (int l = 0; l < n; ++l) {
|
1784
|
+
uint8_t h = qh[l];
|
1785
|
+
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
|
1786
|
+
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
|
1787
|
+
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
|
1788
|
+
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
|
1789
|
+
}
|
1790
|
+
const float dall = dh[0];
|
1791
|
+
const float dmin = dh[1];
|
1792
|
+
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
|
1793
|
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1794
|
+
|
1795
|
+
q1 += step;
|
1796
|
+
qh += step;
|
1797
|
+
dh += step/2;
|
1798
|
+
a += step/2;
|
1684
1799
|
|
1685
1800
|
}
|
1686
|
-
|
1801
|
+
|
1802
|
+
y1 += 4 * QK_K;
|
1687
1803
|
|
1688
1804
|
}
|
1689
1805
|
#else
|
1690
|
-
|
1691
|
-
|
1692
|
-
const int
|
1806
|
+
float yl[8], yh[8];
|
1807
|
+
|
1808
|
+
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
1809
|
+
const int ix = tiisg%8;
|
1810
|
+
const int im = il/8; // 0, 0, 1, 1
|
1811
|
+
const int in = il%8; // 0, 4, 0, 4
|
1693
1812
|
|
1694
|
-
|
1813
|
+
device const float * y = yy + ix*QK_K + il;
|
1695
1814
|
|
1696
|
-
|
1815
|
+
for (int i = ix; i < nb; i += 8) {
|
1816
|
+
|
1817
|
+
for (int l = 0; l < 4; ++l) {
|
1818
|
+
yl[l+0] = y[l+ 0];
|
1819
|
+
yl[l+4] = y[l+16];
|
1820
|
+
yh[l+0] = y[l+32];
|
1821
|
+
yh[l+4] = y[l+48];
|
1822
|
+
}
|
1823
|
+
|
1824
|
+
device const half * dh = &x[i].d;
|
1697
1825
|
device const uint8_t * q = x[i].qs + il;
|
1698
1826
|
device const uint8_t * h = x[i].qh + in;
|
1699
1827
|
device const int8_t * s = x[i].scales;
|
1700
|
-
device const float * y = yy + i*QK_K + il;
|
1701
1828
|
|
1702
|
-
for (int
|
1703
|
-
|
1704
|
-
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1829
|
+
for (int row = 0; row < 2; ++row) {
|
1830
|
+
|
1831
|
+
const float d = dh[0];
|
1832
|
+
|
1833
|
+
float2 acc = {0.f, 0.f};
|
1834
|
+
for (int l = 0; l < 4; ++l) {
|
1835
|
+
const uint8_t hl = h[l] >> im;
|
1836
|
+
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
1837
|
+
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
1838
|
+
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
1839
|
+
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
1840
|
+
}
|
1841
|
+
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
1842
|
+
|
1843
|
+
q += step;
|
1844
|
+
h += step;
|
1845
|
+
s += step;
|
1846
|
+
dh += step/2;
|
1847
|
+
|
1708
1848
|
}
|
1849
|
+
|
1850
|
+
y += 8 * QK_K;
|
1709
1851
|
}
|
1710
1852
|
#endif
|
1711
|
-
sum[ith] = sumf;
|
1712
1853
|
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
1719
|
-
}
|
1720
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1721
|
-
if (ith%16 == 0) {
|
1722
|
-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
1723
|
-
}
|
1724
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1725
|
-
if (ith == 0) {
|
1726
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1727
|
-
dst[r1*ne0 + r0] = sum[0];
|
1854
|
+
for (int row = 0; row < 2; ++row) {
|
1855
|
+
const float tot = simd_sum(sumf[row]);
|
1856
|
+
if (tiisg == 0) {
|
1857
|
+
dst[r1*ne0 + first_row + row] = tot;
|
1858
|
+
}
|
1728
1859
|
}
|
1729
1860
|
|
1730
1861
|
}
|
@@ -1736,10 +1867,9 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1736
1867
|
constant int64_t & ne00,
|
1737
1868
|
constant int64_t & ne10,
|
1738
1869
|
constant int64_t & ne0,
|
1739
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1740
1870
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1741
|
-
|
1742
|
-
|
1871
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1872
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1743
1873
|
|
1744
1874
|
const uint8_t kmask1 = 0x03;
|
1745
1875
|
const uint8_t kmask2 = 0x0C;
|
@@ -1751,19 +1881,18 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1751
1881
|
const int64_t r0 = tgpig.x;
|
1752
1882
|
const int64_t r1 = tgpig.y;
|
1753
1883
|
|
1754
|
-
|
1755
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1884
|
+
const int row = 2 * r0 + sgitg;
|
1756
1885
|
|
1757
|
-
const
|
1758
|
-
const
|
1886
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
|
1887
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1759
1888
|
|
1760
1889
|
float sumf = 0;
|
1761
1890
|
|
1762
1891
|
#if QK_K == 256
|
1763
|
-
|
1764
|
-
const int
|
1765
|
-
const int ip =
|
1766
|
-
const int il =
|
1892
|
+
const int tid = tiisg/2;
|
1893
|
+
const int ix = tiisg%2;
|
1894
|
+
const int ip = tid/8; // 0 or 1
|
1895
|
+
const int il = tid%8;
|
1767
1896
|
const int n = 4;
|
1768
1897
|
const int l0 = n*il;
|
1769
1898
|
const int is = 8*ip + l0/16;
|
@@ -1772,9 +1901,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1772
1901
|
const int q_offset_l = 64*ip + l0;
|
1773
1902
|
const int q_offset_h = 32*ip + l0;
|
1774
1903
|
|
1775
|
-
for (int i =
|
1904
|
+
for (int i = ix; i < nb; i += 2) {
|
1776
1905
|
|
1777
|
-
device const uint8_t *
|
1906
|
+
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
1907
|
+
device const uint8_t * q2 = q1 + 32;
|
1778
1908
|
device const uint8_t * qh = x[i].qh + q_offset_h;
|
1779
1909
|
device const int8_t * sc = x[i].scales + is;
|
1780
1910
|
|
@@ -1784,19 +1914,21 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1784
1914
|
|
1785
1915
|
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
1786
1916
|
for (int l = 0; l < n; ++l) {
|
1787
|
-
sums[0] += y[l+ 0] * ((int8_t)((
|
1788
|
-
sums[1] += y[l+32] * ((int8_t)((
|
1789
|
-
sums[2] += y[l+64] * ((int8_t)((
|
1790
|
-
sums[3] += y[l+96] * ((int8_t)((
|
1917
|
+
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
1918
|
+
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
1919
|
+
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
1920
|
+
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
1791
1921
|
}
|
1792
1922
|
|
1793
1923
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
1794
1924
|
|
1795
1925
|
}
|
1926
|
+
|
1796
1927
|
#else
|
1797
|
-
const int
|
1928
|
+
const int ix = tiisg/4;
|
1929
|
+
const int il = 4*(tiisg%4);
|
1798
1930
|
|
1799
|
-
for (int i =
|
1931
|
+
for (int i = ix; i < nb; i += 8) {
|
1800
1932
|
device const float * y = yy + i * QK_K + il;
|
1801
1933
|
device const uint8_t * ql = x[i].ql + il;
|
1802
1934
|
device const uint8_t * qh = x[i].qh + il;
|
@@ -1816,23 +1948,8 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1816
1948
|
|
1817
1949
|
#endif
|
1818
1950
|
|
1819
|
-
|
1820
|
-
|
1821
|
-
|
1822
|
-
// Accumulate the sum from all threads in the threadgroup
|
1823
|
-
//
|
1824
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1825
|
-
if (ith%4 == 0) {
|
1826
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1951
|
+
const float tot = simd_sum(sumf);
|
1952
|
+
if (tiisg == 0) {
|
1953
|
+
dst[r1*ne0 + row] = tot;
|
1827
1954
|
}
|
1828
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1829
|
-
if (ith%16 == 0) {
|
1830
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1831
|
-
}
|
1832
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1833
|
-
if (ith == 0) {
|
1834
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1835
|
-
dst[r1*ne0 + r0] = sum[0];
|
1836
|
-
}
|
1837
|
-
|
1838
1955
|
}
|