llama_cpp 0.3.3 → 0.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +439 -9
- data/ext/llama_cpp/src/ggml-cuda.cu +759 -136
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +250 -111
- data/ext/llama_cpp/src/ggml-metal.metal +614 -483
- data/ext/llama_cpp/src/ggml.c +793 -1032
- data/ext/llama_cpp/src/ggml.h +95 -18
- data/ext/llama_cpp/src/k_quants.c +327 -3
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +626 -166
- data/ext/llama_cpp/src/llama.h +94 -10
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -0
- data/sig/llama_cpp.rbs +36 -1
- metadata +2 -2
@@ -67,6 +67,17 @@ kernel void kernel_add(
|
|
67
67
|
dst[tpig] = src0[tpig] + src1[tpig];
|
68
68
|
}
|
69
69
|
|
70
|
+
// assumption: src1 is a row
|
71
|
+
// broadcast src1 into src0
|
72
|
+
kernel void kernel_add_row(
|
73
|
+
device const float * src0,
|
74
|
+
device const float * src1,
|
75
|
+
device float * dst,
|
76
|
+
constant int64_t & ne00,
|
77
|
+
uint tpig[[thread_position_in_grid]]) {
|
78
|
+
dst[tpig] = src0[tpig] + src1[tpig % ne00];
|
79
|
+
}
|
80
|
+
|
70
81
|
kernel void kernel_mul(
|
71
82
|
device const float * src0,
|
72
83
|
device const float * src1,
|
@@ -331,26 +342,33 @@ kernel void kernel_rms_norm(
|
|
331
342
|
threadgroup float * sum [[threadgroup(0)]],
|
332
343
|
uint tgpig[[threadgroup_position_in_grid]],
|
333
344
|
uint tpitg[[thread_position_in_threadgroup]],
|
345
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
346
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
334
347
|
uint ntg[[threads_per_threadgroup]]) {
|
335
|
-
device const
|
348
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
349
|
+
device const float * x_scalar = (device const float *) x;
|
350
|
+
float4 sumf=0;
|
351
|
+
float all_sum=0;
|
336
352
|
|
337
353
|
// parallel sum
|
338
|
-
|
339
|
-
|
340
|
-
|
354
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
355
|
+
sumf += x[i00] * x[i00];
|
356
|
+
}
|
357
|
+
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
358
|
+
all_sum = simd_sum(all_sum);
|
359
|
+
if (tiisg == 0) {
|
360
|
+
sum[sgitg] = all_sum;
|
341
361
|
}
|
342
362
|
|
343
|
-
// reduce
|
344
363
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
364
|
+
// broadcast, simd group number is ntg / 32
|
365
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
366
|
+
if (tpitg < i) {
|
367
|
+
sum[tpitg] += sum[tpitg + i];
|
368
|
+
}
|
350
369
|
}
|
351
|
-
|
352
|
-
// broadcast
|
353
370
|
if (tpitg == 0) {
|
371
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
354
372
|
sum[0] /= ne00;
|
355
373
|
}
|
356
374
|
|
@@ -359,156 +377,130 @@ kernel void kernel_rms_norm(
|
|
359
377
|
const float mean = sum[0];
|
360
378
|
const float scale = 1.0f/sqrt(mean + eps);
|
361
379
|
|
362
|
-
device
|
363
|
-
|
380
|
+
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
381
|
+
device float * y_scalar = (device float *) y;
|
382
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
364
383
|
y[i00] = x[i00] * scale;
|
365
384
|
}
|
385
|
+
if (tpitg == 0) {
|
386
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
387
|
+
}
|
388
|
+
}
|
389
|
+
|
390
|
+
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
391
|
+
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
392
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
393
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
394
|
+
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
395
|
+
float d = qb_curr->d;
|
396
|
+
float2 acc = 0.f;
|
397
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
398
|
+
for (int i = 0; i < 8; i+=2) {
|
399
|
+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
400
|
+
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
401
|
+
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
402
|
+
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
403
|
+
}
|
404
|
+
return d * (sumy * -8.f + acc[0] + acc[1]);
|
405
|
+
}
|
406
|
+
|
407
|
+
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
408
|
+
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
409
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
410
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
411
|
+
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
412
|
+
float d = qb_curr->d;
|
413
|
+
float m = qb_curr->m;
|
414
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
415
|
+
float2 acc = 0.f;
|
416
|
+
for (int i = 0; i < 8; i+=2) {
|
417
|
+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
418
|
+
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
419
|
+
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
420
|
+
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
421
|
+
}
|
422
|
+
return d * (acc[0] + acc[1]) + sumy * m;
|
366
423
|
}
|
367
424
|
|
368
425
|
// putting them in the kernel cause a significant performance penalty
|
369
426
|
#define N_DST 4 // each SIMD group works on 4 rows
|
370
427
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
371
428
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
uint2 tgpig[[threadgroup_position_in_grid]],
|
381
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
382
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
429
|
+
//Note: This is a template, but strictly speaking it only applies to
|
430
|
+
// quantizations where the block size is 32. It also does not
|
431
|
+
// giard against the number of rows not being divisible by
|
432
|
+
// N_DST, so this is another explicit assumption of the implementation.
|
433
|
+
template<typename block_q_type, int nr, int nsg, int nw>
|
434
|
+
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
435
|
+
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
436
|
+
uint2 tgpig, uint tiisg, uint sgitg) {
|
383
437
|
const int nb = ne00/QK4_0;
|
384
438
|
const int r0 = tgpig.x;
|
385
439
|
const int r1 = tgpig.y;
|
386
|
-
|
440
|
+
const int first_row = (r0 * nsg + sgitg) * nr;
|
441
|
+
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
|
387
442
|
device const float * y = (device const float *) src1 + r1*ne10;
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
443
|
+
float yl[16]; // src1 vector cache
|
444
|
+
float sumf[nr]={0.f};
|
445
|
+
|
446
|
+
const int ix = tiisg/2;
|
447
|
+
const int il = 8*(tiisg%2);
|
448
|
+
|
449
|
+
device const float * yb = y + ix * QK4_0 + il;
|
450
|
+
|
451
|
+
// each thread in a SIMD group deals with half a block.
|
452
|
+
for (int ib = ix; ib < nb; ib += nw/2) {
|
453
|
+
float sumy = 0;
|
454
|
+
for (int i = 0; i < 8; i += 2) {
|
455
|
+
sumy += yb[i] + yb[i+1];
|
456
|
+
yl[i+0] = yb[i+ 0];
|
457
|
+
yl[i+1] = yb[i+ 1]/256.f;
|
458
|
+
sumy += yb[i+16] + yb[i+17];
|
459
|
+
yl[i+8] = yb[i+16]/16.f;
|
460
|
+
yl[i+9] = yb[i+17]/4096.f;
|
400
461
|
}
|
401
462
|
|
402
|
-
for (int row = 0; row <
|
403
|
-
|
404
|
-
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
405
|
-
|
406
|
-
// calculate
|
407
|
-
float d = qb_curr.d;
|
408
|
-
float2 acc = {0.0f, 0.0f};
|
409
|
-
for (int i = 0; i < 16; i++) {
|
410
|
-
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
411
|
-
acc[1] += yl[i] + yl[i+16];
|
412
|
-
}
|
413
|
-
sumf[row] += d * (acc[0] - 8.f*acc[1]);
|
414
|
-
qb_curr = qb_next;
|
463
|
+
for (int row = 0; row < nr; row++) {
|
464
|
+
sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
|
415
465
|
}
|
416
|
-
}
|
417
466
|
|
418
|
-
|
419
|
-
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
467
|
+
yb += QK4_0 * 16;
|
420
468
|
}
|
421
469
|
|
422
|
-
for (int row = 0; row <
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
// calculate
|
427
|
-
float d = qb_curr.d;
|
428
|
-
float2 acc = {0.0f, 0.0f};
|
429
|
-
for (int i = 0; i < 16; i++) {
|
430
|
-
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
|
431
|
-
acc[1] += yl[i] + yl[i+16];
|
432
|
-
}
|
433
|
-
if (tiisg < nb % N_SIMDWIDTH) {
|
434
|
-
sumf[row] += d * (acc[0] - 8.f*acc[1]);
|
435
|
-
}
|
436
|
-
qb_curr = qb_next;
|
437
|
-
|
438
|
-
all_sum = simd_sum(sumf[row]);
|
439
|
-
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
440
|
-
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
470
|
+
for (int row = 0; row < nr; ++row) {
|
471
|
+
const float tot = simd_sum(sumf[row]);
|
472
|
+
if (tiisg == 0 && first_row + row < ne01) {
|
473
|
+
dst[r1*ne0 + first_row + row] = tot;
|
441
474
|
}
|
442
475
|
}
|
443
476
|
}
|
444
477
|
|
445
|
-
kernel void
|
478
|
+
kernel void kernel_mul_mat_q4_0_f32(
|
446
479
|
device const void * src0,
|
447
480
|
device const float * src1,
|
448
481
|
device float * dst,
|
449
482
|
constant int64_t & ne00,
|
450
483
|
constant int64_t & ne10,
|
451
484
|
constant int64_t & ne0,
|
452
|
-
|
485
|
+
constant int64_t & ne01[[buffer(4)]],
|
453
486
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
const int64_t r0 = tgpig.x;
|
459
|
-
const int64_t r1 = tgpig.y;
|
460
|
-
|
461
|
-
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
|
462
|
-
device const float * y = (device const float *) src1 + r1*ne10;
|
463
|
-
|
464
|
-
const uint nth = tptg.x*tptg.y;
|
465
|
-
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
466
|
-
|
467
|
-
const int ix = tpitg.y/4; // 0 or 1
|
468
|
-
const int iy = tpitg.y - 4*ix; // 0...3
|
469
|
-
|
470
|
-
const int first = 4 * iy;
|
471
|
-
|
472
|
-
float sumf = 0;
|
473
|
-
|
474
|
-
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
475
|
-
|
476
|
-
const float d = (float)x[i].d;
|
477
|
-
const float m = (float)x[i].m;
|
478
|
-
|
479
|
-
device const uint8_t * xl = x[i].qs + first;
|
480
|
-
device const float * yl = y + i * QK4_1 + first;
|
481
|
-
|
482
|
-
float2 acc = {0.0f, 0.0f};
|
483
|
-
|
484
|
-
for (int j = 0; j < 4; ++j) {
|
485
|
-
|
486
|
-
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
|
487
|
-
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
|
488
|
-
|
489
|
-
}
|
490
|
-
|
491
|
-
sumf += acc[0] + acc[1];
|
492
|
-
}
|
493
|
-
|
494
|
-
sum[ith] = sumf;
|
487
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
488
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
489
|
+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
490
|
+
}
|
495
491
|
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
if (ith == 0) {
|
509
|
-
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
510
|
-
dst[r1*ne0 + r0] = sum[0];
|
511
|
-
}
|
492
|
+
kernel void kernel_mul_mat_q4_1_f32(
|
493
|
+
device const void * src0,
|
494
|
+
device const float * src1,
|
495
|
+
device float * dst,
|
496
|
+
constant int64_t & ne00,
|
497
|
+
constant int64_t & ne10,
|
498
|
+
constant int64_t & ne0,
|
499
|
+
constant int64_t & ne01[[buffer(4)]],
|
500
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
501
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
502
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
503
|
+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
512
504
|
}
|
513
505
|
|
514
506
|
kernel void kernel_mul_mat_f16_f32(
|
@@ -624,17 +616,19 @@ kernel void kernel_rope(
|
|
624
616
|
constant int & n_past,
|
625
617
|
constant int & n_dims,
|
626
618
|
constant int & mode,
|
619
|
+
constant float & freq_base,
|
620
|
+
constant float & freq_scale,
|
627
621
|
uint3 tpig[[thread_position_in_grid]]) {
|
628
622
|
const int64_t i3 = tpig[2];
|
629
623
|
const int64_t i2 = tpig[1];
|
630
624
|
const int64_t i1 = tpig[0];
|
631
625
|
|
632
626
|
const bool is_neox = mode & 2;
|
633
|
-
const float theta_scale = pow(
|
627
|
+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
634
628
|
|
635
629
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
636
630
|
|
637
|
-
float theta = (float)p;
|
631
|
+
float theta = freq_scale * (float)p;
|
638
632
|
|
639
633
|
if (!is_neox) {
|
640
634
|
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
@@ -1229,111 +1223,137 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1229
1223
|
constant int64_t & ne00,
|
1230
1224
|
constant int64_t & ne10,
|
1231
1225
|
constant int64_t & ne0,
|
1232
|
-
|
1226
|
+
constant int64_t & ne01[[buffer(4)]],
|
1233
1227
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1234
|
-
|
1235
|
-
|
1228
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1229
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1236
1230
|
|
1237
1231
|
const int nb = ne00/QK_K;
|
1232
|
+
const int r0 = tgpig.x;
|
1233
|
+
const int r1 = tgpig.y;
|
1238
1234
|
|
1239
|
-
const
|
1240
|
-
const
|
1241
|
-
|
1242
|
-
device const
|
1243
|
-
|
1244
|
-
|
1245
|
-
const int nth = tptg.x*tptg.y;
|
1246
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1235
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1236
|
+
const int ib_row = first_row * nb;
|
1237
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
|
1238
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1239
|
+
float yl[32];
|
1240
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1247
1241
|
|
1248
|
-
|
1242
|
+
const int step = sizeof(block_q2_K) * nb;
|
1249
1243
|
|
1250
1244
|
#if QK_K == 256
|
1251
|
-
const int
|
1252
|
-
const int
|
1253
|
-
const int
|
1254
|
-
const int
|
1255
|
-
const int
|
1256
|
-
|
1257
|
-
const
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1245
|
+
const int ix = tiisg/8; // 0...3
|
1246
|
+
const int it = tiisg%8; // 0...7
|
1247
|
+
const int im = it/4; // 0 or 1
|
1248
|
+
const int ir = it%4; // 0...3
|
1249
|
+
const int is = (8*ir)/16;// 0 or 1
|
1250
|
+
|
1251
|
+
device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
|
1252
|
+
|
1253
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
1254
|
+
|
1255
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1256
|
+
for (int i = 0; i < 8; ++i) {
|
1257
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
1258
|
+
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
1259
|
+
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
1260
|
+
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
1261
|
+
}
|
1267
1262
|
|
1268
|
-
uint8_t
|
1269
|
-
|
1270
|
-
|
1271
|
-
uint8_t m2 = scales[2] >> 4;
|
1263
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
|
1264
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
1265
|
+
device const half * dh = &x[ib].d;
|
1272
1266
|
|
1273
|
-
|
1267
|
+
for (int row = 0; row < N_DST; row++) {
|
1274
1268
|
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1269
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1270
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1271
|
+
for (int i = 0; i < 8; i += 2) {
|
1272
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
1273
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
1274
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
1275
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
1276
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
1277
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
1278
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
1279
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
1280
|
+
}
|
1281
|
+
float dall = dh[0];
|
1282
|
+
float dmin = dh[1] * 1.f/16.f;
|
1283
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
1284
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
|
1285
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
|
1286
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
1287
|
+
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
1288
|
+
|
1289
|
+
qs += step/2;
|
1290
|
+
sc += step;
|
1291
|
+
dh += step/2;
|
1281
1292
|
}
|
1282
1293
|
|
1283
|
-
|
1284
|
-
const float dmin = (float)x[i].dmin;
|
1285
|
-
|
1286
|
-
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
1287
|
-
|
1294
|
+
y4 += 4 * QK_K;
|
1288
1295
|
}
|
1289
1296
|
#else
|
1290
|
-
const int
|
1297
|
+
const int ix = tiisg/2; // 0...15
|
1298
|
+
const int it = tiisg%2; // 0...1
|
1291
1299
|
|
1292
|
-
|
1293
|
-
thread const uint8_t * d = (thread const uint8_t *)aux;
|
1294
|
-
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
1300
|
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
1295
1301
|
|
1296
|
-
for (int
|
1302
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
1297
1303
|
|
1298
|
-
|
1299
|
-
|
1304
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1305
|
+
for (int i = 0; i < 8; ++i) {
|
1306
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
1307
|
+
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
1308
|
+
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
1309
|
+
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
1310
|
+
}
|
1311
|
+
|
1312
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
1313
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
1314
|
+
device const half * dh = &x[ib].d;
|
1300
1315
|
|
1301
|
-
|
1302
|
-
const float dmin = (float)x[i].dmin;
|
1316
|
+
for (int row = 0; row < N_DST; row++) {
|
1303
1317
|
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1318
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1319
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1320
|
+
for (int i = 0; i < 8; i += 2) {
|
1321
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
1322
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
1323
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
1324
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
1325
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
1326
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
1327
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
1328
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
1329
|
+
}
|
1307
1330
|
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1331
|
+
float dall = dh[0];
|
1332
|
+
float dmin = dh[1];
|
1333
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
1334
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
1335
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
1336
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
1337
|
+
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
1338
|
+
|
1339
|
+
qs += step/2;
|
1340
|
+
sc += step;
|
1341
|
+
dh += step/2;
|
1313
1342
|
}
|
1343
|
+
|
1344
|
+
y4 += 16 * QK_K;
|
1314
1345
|
}
|
1315
1346
|
#endif
|
1316
1347
|
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1323
|
-
if (ith%4 == 0) {
|
1324
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1325
|
-
}
|
1326
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1327
|
-
if (ith%16 == 0) {
|
1328
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1329
|
-
}
|
1330
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1331
|
-
if (ith == 0) {
|
1332
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1333
|
-
dst[r1*ne0 + r0] = sum[0];
|
1348
|
+
for (int row = 0; row < N_DST; ++row) {
|
1349
|
+
all_sum = simd_sum(sumf[row]);
|
1350
|
+
if (tiisg == 0) {
|
1351
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1352
|
+
}
|
1334
1353
|
}
|
1335
1354
|
}
|
1336
1355
|
|
1356
|
+
#if QK_K == 256
|
1337
1357
|
kernel void kernel_mul_mat_q3_K_f32(
|
1338
1358
|
device const void * src0,
|
1339
1359
|
device const float * src1,
|
@@ -1342,40 +1362,41 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1342
1362
|
constant int64_t & ne10,
|
1343
1363
|
constant int64_t & ne0,
|
1344
1364
|
constant int64_t & ne1,
|
1345
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1346
1365
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1347
|
-
|
1348
|
-
|
1366
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1367
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1349
1368
|
|
1350
1369
|
const int nb = ne00/QK_K;
|
1351
1370
|
|
1352
1371
|
const int64_t r0 = tgpig.x;
|
1353
1372
|
const int64_t r1 = tgpig.y;
|
1354
1373
|
|
1355
|
-
|
1356
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1357
|
-
|
1358
|
-
const int nth = tptg.x*tptg.y;
|
1359
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1374
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1360
1375
|
|
1361
|
-
|
1376
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
1377
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1362
1378
|
|
1363
|
-
|
1364
|
-
const int8_t m4 = 4;
|
1379
|
+
float yl[16];
|
1365
1380
|
|
1366
1381
|
const uint16_t kmask1 = 0x0303;
|
1367
1382
|
const uint16_t kmask2 = 0x0f0f;
|
1368
1383
|
|
1369
|
-
const int tid =
|
1384
|
+
const int tid = tiisg/2;
|
1385
|
+
const int ix = tiisg%2;
|
1370
1386
|
const int ip = tid/8; // 0 or 1
|
1371
1387
|
const int il = tid/2 - 4*ip; // 0...3
|
1372
1388
|
const int ir = tid%2;
|
1373
1389
|
const int n = 8;
|
1374
1390
|
const int l0 = n*ir;
|
1375
1391
|
|
1376
|
-
const
|
1392
|
+
const uint16_t m1 = 1 << (4*ip + il);
|
1393
|
+
const uint16_t m2 = m1 << 8;
|
1377
1394
|
|
1378
1395
|
const int shift = 2*il;
|
1396
|
+
const uint16_t qm1 = 0x0003 << shift;
|
1397
|
+
const uint16_t qm2 = 0x0300 << shift;
|
1398
|
+
const int32_t v1 = 4 << shift;
|
1399
|
+
const int32_t v2 = 1024 << shift;
|
1379
1400
|
|
1380
1401
|
const uint16_t s_shift1 = 4*ip;
|
1381
1402
|
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
@@ -1384,226 +1405,315 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1384
1405
|
const int q_offset = 32*ip + l0;
|
1385
1406
|
const int y_offset = 128*ip + 32*il + l0;
|
1386
1407
|
|
1387
|
-
|
1388
|
-
float sumf1 = 0, sumf2 = 0;
|
1389
|
-
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1408
|
+
const int step = sizeof(block_q3_K) * nb / 2;
|
1390
1409
|
|
1391
|
-
|
1392
|
-
|
1393
|
-
device const uint8_t * q = x[i].qs + q_offset;
|
1394
|
-
device const uint8_t * h = x[i].hmask + l0;
|
1395
|
-
device const float * y = yy + i * QK_K + y_offset;
|
1410
|
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
1396
1411
|
|
1397
|
-
|
1398
|
-
|
1412
|
+
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
1413
|
+
for (int i = ix; i < nb; i += 2) {
|
1399
1414
|
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
}
|
1404
|
-
float d = d_all * s;
|
1405
|
-
sumf1 += d * scales[0];
|
1406
|
-
sumf2 += d;
|
1407
|
-
//sumf += d_all * s * (scales[0] - 32);
|
1408
|
-
|
1409
|
-
s = 0;
|
1410
|
-
for (int l = 0; l < n; ++l) {
|
1411
|
-
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
|
1415
|
+
for (int l = 0; l < 8; ++l) {
|
1416
|
+
yl[l+0] = y1[l+ 0];
|
1417
|
+
yl[l+8] = y1[l+16];
|
1412
1418
|
}
|
1413
|
-
d = d_all * s;
|
1414
|
-
sumf1 += d * scales[1];
|
1415
|
-
sumf2 += d;
|
1416
|
-
//sumf += d_all * s * (scales[1] - 32);
|
1417
|
-
|
1418
|
-
}
|
1419
|
-
|
1420
|
-
//sum[ith] = sumf;
|
1421
|
-
sum[ith] = sumf1 - 32.f*sumf2;
|
1422
|
-
#else
|
1423
|
-
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1424
|
-
const int im = il/8; // 0, 0, 1, 1
|
1425
|
-
const int in = il%8; // 0, 4, 0, 4
|
1426
1419
|
|
1427
|
-
|
1420
|
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
1421
|
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
1422
|
+
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
1423
|
+
device const half * dh = &x[i].d;
|
1428
1424
|
|
1429
|
-
|
1425
|
+
for (int row = 0; row < 2; ++row) {
|
1430
1426
|
|
1431
|
-
|
1427
|
+
const float d_all = (float)dh[0];
|
1428
|
+
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1432
1429
|
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1430
|
+
float s1 = 0, s2 = 0;
|
1431
|
+
for (int l = 0; l < n; l += 2) {
|
1432
|
+
const uint16_t qs = q[l/2];
|
1433
|
+
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
1434
|
+
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
1435
|
+
}
|
1436
|
+
float d = d_all * (s1 + 1.f/256.f * s2);
|
1437
|
+
sumf1[row] += d * scales[0];
|
1438
|
+
sumf2[row] += d;
|
1439
|
+
|
1440
|
+
s1 = s2 = 0;
|
1441
|
+
for (int l = 0; l < n; l += 2) {
|
1442
|
+
const uint16_t qs = q[l/2+8];
|
1443
|
+
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
1444
|
+
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
1445
|
+
}
|
1446
|
+
d = d_all * (s1 + 1.f/256.f * s2);
|
1447
|
+
sumf1[row] += d * scales[1];
|
1448
|
+
sumf2[row] += d;
|
1436
1449
|
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1450
|
+
q += step;
|
1451
|
+
h += step;
|
1452
|
+
a += step;
|
1453
|
+
dh += step;
|
1441
1454
|
|
1442
|
-
for (int l = 0; l < 4; ++l) {
|
1443
|
-
const uint8_t hm = h[l] >> im;
|
1444
|
-
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
1445
|
-
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
1446
|
-
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
1447
|
-
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
1448
1455
|
}
|
1449
1456
|
|
1450
|
-
|
1451
|
-
|
1452
|
-
sum[ith] = sumf;
|
1457
|
+
y1 += 2 * QK_K;
|
1453
1458
|
|
1454
|
-
#endif
|
1455
|
-
|
1456
|
-
//
|
1457
|
-
// Accumulate the sum from all threads in the threadgroup
|
1458
|
-
//
|
1459
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1460
|
-
if (ith%4 == 0) {
|
1461
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1462
|
-
}
|
1463
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1464
|
-
if (ith%16 == 0) {
|
1465
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1466
|
-
}
|
1467
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1468
|
-
if (ith == 0) {
|
1469
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1470
|
-
dst[r1*ne0 + r0] = sum[0];
|
1471
1459
|
}
|
1472
1460
|
|
1461
|
+
for (int row = 0; row < 2; ++row) {
|
1462
|
+
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
1463
|
+
const float tot = simd_sum(sumf);
|
1464
|
+
if (tiisg == 0) {
|
1465
|
+
dst[r1*ne0 + first_row + row] = tot;
|
1466
|
+
}
|
1467
|
+
}
|
1473
1468
|
}
|
1474
|
-
|
1475
|
-
kernel void
|
1469
|
+
#else
|
1470
|
+
kernel void kernel_mul_mat_q3_K_f32(
|
1476
1471
|
device const void * src0,
|
1477
1472
|
device const float * src1,
|
1478
1473
|
device float * dst,
|
1479
1474
|
constant int64_t & ne00,
|
1480
1475
|
constant int64_t & ne10,
|
1481
1476
|
constant int64_t & ne0,
|
1482
|
-
|
1477
|
+
constant int64_t & ne1,
|
1483
1478
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1484
|
-
|
1485
|
-
|
1479
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1480
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1486
1481
|
|
1487
1482
|
const int nb = ne00/QK_K;
|
1488
1483
|
|
1489
1484
|
const int64_t r0 = tgpig.x;
|
1490
1485
|
const int64_t r1 = tgpig.y;
|
1491
1486
|
|
1492
|
-
const int
|
1493
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1487
|
+
const int row = 2 * r0 + sgitg;
|
1494
1488
|
|
1495
|
-
device const
|
1489
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
1496
1490
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1491
|
+
const int ix = tiisg/4;
|
1492
|
+
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1493
|
+
const int im = il/8; // 0, 0, 1, 1
|
1494
|
+
const int in = il%8; // 0, 4, 0, 4
|
1497
1495
|
|
1498
|
-
|
1496
|
+
float2 sum = {0.f, 0.f};
|
1497
|
+
|
1498
|
+
for (int i = ix; i < nb; i += 8) {
|
1499
|
+
|
1500
|
+
const float d_all = (float)(x[i].d);
|
1501
|
+
|
1502
|
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
1503
|
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
1504
|
+
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
1505
|
+
device const float * y = yy + i * QK_K + il;
|
1506
|
+
|
1507
|
+
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
1508
|
+
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
1509
|
+
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
1510
|
+
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
1511
|
+
|
1512
|
+
for (int l = 0; l < 4; l += 2) {
|
1513
|
+
const uint16_t hm = h[l/2] >> im;
|
1514
|
+
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
1515
|
+
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
1516
|
+
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
1517
|
+
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
1518
|
+
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
1519
|
+
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
1520
|
+
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
1521
|
+
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
1522
|
+
}
|
1523
|
+
|
1524
|
+
}
|
1525
|
+
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
1526
|
+
|
1527
|
+
const float tot = simd_sum(sumf);
|
1528
|
+
if (tiisg == 0) {
|
1529
|
+
dst[r1*ne0 + row] = tot;
|
1530
|
+
}
|
1531
|
+
|
1532
|
+
}
|
1533
|
+
#endif
|
1499
1534
|
|
1500
1535
|
#if QK_K == 256
|
1536
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1537
|
+
device const void * src0,
|
1538
|
+
device const float * src1,
|
1539
|
+
device float * dst,
|
1540
|
+
constant int64_t & ne00,
|
1541
|
+
constant int64_t & ne10,
|
1542
|
+
constant int64_t & ne0,
|
1543
|
+
constant int64_t & ne01[[buffer(4)]],
|
1544
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1545
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1546
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1501
1547
|
|
1502
1548
|
const uint16_t kmask1 = 0x3f3f;
|
1503
1549
|
const uint16_t kmask2 = 0x0f0f;
|
1504
1550
|
const uint16_t kmask3 = 0xc0c0;
|
1505
1551
|
|
1506
|
-
const int
|
1507
|
-
const int
|
1508
|
-
const int
|
1509
|
-
const int
|
1552
|
+
const int ix = tiisg/8; // 0...3
|
1553
|
+
const int it = tiisg%8; // 0...7
|
1554
|
+
const int im = it/4; // 0 or 1
|
1555
|
+
const int ir = it%4; // 0...3
|
1510
1556
|
|
1511
|
-
const int
|
1512
|
-
const int
|
1557
|
+
const int nb = ne00/QK_K;
|
1558
|
+
const int r0 = tgpig.x;
|
1559
|
+
const int r1 = tgpig.y;
|
1560
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1561
|
+
const int ib_row = first_row * nb;
|
1562
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
|
1563
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1564
|
+
float yl[16];
|
1565
|
+
float yh[16];
|
1566
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1513
1567
|
|
1514
|
-
const int
|
1515
|
-
const int q_offset = 32*im + l0;
|
1516
|
-
const int y_offset = 64*im + l0;
|
1568
|
+
const int step = sizeof(block_q4_K) * nb / 2;
|
1517
1569
|
|
1518
|
-
|
1570
|
+
device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
|
1519
1571
|
|
1520
|
-
|
1572
|
+
uint16_t sc16[4];
|
1573
|
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
1521
1574
|
|
1522
|
-
|
1523
|
-
device const uint8_t * q2 = q1 + 64;
|
1524
|
-
device const float * y1 = yy + i*QK_K + y_offset;
|
1525
|
-
device const float * y2 = y1 + 128;
|
1575
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
1526
1576
|
|
1527
|
-
|
1528
|
-
|
1577
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1578
|
+
for (int i = 0; i < 8; ++i) {
|
1579
|
+
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
1580
|
+
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
1581
|
+
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
1582
|
+
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
1583
|
+
}
|
1529
1584
|
|
1530
|
-
device const uint16_t *
|
1531
|
-
|
1532
|
-
|
1533
|
-
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
|
1534
|
-
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
|
1585
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
|
1586
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
1587
|
+
device const half * dh = &x[ib].d;
|
1535
1588
|
|
1536
|
-
|
1537
|
-
float smin = 0;
|
1538
|
-
for (int l = 0; l < n; ++l) {
|
1589
|
+
for (int row = 0; row < N_DST; row++) {
|
1539
1590
|
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1591
|
+
sc16[0] = sc[0] & kmask1;
|
1592
|
+
sc16[1] = sc[2] & kmask1;
|
1593
|
+
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
1594
|
+
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
1595
|
+
|
1596
|
+
device const uint16_t * q2 = q1 + 32;
|
1597
|
+
|
1598
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1599
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1600
|
+
for (int i = 0; i < 8; i += 2) {
|
1601
|
+
acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
|
1602
|
+
acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
|
1603
|
+
acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
|
1604
|
+
acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
|
1605
|
+
acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
|
1606
|
+
acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
|
1607
|
+
acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
|
1608
|
+
acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
|
1609
|
+
}
|
1543
1610
|
|
1611
|
+
float dall = dh[0];
|
1612
|
+
float dmin = dh[1];
|
1613
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
1614
|
+
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
1615
|
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
1616
|
+
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
1617
|
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1618
|
+
|
1619
|
+
q1 += step;
|
1620
|
+
sc += step;
|
1621
|
+
dh += step;
|
1544
1622
|
}
|
1545
|
-
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1546
1623
|
|
1624
|
+
y4 += 4 * QK_K;
|
1625
|
+
}
|
1626
|
+
|
1627
|
+
for (int row = 0; row < N_DST; ++row) {
|
1628
|
+
all_sum = simd_sum(sumf[row]);
|
1629
|
+
if (tiisg == 0) {
|
1630
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1631
|
+
}
|
1547
1632
|
}
|
1633
|
+
}
|
1548
1634
|
#else
|
1549
|
-
|
1550
|
-
|
1635
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1636
|
+
device const void * src0,
|
1637
|
+
device const float * src1,
|
1638
|
+
device float * dst,
|
1639
|
+
constant int64_t & ne00,
|
1640
|
+
constant int64_t & ne10,
|
1641
|
+
constant int64_t & ne0,
|
1642
|
+
constant int64_t & ne01[[buffer(4)]],
|
1643
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1644
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1645
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1551
1646
|
|
1552
|
-
const int
|
1647
|
+
const int ix = tiisg/4; // 0...7
|
1648
|
+
const int it = tiisg%4; // 0...3
|
1553
1649
|
|
1554
|
-
|
1650
|
+
const int nb = ne00/QK_K;
|
1651
|
+
const int r0 = tgpig.x;
|
1652
|
+
const int r1 = tgpig.y;
|
1653
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1654
|
+
const int ib_row = first_row * nb;
|
1655
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
|
1656
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1657
|
+
float yl[8];
|
1658
|
+
float yh[8];
|
1659
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1555
1660
|
|
1556
|
-
|
1557
|
-
device const float * y = yy + i * QK_K + il;
|
1661
|
+
const int step = sizeof(block_q4_K) * nb / 2;
|
1558
1662
|
|
1559
|
-
|
1560
|
-
const float m = (float)x[i].d[1];
|
1663
|
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
1561
1664
|
|
1562
|
-
|
1563
|
-
aux16[0] = a[0] & 0x0f0f;
|
1564
|
-
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
1665
|
+
uint16_t sc16[4];
|
1565
1666
|
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1667
|
+
for (int ib = ix; ib < nb; ib += 8) {
|
1668
|
+
|
1669
|
+
float2 sumy = {0.f, 0.f};
|
1670
|
+
for (int i = 0; i < 8; ++i) {
|
1671
|
+
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
1672
|
+
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
1569
1673
|
}
|
1570
|
-
}
|
1571
|
-
#endif
|
1572
1674
|
|
1573
|
-
|
1675
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
1676
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
1677
|
+
device const half * dh = x[ib].d;
|
1574
1678
|
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1679
|
+
for (int row = 0; row < N_DST; row++) {
|
1680
|
+
|
1681
|
+
sc16[0] = sc[0] & 0x000f;
|
1682
|
+
sc16[1] = sc[0] & 0x0f00;
|
1683
|
+
sc16[2] = sc[0] & 0x00f0;
|
1684
|
+
sc16[3] = sc[0] & 0xf000;
|
1685
|
+
|
1686
|
+
float2 acc1 = {0.f, 0.f};
|
1687
|
+
float2 acc2 = {0.f, 0.f};
|
1688
|
+
for (int i = 0; i < 8; i += 2) {
|
1689
|
+
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
1690
|
+
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
1691
|
+
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
1692
|
+
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
1693
|
+
}
|
1694
|
+
|
1695
|
+
float dall = dh[0];
|
1696
|
+
float dmin = dh[1];
|
1697
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
1698
|
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
1699
|
+
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
1700
|
+
|
1701
|
+
qs += step;
|
1702
|
+
sc += step;
|
1703
|
+
dh += step;
|
1704
|
+
}
|
1705
|
+
|
1706
|
+
y4 += 8 * QK_K;
|
1592
1707
|
}
|
1593
1708
|
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
1601
|
-
//}
|
1602
|
-
|
1603
|
-
//if (ith == 0) {
|
1604
|
-
// dst[r1*ne0 + r0] = sum[0];
|
1605
|
-
//}
|
1709
|
+
for (int row = 0; row < N_DST; ++row) {
|
1710
|
+
all_sum = simd_sum(sumf[row]);
|
1711
|
+
if (tiisg == 0) {
|
1712
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1713
|
+
}
|
1714
|
+
}
|
1606
1715
|
}
|
1716
|
+
#endif
|
1607
1717
|
|
1608
1718
|
kernel void kernel_mul_mat_q5_K_f32(
|
1609
1719
|
device const void * src0,
|
@@ -1612,39 +1722,39 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1612
1722
|
constant int64_t & ne00,
|
1613
1723
|
constant int64_t & ne10,
|
1614
1724
|
constant int64_t & ne0,
|
1615
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1616
1725
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1617
|
-
|
1618
|
-
|
1726
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1727
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1619
1728
|
|
1620
1729
|
const int nb = ne00/QK_K;
|
1621
1730
|
|
1622
1731
|
const int64_t r0 = tgpig.x;
|
1623
1732
|
const int64_t r1 = tgpig.y;
|
1624
1733
|
|
1625
|
-
|
1734
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1735
|
+
|
1736
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
|
1626
1737
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1627
1738
|
|
1628
|
-
|
1629
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1739
|
+
float sumf[2]={0.f};
|
1630
1740
|
|
1631
|
-
|
1741
|
+
const int step = sizeof(block_q5_K) * nb;
|
1632
1742
|
|
1633
1743
|
#if QK_K == 256
|
1744
|
+
#
|
1745
|
+
float yl[16], yh[16];
|
1634
1746
|
|
1635
1747
|
const uint16_t kmask1 = 0x3f3f;
|
1636
1748
|
const uint16_t kmask2 = 0x0f0f;
|
1637
1749
|
const uint16_t kmask3 = 0xc0c0;
|
1638
1750
|
|
1639
|
-
const int tid =
|
1640
|
-
const int
|
1641
|
-
const int
|
1642
|
-
const int
|
1643
|
-
|
1644
|
-
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
1645
|
-
const int in = il%2;
|
1751
|
+
const int tid = tiisg/4;
|
1752
|
+
const int ix = tiisg%4;
|
1753
|
+
const int im = tid/4;
|
1754
|
+
const int ir = tid%4;
|
1755
|
+
const int n = 8;
|
1646
1756
|
|
1647
|
-
const int l0 = n*
|
1757
|
+
const int l0 = n*ir;
|
1648
1758
|
const int q_offset = 32*im + l0;
|
1649
1759
|
const int y_offset = 64*im + l0;
|
1650
1760
|
|
@@ -1653,78 +1763,113 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1653
1763
|
const uint8_t hm3 = hm1 << 4;
|
1654
1764
|
const uint8_t hm4 = hm2 << 4;
|
1655
1765
|
|
1656
|
-
|
1766
|
+
uint16_t sc16[4];
|
1767
|
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
1657
1768
|
|
1658
|
-
|
1769
|
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
1659
1770
|
|
1660
|
-
|
1661
|
-
device const uint8_t * q2 = q1 + 64;
|
1662
|
-
device const uint8_t * qh = (x + i)->qh + l0;
|
1663
|
-
device const float * y1 = yy + i*QK_K + y_offset;
|
1664
|
-
device const float * y2 = y1 + 128;
|
1771
|
+
for (int i = ix; i < nb; i += 4) {
|
1665
1772
|
|
1666
|
-
const
|
1667
|
-
const
|
1773
|
+
device const uint8_t * q1 = x[i].qs + q_offset;
|
1774
|
+
device const uint8_t * qh = x[i].qh + l0;
|
1775
|
+
device const half * dh = &x[i].d;
|
1776
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
|
1668
1777
|
|
1669
|
-
device const
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1673
|
-
|
1778
|
+
device const float * y2 = y1 + 128;
|
1779
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1780
|
+
for (int l = 0; l < 8; ++l) {
|
1781
|
+
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
1782
|
+
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
1783
|
+
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
1784
|
+
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
1785
|
+
}
|
1674
1786
|
|
1675
|
-
|
1676
|
-
|
1677
|
-
|
1787
|
+
for (int row = 0; row < 2; ++row) {
|
1788
|
+
|
1789
|
+
device const uint8_t * q2 = q1 + 64;
|
1678
1790
|
|
1679
|
-
|
1680
|
-
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1791
|
+
sc16[0] = a[0] & kmask1;
|
1792
|
+
sc16[1] = a[2] & kmask1;
|
1793
|
+
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
1794
|
+
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
1795
|
+
|
1796
|
+
float4 acc = {0.f, 0.f, 0.f, 0.f};
|
1797
|
+
for (int l = 0; l < n; ++l) {
|
1798
|
+
uint8_t h = qh[l];
|
1799
|
+
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
|
1800
|
+
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
|
1801
|
+
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
|
1802
|
+
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
|
1803
|
+
}
|
1804
|
+
const float dall = dh[0];
|
1805
|
+
const float dmin = dh[1];
|
1806
|
+
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
|
1807
|
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1808
|
+
|
1809
|
+
q1 += step;
|
1810
|
+
qh += step;
|
1811
|
+
dh += step/2;
|
1812
|
+
a += step/2;
|
1684
1813
|
|
1685
1814
|
}
|
1686
|
-
|
1815
|
+
|
1816
|
+
y1 += 4 * QK_K;
|
1687
1817
|
|
1688
1818
|
}
|
1689
1819
|
#else
|
1690
|
-
|
1691
|
-
|
1692
|
-
const int
|
1820
|
+
float yl[8], yh[8];
|
1821
|
+
|
1822
|
+
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
1823
|
+
const int ix = tiisg%8;
|
1824
|
+
const int im = il/8; // 0, 0, 1, 1
|
1825
|
+
const int in = il%8; // 0, 4, 0, 4
|
1693
1826
|
|
1694
|
-
|
1827
|
+
device const float * y = yy + ix*QK_K + il;
|
1695
1828
|
|
1696
|
-
|
1829
|
+
for (int i = ix; i < nb; i += 8) {
|
1830
|
+
|
1831
|
+
for (int l = 0; l < 4; ++l) {
|
1832
|
+
yl[l+0] = y[l+ 0];
|
1833
|
+
yl[l+4] = y[l+16];
|
1834
|
+
yh[l+0] = y[l+32];
|
1835
|
+
yh[l+4] = y[l+48];
|
1836
|
+
}
|
1837
|
+
|
1838
|
+
device const half * dh = &x[i].d;
|
1697
1839
|
device const uint8_t * q = x[i].qs + il;
|
1698
1840
|
device const uint8_t * h = x[i].qh + in;
|
1699
1841
|
device const int8_t * s = x[i].scales;
|
1700
|
-
device const float * y = yy + i*QK_K + il;
|
1701
1842
|
|
1702
|
-
for (int
|
1703
|
-
|
1704
|
-
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1843
|
+
for (int row = 0; row < 2; ++row) {
|
1844
|
+
|
1845
|
+
const float d = dh[0];
|
1846
|
+
|
1847
|
+
float2 acc = {0.f, 0.f};
|
1848
|
+
for (int l = 0; l < 4; ++l) {
|
1849
|
+
const uint8_t hl = h[l] >> im;
|
1850
|
+
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
1851
|
+
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
1852
|
+
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
1853
|
+
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
1854
|
+
}
|
1855
|
+
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
1856
|
+
|
1857
|
+
q += step;
|
1858
|
+
h += step;
|
1859
|
+
s += step;
|
1860
|
+
dh += step/2;
|
1861
|
+
|
1708
1862
|
}
|
1863
|
+
|
1864
|
+
y += 8 * QK_K;
|
1709
1865
|
}
|
1710
1866
|
#endif
|
1711
|
-
sum[ith] = sumf;
|
1712
1867
|
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
1719
|
-
}
|
1720
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1721
|
-
if (ith%16 == 0) {
|
1722
|
-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
1723
|
-
}
|
1724
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1725
|
-
if (ith == 0) {
|
1726
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1727
|
-
dst[r1*ne0 + r0] = sum[0];
|
1868
|
+
for (int row = 0; row < 2; ++row) {
|
1869
|
+
const float tot = simd_sum(sumf[row]);
|
1870
|
+
if (tiisg == 0) {
|
1871
|
+
dst[r1*ne0 + first_row + row] = tot;
|
1872
|
+
}
|
1728
1873
|
}
|
1729
1874
|
|
1730
1875
|
}
|
@@ -1736,10 +1881,9 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1736
1881
|
constant int64_t & ne00,
|
1737
1882
|
constant int64_t & ne10,
|
1738
1883
|
constant int64_t & ne0,
|
1739
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1740
1884
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1741
|
-
|
1742
|
-
|
1885
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1886
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1743
1887
|
|
1744
1888
|
const uint8_t kmask1 = 0x03;
|
1745
1889
|
const uint8_t kmask2 = 0x0C;
|
@@ -1751,19 +1895,18 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1751
1895
|
const int64_t r0 = tgpig.x;
|
1752
1896
|
const int64_t r1 = tgpig.y;
|
1753
1897
|
|
1754
|
-
|
1755
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1898
|
+
const int row = 2 * r0 + sgitg;
|
1756
1899
|
|
1757
|
-
const
|
1758
|
-
const
|
1900
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
|
1901
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1759
1902
|
|
1760
1903
|
float sumf = 0;
|
1761
1904
|
|
1762
1905
|
#if QK_K == 256
|
1763
|
-
|
1764
|
-
const int
|
1765
|
-
const int ip =
|
1766
|
-
const int il =
|
1906
|
+
const int tid = tiisg/2;
|
1907
|
+
const int ix = tiisg%2;
|
1908
|
+
const int ip = tid/8; // 0 or 1
|
1909
|
+
const int il = tid%8;
|
1767
1910
|
const int n = 4;
|
1768
1911
|
const int l0 = n*il;
|
1769
1912
|
const int is = 8*ip + l0/16;
|
@@ -1772,9 +1915,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1772
1915
|
const int q_offset_l = 64*ip + l0;
|
1773
1916
|
const int q_offset_h = 32*ip + l0;
|
1774
1917
|
|
1775
|
-
for (int i =
|
1918
|
+
for (int i = ix; i < nb; i += 2) {
|
1776
1919
|
|
1777
|
-
device const uint8_t *
|
1920
|
+
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
1921
|
+
device const uint8_t * q2 = q1 + 32;
|
1778
1922
|
device const uint8_t * qh = x[i].qh + q_offset_h;
|
1779
1923
|
device const int8_t * sc = x[i].scales + is;
|
1780
1924
|
|
@@ -1784,19 +1928,21 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1784
1928
|
|
1785
1929
|
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
1786
1930
|
for (int l = 0; l < n; ++l) {
|
1787
|
-
sums[0] += y[l+ 0] * ((int8_t)((
|
1788
|
-
sums[1] += y[l+32] * ((int8_t)((
|
1789
|
-
sums[2] += y[l+64] * ((int8_t)((
|
1790
|
-
sums[3] += y[l+96] * ((int8_t)((
|
1931
|
+
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
1932
|
+
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
1933
|
+
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
1934
|
+
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
1791
1935
|
}
|
1792
1936
|
|
1793
1937
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
1794
1938
|
|
1795
1939
|
}
|
1940
|
+
|
1796
1941
|
#else
|
1797
|
-
const int
|
1942
|
+
const int ix = tiisg/4;
|
1943
|
+
const int il = 4*(tiisg%4);
|
1798
1944
|
|
1799
|
-
for (int i =
|
1945
|
+
for (int i = ix; i < nb; i += 8) {
|
1800
1946
|
device const float * y = yy + i * QK_K + il;
|
1801
1947
|
device const uint8_t * ql = x[i].ql + il;
|
1802
1948
|
device const uint8_t * qh = x[i].qh + il;
|
@@ -1816,23 +1962,8 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1816
1962
|
|
1817
1963
|
#endif
|
1818
1964
|
|
1819
|
-
|
1820
|
-
|
1821
|
-
|
1822
|
-
// Accumulate the sum from all threads in the threadgroup
|
1823
|
-
//
|
1824
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1825
|
-
if (ith%4 == 0) {
|
1826
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1965
|
+
const float tot = simd_sum(sumf);
|
1966
|
+
if (tiisg == 0) {
|
1967
|
+
dst[r1*ne0 + row] = tot;
|
1827
1968
|
}
|
1828
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1829
|
-
if (ith%16 == 0) {
|
1830
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1831
|
-
}
|
1832
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1833
|
-
if (ith == 0) {
|
1834
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1835
|
-
dst[r1*ne0 + r0] = sum[0];
|
1836
|
-
}
|
1837
|
-
|
1838
1969
|
}
|