llama_cpp 0.3.3 → 0.3.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/ext/llama_cpp/llama_cpp.cpp +146 -9
- data/ext/llama_cpp/src/ggml-cuda.cu +485 -67
- data/ext/llama_cpp/src/ggml-metal.m +52 -43
- data/ext/llama_cpp/src/ggml-metal.metal +587 -470
- data/ext/llama_cpp/src/ggml.c +105 -79
- data/ext/llama_cpp/src/ggml.h +13 -1
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +123 -66
- data/ext/llama_cpp/src/llama.h +34 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -0
- data/sig/llama_cpp.rbs +12 -1
- metadata +2 -2
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
|
|
331
331
|
threadgroup float * sum [[threadgroup(0)]],
|
332
332
|
uint tgpig[[threadgroup_position_in_grid]],
|
333
333
|
uint tpitg[[thread_position_in_threadgroup]],
|
334
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
335
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
334
336
|
uint ntg[[threads_per_threadgroup]]) {
|
335
|
-
device const
|
337
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
338
|
+
device const float * x_scalar = (device const float *) x;
|
339
|
+
float4 sumf=0;
|
340
|
+
float all_sum=0;
|
336
341
|
|
337
342
|
// parallel sum
|
338
|
-
|
339
|
-
|
340
|
-
|
343
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
344
|
+
sumf += x[i00] * x[i00];
|
345
|
+
}
|
346
|
+
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
347
|
+
all_sum = simd_sum(all_sum);
|
348
|
+
if (tiisg == 0) {
|
349
|
+
sum[sgitg] = all_sum;
|
341
350
|
}
|
342
351
|
|
343
|
-
// reduce
|
344
352
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
353
|
+
// broadcast, simd group number is ntg / 32
|
354
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
355
|
+
if (tpitg < i) {
|
356
|
+
sum[tpitg] += sum[tpitg + i];
|
357
|
+
}
|
350
358
|
}
|
351
|
-
|
352
|
-
// broadcast
|
353
359
|
if (tpitg == 0) {
|
360
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
354
361
|
sum[0] /= ne00;
|
355
362
|
}
|
356
363
|
|
@@ -359,82 +366,94 @@ kernel void kernel_rms_norm(
|
|
359
366
|
const float mean = sum[0];
|
360
367
|
const float scale = 1.0f/sqrt(mean + eps);
|
361
368
|
|
362
|
-
device
|
363
|
-
|
369
|
+
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
370
|
+
device float * y_scalar = (device float *) y;
|
371
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
364
372
|
y[i00] = x[i00] * scale;
|
365
373
|
}
|
374
|
+
if (tpitg == 0) {
|
375
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
376
|
+
}
|
377
|
+
}
|
378
|
+
|
379
|
+
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
|
380
|
+
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
|
381
|
+
float d = qb_curr->d;
|
382
|
+
float4 acc = 0.f;
|
383
|
+
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
|
384
|
+
for (int i = 0; i < 16; i+=2) {
|
385
|
+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
386
|
+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
387
|
+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
388
|
+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
389
|
+
}
|
390
|
+
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
|
391
|
+
}
|
392
|
+
|
393
|
+
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
|
394
|
+
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
|
395
|
+
float d = qb_curr->d;
|
396
|
+
float m = qb_curr->m;
|
397
|
+
float4 acc = 0.f;
|
398
|
+
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
|
399
|
+
for (int i = 0; i < 16; i+=2) {
|
400
|
+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
401
|
+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
402
|
+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
403
|
+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
404
|
+
}
|
405
|
+
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
|
366
406
|
}
|
367
407
|
|
368
408
|
// putting them in the kernel cause a significant performance penalty
|
369
409
|
#define N_DST 4 // each SIMD group works on 4 rows
|
370
410
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
371
411
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
constant int64_t & ne00,
|
377
|
-
constant int64_t & ne10,
|
378
|
-
constant int64_t & ne0,
|
379
|
-
constant int64_t & ne01[[buffer(4)]],
|
380
|
-
uint2 tgpig[[threadgroup_position_in_grid]],
|
381
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
382
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
412
|
+
template<typename block_q_type>
|
413
|
+
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
414
|
+
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
415
|
+
uint2 tgpig, uint tiisg, uint sgitg) {
|
383
416
|
const int nb = ne00/QK4_0;
|
384
417
|
const int r0 = tgpig.x;
|
385
418
|
const int r1 = tgpig.y;
|
386
|
-
device const
|
419
|
+
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
387
420
|
device const float * y = (device const float *) src1 + r1*ne10;
|
388
|
-
block_q4_0 qb_curr, qb_next;
|
389
421
|
float4 y_curr[8]; // src1 vector cache
|
390
422
|
float sumf[N_DST]={0.f}, all_sum;
|
391
423
|
thread float * yl=(thread float *)y_curr;
|
392
424
|
|
393
|
-
// bootstrap
|
394
|
-
qb_curr = x[tiisg];
|
395
425
|
// each thread in a SIMD group deals with 1 block.
|
396
426
|
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
397
|
-
|
427
|
+
float sumy = 0;
|
398
428
|
for (int i = 0; i < QK4_0 / 4; i++) {
|
399
|
-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) +
|
429
|
+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
|
430
|
+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
400
431
|
}
|
401
432
|
|
402
433
|
for (int row = 0; row < N_DST; row++) {
|
403
|
-
|
404
|
-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
405
|
-
|
406
|
-
// calculate
|
407
|
-
float d = qb_curr.d;
|
408
|
-
float2 acc = {0.0f, 0.0f};
|
409
|
-
for (int i = 0; i < 16; i++) {
|
410
|
-
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
411
|
-
acc[1] += yl[i] + yl[i+16];
|
412
|
-
}
|
413
|
-
sumf[row] += d * (acc[0] - 8.f*acc[1]);
|
414
|
-
qb_curr = qb_next;
|
434
|
+
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
|
415
435
|
}
|
416
436
|
}
|
417
437
|
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
float d = qb_curr.d;
|
428
|
-
float2 acc = {0.0f, 0.0f};
|
429
|
-
for (int i = 0; i < 16; i++) {
|
430
|
-
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
431
|
-
acc[1] += yl[i] + yl[i+16];
|
438
|
+
// from now loads two rows every time and 16 blocks per row
|
439
|
+
int ir = tiisg / (N_SIMDWIDTH / 2);
|
440
|
+
int ib = tiisg % (N_SIMDWIDTH / 2);
|
441
|
+
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
|
442
|
+
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
|
443
|
+
float sumy = 0;
|
444
|
+
for (int i = 0; i < QK4_0 / 4; i++) {
|
445
|
+
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
|
446
|
+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
432
447
|
}
|
433
|
-
|
434
|
-
|
448
|
+
|
449
|
+
for (int row = 0; row < N_DST; row+=2) {
|
450
|
+
if (nb_start + ib < nb) {
|
451
|
+
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
|
452
|
+
}
|
435
453
|
}
|
436
|
-
|
454
|
+
}
|
437
455
|
|
456
|
+
for (int row = 0; row < N_DST; ++row) {
|
438
457
|
all_sum = simd_sum(sumf[row]);
|
439
458
|
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
440
459
|
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
@@ -442,73 +461,32 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
442
461
|
}
|
443
462
|
}
|
444
463
|
|
445
|
-
kernel void
|
464
|
+
kernel void kernel_mul_mat_q4_0_f32(
|
446
465
|
device const void * src0,
|
447
466
|
device const float * src1,
|
448
467
|
device float * dst,
|
449
468
|
constant int64_t & ne00,
|
450
469
|
constant int64_t & ne10,
|
451
470
|
constant int64_t & ne0,
|
452
|
-
|
471
|
+
constant int64_t & ne01[[buffer(4)]],
|
453
472
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
const int64_t r0 = tgpig.x;
|
459
|
-
const int64_t r1 = tgpig.y;
|
460
|
-
|
461
|
-
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
|
462
|
-
device const float * y = (device const float *) src1 + r1*ne10;
|
463
|
-
|
464
|
-
const uint nth = tptg.x*tptg.y;
|
465
|
-
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
466
|
-
|
467
|
-
const int ix = tpitg.y/4; // 0 or 1
|
468
|
-
const int iy = tpitg.y - 4*ix; // 0...3
|
469
|
-
|
470
|
-
const int first = 4 * iy;
|
471
|
-
|
472
|
-
float sumf = 0;
|
473
|
-
|
474
|
-
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
475
|
-
|
476
|
-
const float d = (float)x[i].d;
|
477
|
-
const float m = (float)x[i].m;
|
478
|
-
|
479
|
-
device const uint8_t * xl = x[i].qs + first;
|
480
|
-
device const float * yl = y + i * QK4_1 + first;
|
481
|
-
|
482
|
-
float2 acc = {0.0f, 0.0f};
|
483
|
-
|
484
|
-
for (int j = 0; j < 4; ++j) {
|
485
|
-
|
486
|
-
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
|
487
|
-
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
|
488
|
-
|
489
|
-
}
|
490
|
-
|
491
|
-
sumf += acc[0] + acc[1];
|
492
|
-
}
|
493
|
-
|
494
|
-
sum[ith] = sumf;
|
473
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
474
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
475
|
+
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
476
|
+
}
|
495
477
|
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
if (ith == 0) {
|
509
|
-
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
510
|
-
dst[r1*ne0 + r0] = sum[0];
|
511
|
-
}
|
478
|
+
kernel void kernel_mul_mat_q4_1_f32(
|
479
|
+
device const void * src0,
|
480
|
+
device const float * src1,
|
481
|
+
device float * dst,
|
482
|
+
constant int64_t & ne00,
|
483
|
+
constant int64_t & ne10,
|
484
|
+
constant int64_t & ne0,
|
485
|
+
constant int64_t & ne01[[buffer(4)]],
|
486
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
487
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
488
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
489
|
+
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
512
490
|
}
|
513
491
|
|
514
492
|
kernel void kernel_mul_mat_f16_f32(
|
@@ -624,17 +602,19 @@ kernel void kernel_rope(
|
|
624
602
|
constant int & n_past,
|
625
603
|
constant int & n_dims,
|
626
604
|
constant int & mode,
|
605
|
+
constant float & freq_base,
|
606
|
+
constant float & freq_scale,
|
627
607
|
uint3 tpig[[thread_position_in_grid]]) {
|
628
608
|
const int64_t i3 = tpig[2];
|
629
609
|
const int64_t i2 = tpig[1];
|
630
610
|
const int64_t i1 = tpig[0];
|
631
611
|
|
632
612
|
const bool is_neox = mode & 2;
|
633
|
-
const float theta_scale = pow(
|
613
|
+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
634
614
|
|
635
615
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
636
616
|
|
637
|
-
float theta = (float)p;
|
617
|
+
float theta = freq_scale * (float)p;
|
638
618
|
|
639
619
|
if (!is_neox) {
|
640
620
|
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
@@ -1229,111 +1209,137 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1229
1209
|
constant int64_t & ne00,
|
1230
1210
|
constant int64_t & ne10,
|
1231
1211
|
constant int64_t & ne0,
|
1232
|
-
|
1212
|
+
constant int64_t & ne01[[buffer(4)]],
|
1233
1213
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1234
|
-
|
1235
|
-
|
1214
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1215
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1236
1216
|
|
1237
1217
|
const int nb = ne00/QK_K;
|
1218
|
+
const int r0 = tgpig.x;
|
1219
|
+
const int r1 = tgpig.y;
|
1238
1220
|
|
1239
|
-
const
|
1240
|
-
const
|
1241
|
-
|
1242
|
-
device const
|
1243
|
-
|
1244
|
-
|
1245
|
-
const int nth = tptg.x*tptg.y;
|
1246
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1221
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1222
|
+
const int ib_row = first_row * nb;
|
1223
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
|
1224
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1225
|
+
float yl[32];
|
1226
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1247
1227
|
|
1248
|
-
|
1228
|
+
const int step = sizeof(block_q2_K) * nb;
|
1249
1229
|
|
1250
1230
|
#if QK_K == 256
|
1251
|
-
const int
|
1252
|
-
const int
|
1253
|
-
const int
|
1254
|
-
const int
|
1255
|
-
const int
|
1256
|
-
|
1257
|
-
const
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1231
|
+
const int ix = tiisg/8; // 0...3
|
1232
|
+
const int it = tiisg%8; // 0...7
|
1233
|
+
const int im = it/4; // 0 or 1
|
1234
|
+
const int ir = it%4; // 0...3
|
1235
|
+
const int is = (8*ir)/16;// 0 or 1
|
1236
|
+
|
1237
|
+
device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
|
1238
|
+
|
1239
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
1240
|
+
|
1241
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1242
|
+
for (int i = 0; i < 8; ++i) {
|
1243
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
1244
|
+
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
1245
|
+
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
1246
|
+
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
1247
|
+
}
|
1267
1248
|
|
1268
|
-
uint8_t
|
1269
|
-
|
1270
|
-
|
1271
|
-
uint8_t m2 = scales[2] >> 4;
|
1249
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
|
1250
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
1251
|
+
device const half * dh = &x[ib].d;
|
1272
1252
|
|
1273
|
-
|
1253
|
+
for (int row = 0; row < N_DST; row++) {
|
1274
1254
|
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1255
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1256
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1257
|
+
for (int i = 0; i < 8; i += 2) {
|
1258
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
1259
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
1260
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
1261
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
1262
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
1263
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
1264
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
1265
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
1266
|
+
}
|
1267
|
+
float dall = dh[0];
|
1268
|
+
float dmin = dh[1] * 1.f/16.f;
|
1269
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
1270
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
|
1271
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
|
1272
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
1273
|
+
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
1274
|
+
|
1275
|
+
qs += step/2;
|
1276
|
+
sc += step;
|
1277
|
+
dh += step/2;
|
1281
1278
|
}
|
1282
1279
|
|
1283
|
-
|
1284
|
-
const float dmin = (float)x[i].dmin;
|
1285
|
-
|
1286
|
-
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
1287
|
-
|
1280
|
+
y4 += 4 * QK_K;
|
1288
1281
|
}
|
1289
1282
|
#else
|
1290
|
-
const int
|
1283
|
+
const int ix = tiisg/2; // 0...15
|
1284
|
+
const int it = tiisg%2; // 0...1
|
1291
1285
|
|
1292
|
-
|
1293
|
-
thread const uint8_t * d = (thread const uint8_t *)aux;
|
1294
|
-
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
1286
|
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
1295
1287
|
|
1296
|
-
for (int
|
1288
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
1297
1289
|
|
1298
|
-
|
1299
|
-
|
1290
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1291
|
+
for (int i = 0; i < 8; ++i) {
|
1292
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
1293
|
+
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
1294
|
+
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
1295
|
+
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
1296
|
+
}
|
1300
1297
|
|
1301
|
-
const
|
1302
|
-
const
|
1298
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
1299
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
1300
|
+
device const half * dh = &x[ib].d;
|
1303
1301
|
|
1304
|
-
|
1305
|
-
aux[0] = a[0] & 0x0f0f0f0f;
|
1306
|
-
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
1302
|
+
for (int row = 0; row < N_DST; row++) {
|
1307
1303
|
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1304
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1305
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1306
|
+
for (int i = 0; i < 8; i += 2) {
|
1307
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
1308
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
1309
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
1310
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
1311
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
1312
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
1313
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
1314
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
1315
|
+
}
|
1316
|
+
|
1317
|
+
float dall = dh[0];
|
1318
|
+
float dmin = dh[1];
|
1319
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
1320
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
1321
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
1322
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
1323
|
+
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
1324
|
+
|
1325
|
+
qs += step/2;
|
1326
|
+
sc += step;
|
1327
|
+
dh += step/2;
|
1313
1328
|
}
|
1329
|
+
|
1330
|
+
y4 += 16 * QK_K;
|
1314
1331
|
}
|
1315
1332
|
#endif
|
1316
1333
|
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1323
|
-
if (ith%4 == 0) {
|
1324
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1325
|
-
}
|
1326
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1327
|
-
if (ith%16 == 0) {
|
1328
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1329
|
-
}
|
1330
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1331
|
-
if (ith == 0) {
|
1332
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1333
|
-
dst[r1*ne0 + r0] = sum[0];
|
1334
|
+
for (int row = 0; row < N_DST; ++row) {
|
1335
|
+
all_sum = simd_sum(sumf[row]);
|
1336
|
+
if (tiisg == 0) {
|
1337
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1338
|
+
}
|
1334
1339
|
}
|
1335
1340
|
}
|
1336
1341
|
|
1342
|
+
#if QK_K == 256
|
1337
1343
|
kernel void kernel_mul_mat_q3_K_f32(
|
1338
1344
|
device const void * src0,
|
1339
1345
|
device const float * src1,
|
@@ -1342,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1342
1348
|
constant int64_t & ne10,
|
1343
1349
|
constant int64_t & ne0,
|
1344
1350
|
constant int64_t & ne1,
|
1345
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1346
1351
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1347
|
-
|
1348
|
-
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1353
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1349
1354
|
|
1350
1355
|
const int nb = ne00/QK_K;
|
1351
1356
|
|
1352
1357
|
const int64_t r0 = tgpig.x;
|
1353
1358
|
const int64_t r1 = tgpig.y;
|
1354
1359
|
|
1355
|
-
|
1356
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1360
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1357
1361
|
|
1358
|
-
const
|
1359
|
-
const
|
1360
|
-
|
1361
|
-
#if QK_K == 256
|
1362
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
1363
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1362
1364
|
|
1363
|
-
|
1364
|
-
const int8_t m4 = 4;
|
1365
|
+
float yl[16];
|
1365
1366
|
|
1366
1367
|
const uint16_t kmask1 = 0x0303;
|
1367
1368
|
const uint16_t kmask2 = 0x0f0f;
|
1368
1369
|
|
1369
|
-
const int tid =
|
1370
|
+
const int tid = tiisg/2;
|
1371
|
+
const int ix = tiisg%2;
|
1370
1372
|
const int ip = tid/8; // 0 or 1
|
1371
1373
|
const int il = tid/2 - 4*ip; // 0...3
|
1372
1374
|
const int ir = tid%2;
|
1373
1375
|
const int n = 8;
|
1374
1376
|
const int l0 = n*ir;
|
1375
1377
|
|
1376
|
-
const
|
1378
|
+
const uint16_t m1 = 1 << (4*ip + il);
|
1379
|
+
const uint16_t m2 = m1 << 8;
|
1377
1380
|
|
1378
1381
|
const int shift = 2*il;
|
1382
|
+
const uint16_t qm1 = 0x0003 << shift;
|
1383
|
+
const uint16_t qm2 = 0x0300 << shift;
|
1384
|
+
const int32_t v1 = 4 << shift;
|
1385
|
+
const int32_t v2 = 1024 << shift;
|
1379
1386
|
|
1380
1387
|
const uint16_t s_shift1 = 4*ip;
|
1381
1388
|
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
@@ -1384,226 +1391,315 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1384
1391
|
const int q_offset = 32*ip + l0;
|
1385
1392
|
const int y_offset = 128*ip + 32*il + l0;
|
1386
1393
|
|
1387
|
-
|
1388
|
-
float sumf1 = 0, sumf2 = 0;
|
1389
|
-
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1390
|
-
|
1391
|
-
const float d_all = (float)(x[i].d);
|
1394
|
+
const int step = sizeof(block_q3_K) * nb / 2;
|
1392
1395
|
|
1393
|
-
|
1394
|
-
device const uint8_t * h = x[i].hmask + l0;
|
1395
|
-
device const float * y = yy + i * QK_K + y_offset;
|
1396
|
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
1396
1397
|
|
1397
|
-
|
1398
|
-
|
1398
|
+
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
1399
|
+
for (int i = ix; i < nb; i += 2) {
|
1399
1400
|
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1401
|
+
for (int l = 0; l < 8; ++l) {
|
1402
|
+
yl[l+0] = y1[l+ 0];
|
1403
|
+
yl[l+8] = y1[l+16];
|
1403
1404
|
}
|
1404
|
-
float d = d_all * s;
|
1405
|
-
sumf1 += d * scales[0];
|
1406
|
-
sumf2 += d;
|
1407
|
-
//sumf += d_all * s * (scales[0] - 32);
|
1408
1405
|
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
d = d_all * s;
|
1414
|
-
sumf1 += d * scales[1];
|
1415
|
-
sumf2 += d;
|
1416
|
-
//sumf += d_all * s * (scales[1] - 32);
|
1406
|
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
1407
|
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
1408
|
+
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
1409
|
+
device const half * dh = &x[i].d;
|
1417
1410
|
|
1418
|
-
|
1411
|
+
for (int row = 0; row < 2; ++row) {
|
1419
1412
|
|
1420
|
-
|
1421
|
-
|
1422
|
-
#else
|
1423
|
-
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1424
|
-
const int im = il/8; // 0, 0, 1, 1
|
1425
|
-
const int in = il%8; // 0, 4, 0, 4
|
1413
|
+
const float d_all = (float)dh[0];
|
1414
|
+
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1426
1415
|
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1416
|
+
float s1 = 0, s2 = 0;
|
1417
|
+
for (int l = 0; l < n; l += 2) {
|
1418
|
+
const uint16_t qs = q[l/2];
|
1419
|
+
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
1420
|
+
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
1421
|
+
}
|
1422
|
+
float d = d_all * (s1 + 1.f/256.f * s2);
|
1423
|
+
sumf1[row] += d * scales[0];
|
1424
|
+
sumf2[row] += d;
|
1425
|
+
|
1426
|
+
s1 = s2 = 0;
|
1427
|
+
for (int l = 0; l < n; l += 2) {
|
1428
|
+
const uint16_t qs = q[l/2+8];
|
1429
|
+
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
1430
|
+
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
1431
|
+
}
|
1432
|
+
d = d_all * (s1 + 1.f/256.f * s2);
|
1433
|
+
sumf1[row] += d * scales[1];
|
1434
|
+
sumf2[row] += d;
|
1432
1435
|
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
+
q += step;
|
1437
|
+
h += step;
|
1438
|
+
a += step;
|
1439
|
+
dh += step;
|
1436
1440
|
|
1437
|
-
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
1438
|
-
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
1439
|
-
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
1440
|
-
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
1441
|
-
|
1442
|
-
for (int l = 0; l < 4; ++l) {
|
1443
|
-
const uint8_t hm = h[l] >> im;
|
1444
|
-
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
1445
|
-
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
1446
|
-
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
1447
|
-
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
1448
1441
|
}
|
1449
1442
|
|
1450
|
-
|
1451
|
-
|
1452
|
-
sum[ith] = sumf;
|
1453
|
-
|
1454
|
-
#endif
|
1443
|
+
y1 += 2 * QK_K;
|
1455
1444
|
|
1456
|
-
//
|
1457
|
-
// Accumulate the sum from all threads in the threadgroup
|
1458
|
-
//
|
1459
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1460
|
-
if (ith%4 == 0) {
|
1461
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1462
|
-
}
|
1463
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1464
|
-
if (ith%16 == 0) {
|
1465
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1466
|
-
}
|
1467
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1468
|
-
if (ith == 0) {
|
1469
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1470
|
-
dst[r1*ne0 + r0] = sum[0];
|
1471
1445
|
}
|
1472
1446
|
|
1447
|
+
for (int row = 0; row < 2; ++row) {
|
1448
|
+
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
1449
|
+
const float tot = simd_sum(sumf);
|
1450
|
+
if (tiisg == 0) {
|
1451
|
+
dst[r1*ne0 + first_row + row] = tot;
|
1452
|
+
}
|
1453
|
+
}
|
1473
1454
|
}
|
1474
|
-
|
1475
|
-
kernel void
|
1455
|
+
#else
|
1456
|
+
kernel void kernel_mul_mat_q3_K_f32(
|
1476
1457
|
device const void * src0,
|
1477
1458
|
device const float * src1,
|
1478
1459
|
device float * dst,
|
1479
1460
|
constant int64_t & ne00,
|
1480
1461
|
constant int64_t & ne10,
|
1481
1462
|
constant int64_t & ne0,
|
1482
|
-
|
1463
|
+
constant int64_t & ne1,
|
1483
1464
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1484
|
-
|
1485
|
-
|
1465
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1466
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1486
1467
|
|
1487
1468
|
const int nb = ne00/QK_K;
|
1488
1469
|
|
1489
1470
|
const int64_t r0 = tgpig.x;
|
1490
1471
|
const int64_t r1 = tgpig.y;
|
1491
1472
|
|
1492
|
-
const int
|
1493
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1473
|
+
const int row = 2 * r0 + sgitg;
|
1494
1474
|
|
1495
|
-
device const
|
1475
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
1496
1476
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1477
|
+
const int ix = tiisg/4;
|
1478
|
+
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1479
|
+
const int im = il/8; // 0, 0, 1, 1
|
1480
|
+
const int in = il%8; // 0, 4, 0, 4
|
1497
1481
|
|
1498
|
-
|
1482
|
+
float2 sum = {0.f, 0.f};
|
1483
|
+
|
1484
|
+
for (int i = ix; i < nb; i += 8) {
|
1485
|
+
|
1486
|
+
const float d_all = (float)(x[i].d);
|
1487
|
+
|
1488
|
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
1489
|
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
1490
|
+
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
1491
|
+
device const float * y = yy + i * QK_K + il;
|
1492
|
+
|
1493
|
+
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
1494
|
+
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
1495
|
+
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
1496
|
+
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
1497
|
+
|
1498
|
+
for (int l = 0; l < 4; l += 2) {
|
1499
|
+
const uint16_t hm = h[l/2] >> im;
|
1500
|
+
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
1501
|
+
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
1502
|
+
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
1503
|
+
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
1504
|
+
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
1505
|
+
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
1506
|
+
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
1507
|
+
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
1508
|
+
}
|
1509
|
+
|
1510
|
+
}
|
1511
|
+
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
1512
|
+
|
1513
|
+
const float tot = simd_sum(sumf);
|
1514
|
+
if (tiisg == 0) {
|
1515
|
+
dst[r1*ne0 + row] = tot;
|
1516
|
+
}
|
1517
|
+
|
1518
|
+
}
|
1519
|
+
#endif
|
1499
1520
|
|
1500
1521
|
#if QK_K == 256
|
1522
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1523
|
+
device const void * src0,
|
1524
|
+
device const float * src1,
|
1525
|
+
device float * dst,
|
1526
|
+
constant int64_t & ne00,
|
1527
|
+
constant int64_t & ne10,
|
1528
|
+
constant int64_t & ne0,
|
1529
|
+
constant int64_t & ne01[[buffer(4)]],
|
1530
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1531
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1532
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1501
1533
|
|
1502
1534
|
const uint16_t kmask1 = 0x3f3f;
|
1503
1535
|
const uint16_t kmask2 = 0x0f0f;
|
1504
1536
|
const uint16_t kmask3 = 0xc0c0;
|
1505
1537
|
|
1506
|
-
const int
|
1507
|
-
const int
|
1508
|
-
const int
|
1509
|
-
const int
|
1538
|
+
const int ix = tiisg/8; // 0...3
|
1539
|
+
const int it = tiisg%8; // 0...7
|
1540
|
+
const int im = it/4; // 0 or 1
|
1541
|
+
const int ir = it%4; // 0...3
|
1510
1542
|
|
1511
|
-
const int
|
1512
|
-
const int
|
1543
|
+
const int nb = ne00/QK_K;
|
1544
|
+
const int r0 = tgpig.x;
|
1545
|
+
const int r1 = tgpig.y;
|
1546
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1547
|
+
const int ib_row = first_row * nb;
|
1548
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
|
1549
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1550
|
+
float yl[16];
|
1551
|
+
float yh[16];
|
1552
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1513
1553
|
|
1514
|
-
const int
|
1515
|
-
const int q_offset = 32*im + l0;
|
1516
|
-
const int y_offset = 64*im + l0;
|
1554
|
+
const int step = sizeof(block_q4_K) * nb / 2;
|
1517
1555
|
|
1518
|
-
|
1556
|
+
device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
|
1519
1557
|
|
1520
|
-
|
1558
|
+
uint16_t sc16[4];
|
1559
|
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
1521
1560
|
|
1522
|
-
|
1523
|
-
device const uint8_t * q2 = q1 + 64;
|
1524
|
-
device const float * y1 = yy + i*QK_K + y_offset;
|
1525
|
-
device const float * y2 = y1 + 128;
|
1561
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
1526
1562
|
|
1527
|
-
|
1528
|
-
|
1563
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1564
|
+
for (int i = 0; i < 8; ++i) {
|
1565
|
+
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
1566
|
+
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
1567
|
+
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
1568
|
+
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
1569
|
+
}
|
1529
1570
|
|
1530
|
-
device const uint16_t *
|
1531
|
-
|
1532
|
-
|
1533
|
-
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
|
1534
|
-
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
|
1571
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
|
1572
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
1573
|
+
device const half * dh = &x[ib].d;
|
1535
1574
|
|
1536
|
-
|
1537
|
-
float smin = 0;
|
1538
|
-
for (int l = 0; l < n; ++l) {
|
1575
|
+
for (int row = 0; row < N_DST; row++) {
|
1539
1576
|
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1577
|
+
sc16[0] = sc[0] & kmask1;
|
1578
|
+
sc16[1] = sc[2] & kmask1;
|
1579
|
+
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
1580
|
+
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
1581
|
+
|
1582
|
+
device const uint16_t * q2 = q1 + 32;
|
1583
|
+
|
1584
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1585
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1586
|
+
for (int i = 0; i < 8; i += 2) {
|
1587
|
+
acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
|
1588
|
+
acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
|
1589
|
+
acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
|
1590
|
+
acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
|
1591
|
+
acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
|
1592
|
+
acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
|
1593
|
+
acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
|
1594
|
+
acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
|
1595
|
+
}
|
1543
1596
|
|
1597
|
+
float dall = dh[0];
|
1598
|
+
float dmin = dh[1];
|
1599
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
1600
|
+
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
1601
|
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
1602
|
+
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
1603
|
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1604
|
+
|
1605
|
+
q1 += step;
|
1606
|
+
sc += step;
|
1607
|
+
dh += step;
|
1544
1608
|
}
|
1545
|
-
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1546
1609
|
|
1610
|
+
y4 += 4 * QK_K;
|
1611
|
+
}
|
1612
|
+
|
1613
|
+
for (int row = 0; row < N_DST; ++row) {
|
1614
|
+
all_sum = simd_sum(sumf[row]);
|
1615
|
+
if (tiisg == 0) {
|
1616
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1617
|
+
}
|
1547
1618
|
}
|
1619
|
+
}
|
1548
1620
|
#else
|
1549
|
-
|
1550
|
-
|
1621
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1622
|
+
device const void * src0,
|
1623
|
+
device const float * src1,
|
1624
|
+
device float * dst,
|
1625
|
+
constant int64_t & ne00,
|
1626
|
+
constant int64_t & ne10,
|
1627
|
+
constant int64_t & ne0,
|
1628
|
+
constant int64_t & ne01[[buffer(4)]],
|
1629
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1630
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1631
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1551
1632
|
|
1552
|
-
const int
|
1633
|
+
const int ix = tiisg/4; // 0...7
|
1634
|
+
const int it = tiisg%4; // 0...3
|
1553
1635
|
|
1554
|
-
|
1636
|
+
const int nb = ne00/QK_K;
|
1637
|
+
const int r0 = tgpig.x;
|
1638
|
+
const int r1 = tgpig.y;
|
1639
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1640
|
+
const int ib_row = first_row * nb;
|
1641
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
|
1642
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1643
|
+
float yl[8];
|
1644
|
+
float yh[8];
|
1645
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1555
1646
|
|
1556
|
-
|
1557
|
-
device const float * y = yy + i * QK_K + il;
|
1647
|
+
const int step = sizeof(block_q4_K) * nb / 2;
|
1558
1648
|
|
1559
|
-
|
1560
|
-
const float m = (float)x[i].d[1];
|
1649
|
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
1561
1650
|
|
1562
|
-
|
1563
|
-
aux16[0] = a[0] & 0x0f0f;
|
1564
|
-
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
1651
|
+
uint16_t sc16[4];
|
1565
1652
|
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1653
|
+
for (int ib = ix; ib < nb; ib += 8) {
|
1654
|
+
|
1655
|
+
float2 sumy = {0.f, 0.f};
|
1656
|
+
for (int i = 0; i < 8; ++i) {
|
1657
|
+
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
1658
|
+
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
1569
1659
|
}
|
1570
|
-
}
|
1571
|
-
#endif
|
1572
1660
|
|
1573
|
-
|
1661
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
1662
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
1663
|
+
device const half * dh = x[ib].d;
|
1574
1664
|
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1665
|
+
for (int row = 0; row < N_DST; row++) {
|
1666
|
+
|
1667
|
+
sc16[0] = sc[0] & 0x000f;
|
1668
|
+
sc16[1] = sc[0] & 0x0f00;
|
1669
|
+
sc16[2] = sc[0] & 0x00f0;
|
1670
|
+
sc16[3] = sc[0] & 0xf000;
|
1671
|
+
|
1672
|
+
float2 acc1 = {0.f, 0.f};
|
1673
|
+
float2 acc2 = {0.f, 0.f};
|
1674
|
+
for (int i = 0; i < 8; i += 2) {
|
1675
|
+
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
1676
|
+
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
1677
|
+
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
1678
|
+
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
1679
|
+
}
|
1680
|
+
|
1681
|
+
float dall = dh[0];
|
1682
|
+
float dmin = dh[1];
|
1683
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
1684
|
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
1685
|
+
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
1686
|
+
|
1687
|
+
qs += step;
|
1688
|
+
sc += step;
|
1689
|
+
dh += step;
|
1690
|
+
}
|
1691
|
+
|
1692
|
+
y4 += 8 * QK_K;
|
1592
1693
|
}
|
1593
1694
|
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
1601
|
-
//}
|
1602
|
-
|
1603
|
-
//if (ith == 0) {
|
1604
|
-
// dst[r1*ne0 + r0] = sum[0];
|
1605
|
-
//}
|
1695
|
+
for (int row = 0; row < N_DST; ++row) {
|
1696
|
+
all_sum = simd_sum(sumf[row]);
|
1697
|
+
if (tiisg == 0) {
|
1698
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1699
|
+
}
|
1700
|
+
}
|
1606
1701
|
}
|
1702
|
+
#endif
|
1607
1703
|
|
1608
1704
|
kernel void kernel_mul_mat_q5_K_f32(
|
1609
1705
|
device const void * src0,
|
@@ -1612,39 +1708,39 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1612
1708
|
constant int64_t & ne00,
|
1613
1709
|
constant int64_t & ne10,
|
1614
1710
|
constant int64_t & ne0,
|
1615
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1616
1711
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1617
|
-
|
1618
|
-
|
1712
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1713
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1619
1714
|
|
1620
1715
|
const int nb = ne00/QK_K;
|
1621
1716
|
|
1622
1717
|
const int64_t r0 = tgpig.x;
|
1623
1718
|
const int64_t r1 = tgpig.y;
|
1624
1719
|
|
1625
|
-
|
1720
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1721
|
+
|
1722
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
|
1626
1723
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1627
1724
|
|
1628
|
-
|
1629
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1725
|
+
float sumf[2]={0.f};
|
1630
1726
|
|
1631
|
-
|
1727
|
+
const int step = sizeof(block_q5_K) * nb;
|
1632
1728
|
|
1633
1729
|
#if QK_K == 256
|
1730
|
+
#
|
1731
|
+
float yl[16], yh[16];
|
1634
1732
|
|
1635
1733
|
const uint16_t kmask1 = 0x3f3f;
|
1636
1734
|
const uint16_t kmask2 = 0x0f0f;
|
1637
1735
|
const uint16_t kmask3 = 0xc0c0;
|
1638
1736
|
|
1639
|
-
const int tid =
|
1640
|
-
const int
|
1641
|
-
const int
|
1642
|
-
const int
|
1643
|
-
|
1644
|
-
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
1645
|
-
const int in = il%2;
|
1737
|
+
const int tid = tiisg/4;
|
1738
|
+
const int ix = tiisg%4;
|
1739
|
+
const int im = tid/4;
|
1740
|
+
const int ir = tid%4;
|
1741
|
+
const int n = 8;
|
1646
1742
|
|
1647
|
-
const int l0 = n*
|
1743
|
+
const int l0 = n*ir;
|
1648
1744
|
const int q_offset = 32*im + l0;
|
1649
1745
|
const int y_offset = 64*im + l0;
|
1650
1746
|
|
@@ -1653,78 +1749,113 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1653
1749
|
const uint8_t hm3 = hm1 << 4;
|
1654
1750
|
const uint8_t hm4 = hm2 << 4;
|
1655
1751
|
|
1656
|
-
|
1752
|
+
uint16_t sc16[4];
|
1753
|
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
1657
1754
|
|
1658
|
-
|
1755
|
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
1659
1756
|
|
1660
|
-
|
1661
|
-
device const uint8_t * q2 = q1 + 64;
|
1662
|
-
device const uint8_t * qh = (x + i)->qh + l0;
|
1663
|
-
device const float * y1 = yy + i*QK_K + y_offset;
|
1664
|
-
device const float * y2 = y1 + 128;
|
1757
|
+
for (int i = ix; i < nb; i += 4) {
|
1665
1758
|
|
1666
|
-
const
|
1667
|
-
const
|
1759
|
+
device const uint8_t * q1 = x[i].qs + q_offset;
|
1760
|
+
device const uint8_t * qh = x[i].qh + l0;
|
1761
|
+
device const half * dh = &x[i].d;
|
1762
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
|
1668
1763
|
|
1669
|
-
device const
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1673
|
-
|
1764
|
+
device const float * y2 = y1 + 128;
|
1765
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1766
|
+
for (int l = 0; l < 8; ++l) {
|
1767
|
+
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
1768
|
+
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
1769
|
+
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
1770
|
+
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
1771
|
+
}
|
1674
1772
|
|
1675
|
-
|
1676
|
-
|
1677
|
-
|
1773
|
+
for (int row = 0; row < 2; ++row) {
|
1774
|
+
|
1775
|
+
device const uint8_t * q2 = q1 + 64;
|
1776
|
+
|
1777
|
+
sc16[0] = a[0] & kmask1;
|
1778
|
+
sc16[1] = a[2] & kmask1;
|
1779
|
+
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
1780
|
+
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
1678
1781
|
|
1679
|
-
|
1680
|
-
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1782
|
+
float4 acc = {0.f, 0.f, 0.f, 0.f};
|
1783
|
+
for (int l = 0; l < n; ++l) {
|
1784
|
+
uint8_t h = qh[l];
|
1785
|
+
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
|
1786
|
+
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
|
1787
|
+
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
|
1788
|
+
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
|
1789
|
+
}
|
1790
|
+
const float dall = dh[0];
|
1791
|
+
const float dmin = dh[1];
|
1792
|
+
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
|
1793
|
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1794
|
+
|
1795
|
+
q1 += step;
|
1796
|
+
qh += step;
|
1797
|
+
dh += step/2;
|
1798
|
+
a += step/2;
|
1684
1799
|
|
1685
1800
|
}
|
1686
|
-
|
1801
|
+
|
1802
|
+
y1 += 4 * QK_K;
|
1687
1803
|
|
1688
1804
|
}
|
1689
1805
|
#else
|
1690
|
-
|
1691
|
-
|
1692
|
-
const int
|
1806
|
+
float yl[8], yh[8];
|
1807
|
+
|
1808
|
+
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
1809
|
+
const int ix = tiisg%8;
|
1810
|
+
const int im = il/8; // 0, 0, 1, 1
|
1811
|
+
const int in = il%8; // 0, 4, 0, 4
|
1693
1812
|
|
1694
|
-
|
1813
|
+
device const float * y = yy + ix*QK_K + il;
|
1695
1814
|
|
1696
|
-
|
1815
|
+
for (int i = ix; i < nb; i += 8) {
|
1816
|
+
|
1817
|
+
for (int l = 0; l < 4; ++l) {
|
1818
|
+
yl[l+0] = y[l+ 0];
|
1819
|
+
yl[l+4] = y[l+16];
|
1820
|
+
yh[l+0] = y[l+32];
|
1821
|
+
yh[l+4] = y[l+48];
|
1822
|
+
}
|
1823
|
+
|
1824
|
+
device const half * dh = &x[i].d;
|
1697
1825
|
device const uint8_t * q = x[i].qs + il;
|
1698
1826
|
device const uint8_t * h = x[i].qh + in;
|
1699
1827
|
device const int8_t * s = x[i].scales;
|
1700
|
-
device const float * y = yy + i*QK_K + il;
|
1701
1828
|
|
1702
|
-
for (int
|
1703
|
-
|
1704
|
-
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1829
|
+
for (int row = 0; row < 2; ++row) {
|
1830
|
+
|
1831
|
+
const float d = dh[0];
|
1832
|
+
|
1833
|
+
float2 acc = {0.f, 0.f};
|
1834
|
+
for (int l = 0; l < 4; ++l) {
|
1835
|
+
const uint8_t hl = h[l] >> im;
|
1836
|
+
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
1837
|
+
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
1838
|
+
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
1839
|
+
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
1840
|
+
}
|
1841
|
+
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
1842
|
+
|
1843
|
+
q += step;
|
1844
|
+
h += step;
|
1845
|
+
s += step;
|
1846
|
+
dh += step/2;
|
1847
|
+
|
1708
1848
|
}
|
1849
|
+
|
1850
|
+
y += 8 * QK_K;
|
1709
1851
|
}
|
1710
1852
|
#endif
|
1711
|
-
sum[ith] = sumf;
|
1712
1853
|
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
1719
|
-
}
|
1720
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1721
|
-
if (ith%16 == 0) {
|
1722
|
-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
1723
|
-
}
|
1724
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1725
|
-
if (ith == 0) {
|
1726
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1727
|
-
dst[r1*ne0 + r0] = sum[0];
|
1854
|
+
for (int row = 0; row < 2; ++row) {
|
1855
|
+
const float tot = simd_sum(sumf[row]);
|
1856
|
+
if (tiisg == 0) {
|
1857
|
+
dst[r1*ne0 + first_row + row] = tot;
|
1858
|
+
}
|
1728
1859
|
}
|
1729
1860
|
|
1730
1861
|
}
|
@@ -1736,10 +1867,9 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1736
1867
|
constant int64_t & ne00,
|
1737
1868
|
constant int64_t & ne10,
|
1738
1869
|
constant int64_t & ne0,
|
1739
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1740
1870
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1741
|
-
|
1742
|
-
|
1871
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1872
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1743
1873
|
|
1744
1874
|
const uint8_t kmask1 = 0x03;
|
1745
1875
|
const uint8_t kmask2 = 0x0C;
|
@@ -1751,19 +1881,18 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1751
1881
|
const int64_t r0 = tgpig.x;
|
1752
1882
|
const int64_t r1 = tgpig.y;
|
1753
1883
|
|
1754
|
-
|
1755
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1884
|
+
const int row = 2 * r0 + sgitg;
|
1756
1885
|
|
1757
|
-
const
|
1758
|
-
const
|
1886
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
|
1887
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1759
1888
|
|
1760
1889
|
float sumf = 0;
|
1761
1890
|
|
1762
1891
|
#if QK_K == 256
|
1763
|
-
|
1764
|
-
const int
|
1765
|
-
const int ip =
|
1766
|
-
const int il =
|
1892
|
+
const int tid = tiisg/2;
|
1893
|
+
const int ix = tiisg%2;
|
1894
|
+
const int ip = tid/8; // 0 or 1
|
1895
|
+
const int il = tid%8;
|
1767
1896
|
const int n = 4;
|
1768
1897
|
const int l0 = n*il;
|
1769
1898
|
const int is = 8*ip + l0/16;
|
@@ -1772,9 +1901,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1772
1901
|
const int q_offset_l = 64*ip + l0;
|
1773
1902
|
const int q_offset_h = 32*ip + l0;
|
1774
1903
|
|
1775
|
-
for (int i =
|
1904
|
+
for (int i = ix; i < nb; i += 2) {
|
1776
1905
|
|
1777
|
-
device const uint8_t *
|
1906
|
+
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
1907
|
+
device const uint8_t * q2 = q1 + 32;
|
1778
1908
|
device const uint8_t * qh = x[i].qh + q_offset_h;
|
1779
1909
|
device const int8_t * sc = x[i].scales + is;
|
1780
1910
|
|
@@ -1784,19 +1914,21 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1784
1914
|
|
1785
1915
|
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
1786
1916
|
for (int l = 0; l < n; ++l) {
|
1787
|
-
sums[0] += y[l+ 0] * ((int8_t)((
|
1788
|
-
sums[1] += y[l+32] * ((int8_t)((
|
1789
|
-
sums[2] += y[l+64] * ((int8_t)((
|
1790
|
-
sums[3] += y[l+96] * ((int8_t)((
|
1917
|
+
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
1918
|
+
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
1919
|
+
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
1920
|
+
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
1791
1921
|
}
|
1792
1922
|
|
1793
1923
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
1794
1924
|
|
1795
1925
|
}
|
1926
|
+
|
1796
1927
|
#else
|
1797
|
-
const int
|
1928
|
+
const int ix = tiisg/4;
|
1929
|
+
const int il = 4*(tiisg%4);
|
1798
1930
|
|
1799
|
-
for (int i =
|
1931
|
+
for (int i = ix; i < nb; i += 8) {
|
1800
1932
|
device const float * y = yy + i * QK_K + il;
|
1801
1933
|
device const uint8_t * ql = x[i].ql + il;
|
1802
1934
|
device const uint8_t * qh = x[i].qh + il;
|
@@ -1816,23 +1948,8 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1816
1948
|
|
1817
1949
|
#endif
|
1818
1950
|
|
1819
|
-
|
1820
|
-
|
1821
|
-
|
1822
|
-
// Accumulate the sum from all threads in the threadgroup
|
1823
|
-
//
|
1824
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1825
|
-
if (ith%4 == 0) {
|
1826
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1951
|
+
const float tot = simd_sum(sumf);
|
1952
|
+
if (tiisg == 0) {
|
1953
|
+
dst[r1*ne0 + row] = tot;
|
1827
1954
|
}
|
1828
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1829
|
-
if (ith%16 == 0) {
|
1830
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1831
|
-
}
|
1832
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1833
|
-
if (ith == 0) {
|
1834
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1835
|
-
dst[r1*ne0 + r0] = sum[0];
|
1836
|
-
}
|
1837
|
-
|
1838
1955
|
}
|