llama_cpp 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +37 -0
- data/ext/llama_cpp/extconf.rb +9 -0
- data/ext/llama_cpp/llama_cpp.cpp +302 -112
- data/ext/llama_cpp/src/ggml-cuda.cu +677 -118
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +65 -45
- data/ext/llama_cpp/src/ggml-metal.metal +610 -484
- data/ext/llama_cpp/src/ggml-mpi.c +216 -0
- data/ext/llama_cpp/src/ggml-mpi.h +39 -0
- data/ext/llama_cpp/src/ggml.c +1146 -812
- data/ext/llama_cpp/src/ggml.h +77 -19
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +289 -104
- data/ext/llama_cpp/src/llama.h +46 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +2 -1
- data/sig/llama_cpp.rbs +14 -1
- metadata +4 -2
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
|
|
331
331
|
threadgroup float * sum [[threadgroup(0)]],
|
332
332
|
uint tgpig[[threadgroup_position_in_grid]],
|
333
333
|
uint tpitg[[thread_position_in_threadgroup]],
|
334
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
335
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
334
336
|
uint ntg[[threads_per_threadgroup]]) {
|
335
|
-
device const
|
337
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
338
|
+
device const float * x_scalar = (device const float *) x;
|
339
|
+
float4 sumf=0;
|
340
|
+
float all_sum=0;
|
336
341
|
|
337
342
|
// parallel sum
|
338
|
-
|
339
|
-
|
340
|
-
|
343
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
344
|
+
sumf += x[i00] * x[i00];
|
345
|
+
}
|
346
|
+
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
347
|
+
all_sum = simd_sum(all_sum);
|
348
|
+
if (tiisg == 0) {
|
349
|
+
sum[sgitg] = all_sum;
|
341
350
|
}
|
342
351
|
|
343
|
-
// reduce
|
344
352
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
353
|
+
// broadcast, simd group number is ntg / 32
|
354
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
355
|
+
if (tpitg < i) {
|
356
|
+
sum[tpitg] += sum[tpitg + i];
|
357
|
+
}
|
350
358
|
}
|
351
|
-
|
352
|
-
// broadcast
|
353
359
|
if (tpitg == 0) {
|
360
|
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
354
361
|
sum[0] /= ne00;
|
355
362
|
}
|
356
363
|
|
@@ -359,147 +366,127 @@ kernel void kernel_rms_norm(
|
|
359
366
|
const float mean = sum[0];
|
360
367
|
const float scale = 1.0f/sqrt(mean + eps);
|
361
368
|
|
362
|
-
device
|
363
|
-
|
369
|
+
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
370
|
+
device float * y_scalar = (device float *) y;
|
371
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
364
372
|
y[i00] = x[i00] * scale;
|
365
373
|
}
|
374
|
+
if (tpitg == 0) {
|
375
|
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
376
|
+
}
|
366
377
|
}
|
367
378
|
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
379
|
+
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
|
380
|
+
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
|
381
|
+
float d = qb_curr->d;
|
382
|
+
float4 acc = 0.f;
|
383
|
+
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
|
384
|
+
for (int i = 0; i < 16; i+=2) {
|
385
|
+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
386
|
+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
387
|
+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
388
|
+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
389
|
+
}
|
390
|
+
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
|
391
|
+
}
|
380
392
|
|
381
|
-
|
382
|
-
|
393
|
+
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
|
394
|
+
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
|
395
|
+
float d = qb_curr->d;
|
396
|
+
float m = qb_curr->m;
|
397
|
+
float4 acc = 0.f;
|
398
|
+
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
|
399
|
+
for (int i = 0; i < 16; i+=2) {
|
400
|
+
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
401
|
+
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
402
|
+
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
403
|
+
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
404
|
+
}
|
405
|
+
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
|
406
|
+
}
|
383
407
|
|
384
|
-
|
408
|
+
// putting them in the kernel cause a significant performance penalty
|
409
|
+
#define N_DST 4 // each SIMD group works on 4 rows
|
410
|
+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
411
|
+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
412
|
+
template<typename block_q_type>
|
413
|
+
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
414
|
+
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
415
|
+
uint2 tgpig, uint tiisg, uint sgitg) {
|
416
|
+
const int nb = ne00/QK4_0;
|
417
|
+
const int r0 = tgpig.x;
|
418
|
+
const int r1 = tgpig.y;
|
419
|
+
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
385
420
|
device const float * y = (device const float *) src1 + r1*ne10;
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
398
|
-
|
399
|
-
const float d = (float)x[i].d;
|
400
|
-
|
401
|
-
device const uint8_t * xl = x[i].qs + first;
|
402
|
-
device const float * yl = y + i * QK4_0 + first;
|
403
|
-
|
404
|
-
float2 acc = {0.0f, 0.0f};
|
405
|
-
|
406
|
-
for (int j = 0; j < 4; ++j) {
|
407
|
-
|
408
|
-
acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
|
409
|
-
acc[1] += yl[j] + yl[j+16];
|
410
|
-
|
421
|
+
float4 y_curr[8]; // src1 vector cache
|
422
|
+
float sumf[N_DST]={0.f}, all_sum;
|
423
|
+
thread float * yl=(thread float *)y_curr;
|
424
|
+
|
425
|
+
// each thread in a SIMD group deals with 1 block.
|
426
|
+
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
427
|
+
float sumy = 0;
|
428
|
+
for (int i = 0; i < QK4_0 / 4; i++) {
|
429
|
+
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
|
430
|
+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
411
431
|
}
|
412
432
|
|
413
|
-
|
433
|
+
for (int row = 0; row < N_DST; row++) {
|
434
|
+
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
|
435
|
+
}
|
414
436
|
}
|
415
437
|
|
416
|
-
|
438
|
+
// from now loads two rows every time and 16 blocks per row
|
439
|
+
int ir = tiisg / (N_SIMDWIDTH / 2);
|
440
|
+
int ib = tiisg % (N_SIMDWIDTH / 2);
|
441
|
+
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
|
442
|
+
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
|
443
|
+
float sumy = 0;
|
444
|
+
for (int i = 0; i < QK4_0 / 4; i++) {
|
445
|
+
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
|
446
|
+
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
447
|
+
}
|
417
448
|
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
424
|
-
}
|
425
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
426
|
-
if (ith%16 == 0) {
|
427
|
-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
449
|
+
for (int row = 0; row < N_DST; row+=2) {
|
450
|
+
if (nb_start + ib < nb) {
|
451
|
+
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
|
452
|
+
}
|
453
|
+
}
|
428
454
|
}
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
455
|
+
|
456
|
+
for (int row = 0; row < N_DST; ++row) {
|
457
|
+
all_sum = simd_sum(sumf[row]);
|
458
|
+
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
459
|
+
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
460
|
+
}
|
433
461
|
}
|
434
462
|
}
|
435
463
|
|
436
|
-
kernel void
|
464
|
+
kernel void kernel_mul_mat_q4_0_f32(
|
437
465
|
device const void * src0,
|
438
466
|
device const float * src1,
|
439
467
|
device float * dst,
|
440
468
|
constant int64_t & ne00,
|
441
469
|
constant int64_t & ne10,
|
442
470
|
constant int64_t & ne0,
|
443
|
-
|
471
|
+
constant int64_t & ne01[[buffer(4)]],
|
444
472
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
const int64_t r0 = tgpig.x;
|
450
|
-
const int64_t r1 = tgpig.y;
|
451
|
-
|
452
|
-
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
|
453
|
-
device const float * y = (device const float *) src1 + r1*ne10;
|
454
|
-
|
455
|
-
const uint nth = tptg.x*tptg.y;
|
456
|
-
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
457
|
-
|
458
|
-
const int ix = tpitg.y/4; // 0 or 1
|
459
|
-
const int iy = tpitg.y - 4*ix; // 0...3
|
460
|
-
|
461
|
-
const int first = 4 * iy;
|
462
|
-
|
463
|
-
float sumf = 0;
|
464
|
-
|
465
|
-
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
466
|
-
|
467
|
-
const float d = (float)x[i].d;
|
468
|
-
const float m = (float)x[i].m;
|
469
|
-
|
470
|
-
device const uint8_t * xl = x[i].qs + first;
|
471
|
-
device const float * yl = y + i * QK4_1 + first;
|
472
|
-
|
473
|
-
float2 acc = {0.0f, 0.0f};
|
474
|
-
|
475
|
-
for (int j = 0; j < 4; ++j) {
|
476
|
-
|
477
|
-
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
|
478
|
-
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
|
479
|
-
|
480
|
-
}
|
481
|
-
|
482
|
-
sumf += acc[0] + acc[1];
|
483
|
-
}
|
484
|
-
|
485
|
-
sum[ith] = sumf;
|
473
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
474
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
475
|
+
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
476
|
+
}
|
486
477
|
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
if (ith == 0) {
|
500
|
-
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
501
|
-
dst[r1*ne0 + r0] = sum[0];
|
502
|
-
}
|
478
|
+
kernel void kernel_mul_mat_q4_1_f32(
|
479
|
+
device const void * src0,
|
480
|
+
device const float * src1,
|
481
|
+
device float * dst,
|
482
|
+
constant int64_t & ne00,
|
483
|
+
constant int64_t & ne10,
|
484
|
+
constant int64_t & ne0,
|
485
|
+
constant int64_t & ne01[[buffer(4)]],
|
486
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
487
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
488
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
489
|
+
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
503
490
|
}
|
504
491
|
|
505
492
|
kernel void kernel_mul_mat_f16_f32(
|
@@ -615,17 +602,19 @@ kernel void kernel_rope(
|
|
615
602
|
constant int & n_past,
|
616
603
|
constant int & n_dims,
|
617
604
|
constant int & mode,
|
605
|
+
constant float & freq_base,
|
606
|
+
constant float & freq_scale,
|
618
607
|
uint3 tpig[[thread_position_in_grid]]) {
|
619
608
|
const int64_t i3 = tpig[2];
|
620
609
|
const int64_t i2 = tpig[1];
|
621
610
|
const int64_t i1 = tpig[0];
|
622
611
|
|
623
612
|
const bool is_neox = mode & 2;
|
624
|
-
const float theta_scale = pow(
|
613
|
+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
625
614
|
|
626
615
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
627
616
|
|
628
|
-
float theta = (float)p;
|
617
|
+
float theta = freq_scale * (float)p;
|
629
618
|
|
630
619
|
if (!is_neox) {
|
631
620
|
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
@@ -1220,111 +1209,137 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|
1220
1209
|
constant int64_t & ne00,
|
1221
1210
|
constant int64_t & ne10,
|
1222
1211
|
constant int64_t & ne0,
|
1223
|
-
|
1212
|
+
constant int64_t & ne01[[buffer(4)]],
|
1224
1213
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1225
|
-
|
1226
|
-
|
1214
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1215
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1227
1216
|
|
1228
1217
|
const int nb = ne00/QK_K;
|
1218
|
+
const int r0 = tgpig.x;
|
1219
|
+
const int r1 = tgpig.y;
|
1229
1220
|
|
1230
|
-
const
|
1231
|
-
const
|
1232
|
-
|
1233
|
-
device const
|
1234
|
-
|
1221
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1222
|
+
const int ib_row = first_row * nb;
|
1223
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
|
1224
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1225
|
+
float yl[32];
|
1226
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1235
1227
|
|
1236
|
-
const int
|
1237
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1238
|
-
|
1239
|
-
float sumf = 0;
|
1228
|
+
const int step = sizeof(block_q2_K) * nb;
|
1240
1229
|
|
1241
1230
|
#if QK_K == 256
|
1242
|
-
const int
|
1243
|
-
const int
|
1244
|
-
const int
|
1245
|
-
const int
|
1246
|
-
const int
|
1247
|
-
|
1248
|
-
const
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
uint8_t d1 = scales[0] & 0xF;
|
1260
|
-
uint8_t d2 = scales[2] & 0xF;
|
1261
|
-
uint8_t m1 = scales[0] >> 4;
|
1262
|
-
uint8_t m2 = scales[2] >> 4;
|
1263
|
-
|
1264
|
-
device const float * y = yy + i*QK_K + y_offset;
|
1265
|
-
|
1266
|
-
float2 s = {0.f, 0.f};
|
1267
|
-
float smin = 0;
|
1268
|
-
for (int l = 0; l < n; ++l) {
|
1269
|
-
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
|
1270
|
-
s[1] += y[l+32] * ((q[l] >> shift2) & 3);
|
1271
|
-
smin += y[l+ 0] * m1 + y[l+32] * m2;
|
1231
|
+
const int ix = tiisg/8; // 0...3
|
1232
|
+
const int it = tiisg%8; // 0...7
|
1233
|
+
const int im = it/4; // 0 or 1
|
1234
|
+
const int ir = it%4; // 0...3
|
1235
|
+
const int is = (8*ir)/16;// 0 or 1
|
1236
|
+
|
1237
|
+
device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
|
1238
|
+
|
1239
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
1240
|
+
|
1241
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1242
|
+
for (int i = 0; i < 8; ++i) {
|
1243
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
1244
|
+
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
1245
|
+
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
1246
|
+
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
1272
1247
|
}
|
1273
1248
|
|
1274
|
-
const
|
1275
|
-
const
|
1276
|
-
|
1277
|
-
|
1249
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
|
1250
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
1251
|
+
device const half * dh = &x[ib].d;
|
1252
|
+
|
1253
|
+
for (int row = 0; row < N_DST; row++) {
|
1254
|
+
|
1255
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1256
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1257
|
+
for (int i = 0; i < 8; i += 2) {
|
1258
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
1259
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
1260
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
1261
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
1262
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
1263
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
1264
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
1265
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
1266
|
+
}
|
1267
|
+
float dall = dh[0];
|
1268
|
+
float dmin = dh[1] * 1.f/16.f;
|
1269
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
1270
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
|
1271
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
|
1272
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
1273
|
+
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
1274
|
+
|
1275
|
+
qs += step/2;
|
1276
|
+
sc += step;
|
1277
|
+
dh += step/2;
|
1278
|
+
}
|
1278
1279
|
|
1280
|
+
y4 += 4 * QK_K;
|
1279
1281
|
}
|
1280
1282
|
#else
|
1281
|
-
const int
|
1282
|
-
|
1283
|
-
uint32_t aux[2];
|
1284
|
-
thread const uint8_t * d = (thread const uint8_t *)aux;
|
1285
|
-
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
1283
|
+
const int ix = tiisg/2; // 0...15
|
1284
|
+
const int it = tiisg%2; // 0...1
|
1286
1285
|
|
1287
|
-
|
1286
|
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
1288
1287
|
|
1289
|
-
|
1290
|
-
device const float * y = yy + i*QK_K + il;
|
1288
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
1291
1289
|
|
1292
|
-
|
1293
|
-
|
1290
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1291
|
+
for (int i = 0; i < 8; ++i) {
|
1292
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
1293
|
+
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
1294
|
+
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
1295
|
+
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
1296
|
+
}
|
1294
1297
|
|
1295
|
-
device const
|
1296
|
-
|
1297
|
-
|
1298
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
1299
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
1300
|
+
device const half * dh = &x[ib].d;
|
1301
|
+
|
1302
|
+
for (int row = 0; row < N_DST; row++) {
|
1303
|
+
|
1304
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1305
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1306
|
+
for (int i = 0; i < 8; i += 2) {
|
1307
|
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
1308
|
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
1309
|
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
1310
|
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
1311
|
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
1312
|
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
1313
|
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
1314
|
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
1315
|
+
}
|
1298
1316
|
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1317
|
+
float dall = dh[0];
|
1318
|
+
float dmin = dh[1];
|
1319
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
1320
|
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
1321
|
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
1322
|
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
1323
|
+
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
1324
|
+
|
1325
|
+
qs += step/2;
|
1326
|
+
sc += step;
|
1327
|
+
dh += step/2;
|
1304
1328
|
}
|
1329
|
+
|
1330
|
+
y4 += 16 * QK_K;
|
1305
1331
|
}
|
1306
1332
|
#endif
|
1307
1333
|
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1314
|
-
if (ith%4 == 0) {
|
1315
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1316
|
-
}
|
1317
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1318
|
-
if (ith%16 == 0) {
|
1319
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1320
|
-
}
|
1321
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1322
|
-
if (ith == 0) {
|
1323
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1324
|
-
dst[r1*ne0 + r0] = sum[0];
|
1334
|
+
for (int row = 0; row < N_DST; ++row) {
|
1335
|
+
all_sum = simd_sum(sumf[row]);
|
1336
|
+
if (tiisg == 0) {
|
1337
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1338
|
+
}
|
1325
1339
|
}
|
1326
1340
|
}
|
1327
1341
|
|
1342
|
+
#if QK_K == 256
|
1328
1343
|
kernel void kernel_mul_mat_q3_K_f32(
|
1329
1344
|
device const void * src0,
|
1330
1345
|
device const float * src1,
|
@@ -1333,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1333
1348
|
constant int64_t & ne10,
|
1334
1349
|
constant int64_t & ne0,
|
1335
1350
|
constant int64_t & ne1,
|
1336
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1337
1351
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1338
|
-
|
1339
|
-
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1353
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1340
1354
|
|
1341
1355
|
const int nb = ne00/QK_K;
|
1342
1356
|
|
1343
1357
|
const int64_t r0 = tgpig.x;
|
1344
1358
|
const int64_t r1 = tgpig.y;
|
1345
1359
|
|
1346
|
-
|
1347
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1348
|
-
|
1349
|
-
const int nth = tptg.x*tptg.y;
|
1350
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1360
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1351
1361
|
|
1352
|
-
|
1362
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
1363
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1353
1364
|
|
1354
|
-
|
1355
|
-
const int8_t m4 = 4;
|
1365
|
+
float yl[16];
|
1356
1366
|
|
1357
1367
|
const uint16_t kmask1 = 0x0303;
|
1358
1368
|
const uint16_t kmask2 = 0x0f0f;
|
1359
1369
|
|
1360
|
-
const int tid =
|
1370
|
+
const int tid = tiisg/2;
|
1371
|
+
const int ix = tiisg%2;
|
1361
1372
|
const int ip = tid/8; // 0 or 1
|
1362
1373
|
const int il = tid/2 - 4*ip; // 0...3
|
1363
1374
|
const int ir = tid%2;
|
1364
1375
|
const int n = 8;
|
1365
1376
|
const int l0 = n*ir;
|
1366
1377
|
|
1367
|
-
const
|
1378
|
+
const uint16_t m1 = 1 << (4*ip + il);
|
1379
|
+
const uint16_t m2 = m1 << 8;
|
1368
1380
|
|
1369
1381
|
const int shift = 2*il;
|
1382
|
+
const uint16_t qm1 = 0x0003 << shift;
|
1383
|
+
const uint16_t qm2 = 0x0300 << shift;
|
1384
|
+
const int32_t v1 = 4 << shift;
|
1385
|
+
const int32_t v2 = 1024 << shift;
|
1370
1386
|
|
1371
1387
|
const uint16_t s_shift1 = 4*ip;
|
1372
1388
|
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
@@ -1375,226 +1391,315 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1375
1391
|
const int q_offset = 32*ip + l0;
|
1376
1392
|
const int y_offset = 128*ip + 32*il + l0;
|
1377
1393
|
|
1378
|
-
|
1379
|
-
float sumf1 = 0, sumf2 = 0;
|
1380
|
-
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1394
|
+
const int step = sizeof(block_q3_K) * nb / 2;
|
1381
1395
|
|
1382
|
-
|
1396
|
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
1383
1397
|
|
1384
|
-
|
1385
|
-
|
1386
|
-
device const float * y = yy + i * QK_K + y_offset;
|
1398
|
+
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
1399
|
+
for (int i = ix; i < nb; i += 2) {
|
1387
1400
|
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
float s = 0;
|
1392
|
-
for (int l = 0; l < n; ++l) {
|
1393
|
-
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
|
1394
|
-
}
|
1395
|
-
float d = d_all * s;
|
1396
|
-
sumf1 += d * scales[0];
|
1397
|
-
sumf2 += d;
|
1398
|
-
//sumf += d_all * s * (scales[0] - 32);
|
1399
|
-
|
1400
|
-
s = 0;
|
1401
|
-
for (int l = 0; l < n; ++l) {
|
1402
|
-
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
|
1401
|
+
for (int l = 0; l < 8; ++l) {
|
1402
|
+
yl[l+0] = y1[l+ 0];
|
1403
|
+
yl[l+8] = y1[l+16];
|
1403
1404
|
}
|
1404
|
-
d = d_all * s;
|
1405
|
-
sumf1 += d * scales[1];
|
1406
|
-
sumf2 += d;
|
1407
|
-
//sumf += d_all * s * (scales[1] - 32);
|
1408
|
-
|
1409
|
-
}
|
1410
1405
|
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
const int im = il/8; // 0, 0, 1, 1
|
1416
|
-
const int in = il%8; // 0, 4, 0, 4
|
1406
|
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
1407
|
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
1408
|
+
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
1409
|
+
device const half * dh = &x[i].d;
|
1417
1410
|
|
1418
|
-
|
1411
|
+
for (int row = 0; row < 2; ++row) {
|
1419
1412
|
|
1420
|
-
|
1413
|
+
const float d_all = (float)dh[0];
|
1414
|
+
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1421
1415
|
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1416
|
+
float s1 = 0, s2 = 0;
|
1417
|
+
for (int l = 0; l < n; l += 2) {
|
1418
|
+
const uint16_t qs = q[l/2];
|
1419
|
+
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
1420
|
+
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
1421
|
+
}
|
1422
|
+
float d = d_all * (s1 + 1.f/256.f * s2);
|
1423
|
+
sumf1[row] += d * scales[0];
|
1424
|
+
sumf2[row] += d;
|
1425
|
+
|
1426
|
+
s1 = s2 = 0;
|
1427
|
+
for (int l = 0; l < n; l += 2) {
|
1428
|
+
const uint16_t qs = q[l/2+8];
|
1429
|
+
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
1430
|
+
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
1431
|
+
}
|
1432
|
+
d = d_all * (s1 + 1.f/256.f * s2);
|
1433
|
+
sumf1[row] += d * scales[1];
|
1434
|
+
sumf2[row] += d;
|
1427
1435
|
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1436
|
+
q += step;
|
1437
|
+
h += step;
|
1438
|
+
a += step;
|
1439
|
+
dh += step;
|
1432
1440
|
|
1433
|
-
for (int l = 0; l < 4; ++l) {
|
1434
|
-
const uint8_t hm = h[l] >> im;
|
1435
|
-
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
1436
|
-
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
1437
|
-
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
1438
|
-
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
1439
1441
|
}
|
1440
1442
|
|
1441
|
-
|
1442
|
-
|
1443
|
-
sum[ith] = sumf;
|
1444
|
-
|
1445
|
-
#endif
|
1443
|
+
y1 += 2 * QK_K;
|
1446
1444
|
|
1447
|
-
//
|
1448
|
-
// Accumulate the sum from all threads in the threadgroup
|
1449
|
-
//
|
1450
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1451
|
-
if (ith%4 == 0) {
|
1452
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1453
|
-
}
|
1454
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1455
|
-
if (ith%16 == 0) {
|
1456
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1457
|
-
}
|
1458
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1459
|
-
if (ith == 0) {
|
1460
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1461
|
-
dst[r1*ne0 + r0] = sum[0];
|
1462
1445
|
}
|
1463
1446
|
|
1447
|
+
for (int row = 0; row < 2; ++row) {
|
1448
|
+
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
1449
|
+
const float tot = simd_sum(sumf);
|
1450
|
+
if (tiisg == 0) {
|
1451
|
+
dst[r1*ne0 + first_row + row] = tot;
|
1452
|
+
}
|
1453
|
+
}
|
1464
1454
|
}
|
1465
|
-
|
1466
|
-
kernel void
|
1455
|
+
#else
|
1456
|
+
kernel void kernel_mul_mat_q3_K_f32(
|
1467
1457
|
device const void * src0,
|
1468
1458
|
device const float * src1,
|
1469
1459
|
device float * dst,
|
1470
1460
|
constant int64_t & ne00,
|
1471
1461
|
constant int64_t & ne10,
|
1472
1462
|
constant int64_t & ne0,
|
1473
|
-
|
1463
|
+
constant int64_t & ne1,
|
1474
1464
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1475
|
-
|
1476
|
-
|
1465
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1466
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1477
1467
|
|
1478
1468
|
const int nb = ne00/QK_K;
|
1479
1469
|
|
1480
1470
|
const int64_t r0 = tgpig.x;
|
1481
1471
|
const int64_t r1 = tgpig.y;
|
1482
1472
|
|
1483
|
-
const int
|
1484
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1473
|
+
const int row = 2 * r0 + sgitg;
|
1485
1474
|
|
1486
|
-
device const
|
1475
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
1487
1476
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1477
|
+
const int ix = tiisg/4;
|
1478
|
+
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1479
|
+
const int im = il/8; // 0, 0, 1, 1
|
1480
|
+
const int in = il%8; // 0, 4, 0, 4
|
1488
1481
|
|
1489
|
-
|
1482
|
+
float2 sum = {0.f, 0.f};
|
1483
|
+
|
1484
|
+
for (int i = ix; i < nb; i += 8) {
|
1485
|
+
|
1486
|
+
const float d_all = (float)(x[i].d);
|
1487
|
+
|
1488
|
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
1489
|
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
1490
|
+
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
1491
|
+
device const float * y = yy + i * QK_K + il;
|
1492
|
+
|
1493
|
+
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
1494
|
+
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
1495
|
+
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
1496
|
+
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
1497
|
+
|
1498
|
+
for (int l = 0; l < 4; l += 2) {
|
1499
|
+
const uint16_t hm = h[l/2] >> im;
|
1500
|
+
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
1501
|
+
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
1502
|
+
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
1503
|
+
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
1504
|
+
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
1505
|
+
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
1506
|
+
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
1507
|
+
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
1508
|
+
}
|
1509
|
+
|
1510
|
+
}
|
1511
|
+
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
1512
|
+
|
1513
|
+
const float tot = simd_sum(sumf);
|
1514
|
+
if (tiisg == 0) {
|
1515
|
+
dst[r1*ne0 + row] = tot;
|
1516
|
+
}
|
1517
|
+
|
1518
|
+
}
|
1519
|
+
#endif
|
1490
1520
|
|
1491
1521
|
#if QK_K == 256
|
1522
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1523
|
+
device const void * src0,
|
1524
|
+
device const float * src1,
|
1525
|
+
device float * dst,
|
1526
|
+
constant int64_t & ne00,
|
1527
|
+
constant int64_t & ne10,
|
1528
|
+
constant int64_t & ne0,
|
1529
|
+
constant int64_t & ne01[[buffer(4)]],
|
1530
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1531
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1532
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1492
1533
|
|
1493
1534
|
const uint16_t kmask1 = 0x3f3f;
|
1494
1535
|
const uint16_t kmask2 = 0x0f0f;
|
1495
1536
|
const uint16_t kmask3 = 0xc0c0;
|
1496
1537
|
|
1497
|
-
const int
|
1498
|
-
const int
|
1499
|
-
const int
|
1500
|
-
const int
|
1538
|
+
const int ix = tiisg/8; // 0...3
|
1539
|
+
const int it = tiisg%8; // 0...7
|
1540
|
+
const int im = it/4; // 0 or 1
|
1541
|
+
const int ir = it%4; // 0...3
|
1501
1542
|
|
1502
|
-
const int
|
1503
|
-
const int
|
1543
|
+
const int nb = ne00/QK_K;
|
1544
|
+
const int r0 = tgpig.x;
|
1545
|
+
const int r1 = tgpig.y;
|
1546
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1547
|
+
const int ib_row = first_row * nb;
|
1548
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
|
1549
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1550
|
+
float yl[16];
|
1551
|
+
float yh[16];
|
1552
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1504
1553
|
|
1505
|
-
const int
|
1506
|
-
const int q_offset = 32*im + l0;
|
1507
|
-
const int y_offset = 64*im + l0;
|
1554
|
+
const int step = sizeof(block_q4_K) * nb / 2;
|
1508
1555
|
|
1509
|
-
|
1556
|
+
device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
|
1510
1557
|
|
1511
|
-
|
1558
|
+
uint16_t sc16[4];
|
1559
|
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
1512
1560
|
|
1513
|
-
|
1514
|
-
device const uint8_t * q2 = q1 + 64;
|
1515
|
-
device const float * y1 = yy + i*QK_K + y_offset;
|
1516
|
-
device const float * y2 = y1 + 128;
|
1561
|
+
for (int ib = ix; ib < nb; ib += 4) {
|
1517
1562
|
|
1518
|
-
|
1519
|
-
|
1563
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1564
|
+
for (int i = 0; i < 8; ++i) {
|
1565
|
+
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
1566
|
+
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
1567
|
+
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
1568
|
+
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
1569
|
+
}
|
1520
1570
|
|
1521
|
-
device const uint16_t *
|
1522
|
-
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1571
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
|
1572
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
1573
|
+
device const half * dh = &x[ib].d;
|
1574
|
+
|
1575
|
+
for (int row = 0; row < N_DST; row++) {
|
1576
|
+
|
1577
|
+
sc16[0] = sc[0] & kmask1;
|
1578
|
+
sc16[1] = sc[2] & kmask1;
|
1579
|
+
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
1580
|
+
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
1581
|
+
|
1582
|
+
device const uint16_t * q2 = q1 + 32;
|
1583
|
+
|
1584
|
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
1585
|
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
1586
|
+
for (int i = 0; i < 8; i += 2) {
|
1587
|
+
acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
|
1588
|
+
acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
|
1589
|
+
acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
|
1590
|
+
acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
|
1591
|
+
acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
|
1592
|
+
acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
|
1593
|
+
acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
|
1594
|
+
acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
|
1595
|
+
}
|
1526
1596
|
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1597
|
+
float dall = dh[0];
|
1598
|
+
float dmin = dh[1];
|
1599
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
1600
|
+
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
1601
|
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
1602
|
+
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
1603
|
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1604
|
+
|
1605
|
+
q1 += step;
|
1606
|
+
sc += step;
|
1607
|
+
dh += step;
|
1608
|
+
}
|
1530
1609
|
|
1531
|
-
|
1532
|
-
|
1533
|
-
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
|
1610
|
+
y4 += 4 * QK_K;
|
1611
|
+
}
|
1534
1612
|
|
1613
|
+
for (int row = 0; row < N_DST; ++row) {
|
1614
|
+
all_sum = simd_sum(sumf[row]);
|
1615
|
+
if (tiisg == 0) {
|
1616
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1535
1617
|
}
|
1536
|
-
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1537
|
-
|
1538
1618
|
}
|
1619
|
+
}
|
1539
1620
|
#else
|
1540
|
-
|
1541
|
-
|
1621
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1622
|
+
device const void * src0,
|
1623
|
+
device const float * src1,
|
1624
|
+
device float * dst,
|
1625
|
+
constant int64_t & ne00,
|
1626
|
+
constant int64_t & ne10,
|
1627
|
+
constant int64_t & ne0,
|
1628
|
+
constant int64_t & ne01[[buffer(4)]],
|
1629
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1630
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1631
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1542
1632
|
|
1543
|
-
const int
|
1633
|
+
const int ix = tiisg/4; // 0...7
|
1634
|
+
const int it = tiisg%4; // 0...3
|
1544
1635
|
|
1545
|
-
|
1636
|
+
const int nb = ne00/QK_K;
|
1637
|
+
const int r0 = tgpig.x;
|
1638
|
+
const int r1 = tgpig.y;
|
1639
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1640
|
+
const int ib_row = first_row * nb;
|
1641
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
|
1642
|
+
device const float * y = (device const float *) src1 + r1*ne10;
|
1643
|
+
float yl[8];
|
1644
|
+
float yh[8];
|
1645
|
+
float sumf[N_DST]={0.f}, all_sum;
|
1546
1646
|
|
1547
|
-
|
1548
|
-
device const float * y = yy + i * QK_K + il;
|
1647
|
+
const int step = sizeof(block_q4_K) * nb / 2;
|
1549
1648
|
|
1550
|
-
|
1551
|
-
const float m = (float)x[i].d[1];
|
1649
|
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
1552
1650
|
|
1553
|
-
|
1554
|
-
aux16[0] = a[0] & 0x0f0f;
|
1555
|
-
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
1651
|
+
uint16_t sc16[4];
|
1556
1652
|
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1653
|
+
for (int ib = ix; ib < nb; ib += 8) {
|
1654
|
+
|
1655
|
+
float2 sumy = {0.f, 0.f};
|
1656
|
+
for (int i = 0; i < 8; ++i) {
|
1657
|
+
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
1658
|
+
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
1560
1659
|
}
|
1561
|
-
}
|
1562
|
-
#endif
|
1563
1660
|
|
1564
|
-
|
1661
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
1662
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
1663
|
+
device const half * dh = x[ib].d;
|
1565
1664
|
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1574
|
-
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1665
|
+
for (int row = 0; row < N_DST; row++) {
|
1666
|
+
|
1667
|
+
sc16[0] = sc[0] & 0x000f;
|
1668
|
+
sc16[1] = sc[0] & 0x0f00;
|
1669
|
+
sc16[2] = sc[0] & 0x00f0;
|
1670
|
+
sc16[3] = sc[0] & 0xf000;
|
1671
|
+
|
1672
|
+
float2 acc1 = {0.f, 0.f};
|
1673
|
+
float2 acc2 = {0.f, 0.f};
|
1674
|
+
for (int i = 0; i < 8; i += 2) {
|
1675
|
+
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
1676
|
+
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
1677
|
+
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
1678
|
+
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
1679
|
+
}
|
1680
|
+
|
1681
|
+
float dall = dh[0];
|
1682
|
+
float dmin = dh[1];
|
1683
|
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
1684
|
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
1685
|
+
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
1686
|
+
|
1687
|
+
qs += step;
|
1688
|
+
sc += step;
|
1689
|
+
dh += step;
|
1690
|
+
}
|
1691
|
+
|
1692
|
+
y4 += 8 * QK_K;
|
1583
1693
|
}
|
1584
1694
|
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
1592
|
-
//}
|
1593
|
-
|
1594
|
-
//if (ith == 0) {
|
1595
|
-
// dst[r1*ne0 + r0] = sum[0];
|
1596
|
-
//}
|
1695
|
+
for (int row = 0; row < N_DST; ++row) {
|
1696
|
+
all_sum = simd_sum(sumf[row]);
|
1697
|
+
if (tiisg == 0) {
|
1698
|
+
dst[r1*ne0 + first_row + row] = all_sum;
|
1699
|
+
}
|
1700
|
+
}
|
1597
1701
|
}
|
1702
|
+
#endif
|
1598
1703
|
|
1599
1704
|
kernel void kernel_mul_mat_q5_K_f32(
|
1600
1705
|
device const void * src0,
|
@@ -1603,39 +1708,39 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1603
1708
|
constant int64_t & ne00,
|
1604
1709
|
constant int64_t & ne10,
|
1605
1710
|
constant int64_t & ne0,
|
1606
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1607
1711
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1608
|
-
|
1609
|
-
|
1712
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1713
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1610
1714
|
|
1611
1715
|
const int nb = ne00/QK_K;
|
1612
1716
|
|
1613
1717
|
const int64_t r0 = tgpig.x;
|
1614
1718
|
const int64_t r1 = tgpig.y;
|
1615
1719
|
|
1616
|
-
|
1720
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1721
|
+
|
1722
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
|
1617
1723
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1618
1724
|
|
1619
|
-
|
1620
|
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1725
|
+
float sumf[2]={0.f};
|
1621
1726
|
|
1622
|
-
|
1727
|
+
const int step = sizeof(block_q5_K) * nb;
|
1623
1728
|
|
1624
1729
|
#if QK_K == 256
|
1730
|
+
#
|
1731
|
+
float yl[16], yh[16];
|
1625
1732
|
|
1626
1733
|
const uint16_t kmask1 = 0x3f3f;
|
1627
1734
|
const uint16_t kmask2 = 0x0f0f;
|
1628
1735
|
const uint16_t kmask3 = 0xc0c0;
|
1629
1736
|
|
1630
|
-
const int tid =
|
1631
|
-
const int
|
1632
|
-
const int
|
1633
|
-
const int
|
1634
|
-
|
1635
|
-
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
1636
|
-
const int in = il%2;
|
1737
|
+
const int tid = tiisg/4;
|
1738
|
+
const int ix = tiisg%4;
|
1739
|
+
const int im = tid/4;
|
1740
|
+
const int ir = tid%4;
|
1741
|
+
const int n = 8;
|
1637
1742
|
|
1638
|
-
const int l0 = n*
|
1743
|
+
const int l0 = n*ir;
|
1639
1744
|
const int q_offset = 32*im + l0;
|
1640
1745
|
const int y_offset = 64*im + l0;
|
1641
1746
|
|
@@ -1644,78 +1749,113 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1644
1749
|
const uint8_t hm3 = hm1 << 4;
|
1645
1750
|
const uint8_t hm4 = hm2 << 4;
|
1646
1751
|
|
1647
|
-
|
1752
|
+
uint16_t sc16[4];
|
1753
|
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
1648
1754
|
|
1649
|
-
|
1755
|
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
1650
1756
|
|
1651
|
-
|
1652
|
-
device const uint8_t * q2 = q1 + 64;
|
1653
|
-
device const uint8_t * qh = (x + i)->qh + l0;
|
1654
|
-
device const float * y1 = yy + i*QK_K + y_offset;
|
1655
|
-
device const float * y2 = y1 + 128;
|
1757
|
+
for (int i = ix; i < nb; i += 4) {
|
1656
1758
|
|
1657
|
-
const
|
1658
|
-
const
|
1759
|
+
device const uint8_t * q1 = x[i].qs + q_offset;
|
1760
|
+
device const uint8_t * qh = x[i].qh + l0;
|
1761
|
+
device const half * dh = &x[i].d;
|
1762
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
|
1659
1763
|
|
1660
|
-
device const
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1764
|
+
device const float * y2 = y1 + 128;
|
1765
|
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
1766
|
+
for (int l = 0; l < 8; ++l) {
|
1767
|
+
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
1768
|
+
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
1769
|
+
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
1770
|
+
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
1771
|
+
}
|
1665
1772
|
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1773
|
+
for (int row = 0; row < 2; ++row) {
|
1774
|
+
|
1775
|
+
device const uint8_t * q2 = q1 + 64;
|
1669
1776
|
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1673
|
-
|
1674
|
-
|
1777
|
+
sc16[0] = a[0] & kmask1;
|
1778
|
+
sc16[1] = a[2] & kmask1;
|
1779
|
+
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
1780
|
+
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
1781
|
+
|
1782
|
+
float4 acc = {0.f, 0.f, 0.f, 0.f};
|
1783
|
+
for (int l = 0; l < n; ++l) {
|
1784
|
+
uint8_t h = qh[l];
|
1785
|
+
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
|
1786
|
+
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
|
1787
|
+
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
|
1788
|
+
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
|
1789
|
+
}
|
1790
|
+
const float dall = dh[0];
|
1791
|
+
const float dmin = dh[1];
|
1792
|
+
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
|
1793
|
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1794
|
+
|
1795
|
+
q1 += step;
|
1796
|
+
qh += step;
|
1797
|
+
dh += step/2;
|
1798
|
+
a += step/2;
|
1675
1799
|
|
1676
1800
|
}
|
1677
|
-
|
1801
|
+
|
1802
|
+
y1 += 4 * QK_K;
|
1678
1803
|
|
1679
1804
|
}
|
1680
1805
|
#else
|
1681
|
-
|
1682
|
-
const int im = il/8; // 0, 0, 1, 1
|
1683
|
-
const int in = il%8; // 0, 4, 0, 4
|
1806
|
+
float yl[8], yh[8];
|
1684
1807
|
|
1685
|
-
|
1808
|
+
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
1809
|
+
const int ix = tiisg%8;
|
1810
|
+
const int im = il/8; // 0, 0, 1, 1
|
1811
|
+
const int in = il%8; // 0, 4, 0, 4
|
1686
1812
|
|
1687
|
-
|
1813
|
+
device const float * y = yy + ix*QK_K + il;
|
1814
|
+
|
1815
|
+
for (int i = ix; i < nb; i += 8) {
|
1816
|
+
|
1817
|
+
for (int l = 0; l < 4; ++l) {
|
1818
|
+
yl[l+0] = y[l+ 0];
|
1819
|
+
yl[l+4] = y[l+16];
|
1820
|
+
yh[l+0] = y[l+32];
|
1821
|
+
yh[l+4] = y[l+48];
|
1822
|
+
}
|
1823
|
+
|
1824
|
+
device const half * dh = &x[i].d;
|
1688
1825
|
device const uint8_t * q = x[i].qs + il;
|
1689
1826
|
device const uint8_t * h = x[i].qh + in;
|
1690
1827
|
device const int8_t * s = x[i].scales;
|
1691
|
-
device const float * y = yy + i*QK_K + il;
|
1692
1828
|
|
1693
|
-
for (int
|
1694
|
-
|
1695
|
-
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1829
|
+
for (int row = 0; row < 2; ++row) {
|
1830
|
+
|
1831
|
+
const float d = dh[0];
|
1832
|
+
|
1833
|
+
float2 acc = {0.f, 0.f};
|
1834
|
+
for (int l = 0; l < 4; ++l) {
|
1835
|
+
const uint8_t hl = h[l] >> im;
|
1836
|
+
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
1837
|
+
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
1838
|
+
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
1839
|
+
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
1840
|
+
}
|
1841
|
+
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
1842
|
+
|
1843
|
+
q += step;
|
1844
|
+
h += step;
|
1845
|
+
s += step;
|
1846
|
+
dh += step/2;
|
1847
|
+
|
1699
1848
|
}
|
1849
|
+
|
1850
|
+
y += 8 * QK_K;
|
1700
1851
|
}
|
1701
1852
|
#endif
|
1702
|
-
sum[ith] = sumf;
|
1703
1853
|
|
1704
|
-
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1708
|
-
|
1709
|
-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
1710
|
-
}
|
1711
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1712
|
-
if (ith%16 == 0) {
|
1713
|
-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
1714
|
-
}
|
1715
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1716
|
-
if (ith == 0) {
|
1717
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1718
|
-
dst[r1*ne0 + r0] = sum[0];
|
1854
|
+
for (int row = 0; row < 2; ++row) {
|
1855
|
+
const float tot = simd_sum(sumf[row]);
|
1856
|
+
if (tiisg == 0) {
|
1857
|
+
dst[r1*ne0 + first_row + row] = tot;
|
1858
|
+
}
|
1719
1859
|
}
|
1720
1860
|
|
1721
1861
|
}
|
@@ -1727,10 +1867,9 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1727
1867
|
constant int64_t & ne00,
|
1728
1868
|
constant int64_t & ne10,
|
1729
1869
|
constant int64_t & ne0,
|
1730
|
-
threadgroup float * sum [[threadgroup(0)]],
|
1731
1870
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1732
|
-
|
1733
|
-
|
1871
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1872
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1734
1873
|
|
1735
1874
|
const uint8_t kmask1 = 0x03;
|
1736
1875
|
const uint8_t kmask2 = 0x0C;
|
@@ -1742,19 +1881,18 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1742
1881
|
const int64_t r0 = tgpig.x;
|
1743
1882
|
const int64_t r1 = tgpig.y;
|
1744
1883
|
|
1745
|
-
|
1746
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1884
|
+
const int row = 2 * r0 + sgitg;
|
1747
1885
|
|
1748
|
-
const
|
1749
|
-
const
|
1886
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
|
1887
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1750
1888
|
|
1751
1889
|
float sumf = 0;
|
1752
1890
|
|
1753
1891
|
#if QK_K == 256
|
1754
|
-
|
1755
|
-
const int
|
1756
|
-
const int ip =
|
1757
|
-
const int il =
|
1892
|
+
const int tid = tiisg/2;
|
1893
|
+
const int ix = tiisg%2;
|
1894
|
+
const int ip = tid/8; // 0 or 1
|
1895
|
+
const int il = tid%8;
|
1758
1896
|
const int n = 4;
|
1759
1897
|
const int l0 = n*il;
|
1760
1898
|
const int is = 8*ip + l0/16;
|
@@ -1763,9 +1901,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1763
1901
|
const int q_offset_l = 64*ip + l0;
|
1764
1902
|
const int q_offset_h = 32*ip + l0;
|
1765
1903
|
|
1766
|
-
for (int i =
|
1904
|
+
for (int i = ix; i < nb; i += 2) {
|
1767
1905
|
|
1768
|
-
device const uint8_t *
|
1906
|
+
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
1907
|
+
device const uint8_t * q2 = q1 + 32;
|
1769
1908
|
device const uint8_t * qh = x[i].qh + q_offset_h;
|
1770
1909
|
device const int8_t * sc = x[i].scales + is;
|
1771
1910
|
|
@@ -1775,19 +1914,21 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1775
1914
|
|
1776
1915
|
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
1777
1916
|
for (int l = 0; l < n; ++l) {
|
1778
|
-
sums[0] += y[l+ 0] * ((int8_t)((
|
1779
|
-
sums[1] += y[l+32] * ((int8_t)((
|
1780
|
-
sums[2] += y[l+64] * ((int8_t)((
|
1781
|
-
sums[3] += y[l+96] * ((int8_t)((
|
1917
|
+
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
1918
|
+
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
1919
|
+
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
1920
|
+
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
1782
1921
|
}
|
1783
1922
|
|
1784
1923
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
1785
1924
|
|
1786
1925
|
}
|
1926
|
+
|
1787
1927
|
#else
|
1788
|
-
const int
|
1928
|
+
const int ix = tiisg/4;
|
1929
|
+
const int il = 4*(tiisg%4);
|
1789
1930
|
|
1790
|
-
for (int i =
|
1931
|
+
for (int i = ix; i < nb; i += 8) {
|
1791
1932
|
device const float * y = yy + i * QK_K + il;
|
1792
1933
|
device const uint8_t * ql = x[i].ql + il;
|
1793
1934
|
device const uint8_t * qh = x[i].qh + il;
|
@@ -1807,23 +1948,8 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1807
1948
|
|
1808
1949
|
#endif
|
1809
1950
|
|
1810
|
-
|
1811
|
-
|
1812
|
-
|
1813
|
-
// Accumulate the sum from all threads in the threadgroup
|
1814
|
-
//
|
1815
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1816
|
-
if (ith%4 == 0) {
|
1817
|
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1818
|
-
}
|
1819
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1820
|
-
if (ith%16 == 0) {
|
1821
|
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1822
|
-
}
|
1823
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1824
|
-
if (ith == 0) {
|
1825
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1826
|
-
dst[r1*ne0 + r0] = sum[0];
|
1951
|
+
const float tot = simd_sum(sumf);
|
1952
|
+
if (tiisg == 0) {
|
1953
|
+
dst[r1*ne0 + row] = tot;
|
1827
1954
|
}
|
1828
|
-
|
1829
1955
|
}
|