llama_cpp 0.2.0 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/examples/README.md +92 -0
- data/examples/chat.rb +195 -0
- data/examples/embedding.rb +37 -0
- data/ext/llama_cpp/llama_cpp.cpp +52 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1218 -411
- data/ext/llama_cpp/src/ggml-cuda.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +703 -514
- data/ext/llama_cpp/src/ggml-metal.metal +574 -122
- data/ext/llama_cpp/src/ggml-opencl.cpp +496 -36
- data/ext/llama_cpp/src/ggml-opencl.h +1 -2
- data/ext/llama_cpp/src/ggml.c +2715 -476
- data/ext/llama_cpp/src/ggml.h +266 -11
- data/ext/llama_cpp/src/llama.cpp +266 -135
- data/ext/llama_cpp/src/llama.h +19 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -0
- metadata +5 -2
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
|
|
256
256
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
257
257
|
}
|
258
258
|
|
259
|
+
kernel void kernel_norm(
|
260
|
+
device const void * src0,
|
261
|
+
device float * dst,
|
262
|
+
constant int64_t & ne00,
|
263
|
+
constant uint64_t & nb01,
|
264
|
+
constant float & eps,
|
265
|
+
threadgroup float * sum [[threadgroup(0)]],
|
266
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
267
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
268
|
+
uint ntg[[threads_per_threadgroup]]) {
|
269
|
+
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
270
|
+
// MEAN
|
271
|
+
// parallel sum
|
272
|
+
sum[tpitg] = 0.0f;
|
273
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
274
|
+
sum[tpitg] += x[i00];
|
275
|
+
}
|
276
|
+
// reduce
|
277
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
278
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
279
|
+
if (tpitg < i) {
|
280
|
+
sum[tpitg] += sum[tpitg + i];
|
281
|
+
}
|
282
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
283
|
+
}
|
284
|
+
// broadcast
|
285
|
+
if (tpitg == 0) {
|
286
|
+
sum[0] /= ne00;
|
287
|
+
}
|
288
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
289
|
+
const float mean = sum[0];
|
290
|
+
|
291
|
+
// recenter
|
292
|
+
device float * y = dst + tgpig*ne00;
|
293
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
294
|
+
y[i00] = x[i00] - mean;
|
295
|
+
}
|
296
|
+
|
297
|
+
// VARIANCE
|
298
|
+
// parallel sum
|
299
|
+
sum[tpitg] = 0.0f;
|
300
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
301
|
+
sum[tpitg] += y[i00] * y[i00];
|
302
|
+
}
|
303
|
+
// reduce
|
304
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
305
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
306
|
+
if (tpitg < i) {
|
307
|
+
sum[tpitg] += sum[tpitg + i];
|
308
|
+
}
|
309
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
310
|
+
}
|
311
|
+
// broadcast
|
312
|
+
if (tpitg == 0) {
|
313
|
+
sum[0] /= ne00;
|
314
|
+
}
|
315
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
316
|
+
const float variance = sum[0];
|
317
|
+
|
318
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
319
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
320
|
+
y[i00] = y[i00] * scale;
|
321
|
+
}
|
322
|
+
}
|
323
|
+
|
324
|
+
|
259
325
|
kernel void kernel_rms_norm(
|
260
326
|
device const void * src0,
|
261
327
|
device float * dst,
|
@@ -304,34 +370,22 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
304
370
|
device const float * src1,
|
305
371
|
device float * dst,
|
306
372
|
constant int64_t & ne00,
|
307
|
-
constant int64_t & ne01,
|
308
|
-
constant uint64_t & nb00,
|
309
|
-
constant uint64_t & nb01,
|
310
|
-
constant uint64_t & nb02,
|
311
373
|
constant int64_t & ne10,
|
312
|
-
constant int64_t & ne11,
|
313
|
-
constant uint64_t & nb10,
|
314
|
-
constant uint64_t & nb11,
|
315
|
-
constant uint64_t & nb12,
|
316
374
|
constant int64_t & ne0,
|
317
|
-
constant int64_t & ne1,
|
318
375
|
threadgroup float * sum [[threadgroup(0)]],
|
319
376
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
320
|
-
uint2 tpig[[thread_position_in_grid]],
|
321
377
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
322
378
|
uint2 tptg[[threads_per_threadgroup]]) {
|
323
379
|
const int nb = ne00/QK4_0;
|
324
380
|
|
325
|
-
const int8_t m8 = 8;
|
326
|
-
|
327
381
|
const int64_t r0 = tgpig.x;
|
328
382
|
const int64_t r1 = tgpig.y;
|
329
383
|
|
330
384
|
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
|
331
385
|
device const float * y = (device const float *) src1 + r1*ne10;
|
332
386
|
|
333
|
-
const
|
334
|
-
const
|
387
|
+
const int nth = tptg.x*tptg.y;
|
388
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
335
389
|
|
336
390
|
const int ix = tpitg.y/4; // 0 or 1
|
337
391
|
const int iy = tpitg.y - 4*ix; // 0...3
|
@@ -351,47 +405,32 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
351
405
|
|
352
406
|
for (int j = 0; j < 4; ++j) {
|
353
407
|
|
354
|
-
acc[0] += yl[j
|
355
|
-
acc[1] += yl[j
|
408
|
+
acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
|
409
|
+
acc[1] += yl[j] + yl[j+16];
|
356
410
|
|
357
411
|
}
|
358
412
|
|
359
|
-
sumf += d * (acc[0]
|
413
|
+
sumf += d * (acc[0] - 8.f*acc[1]);
|
360
414
|
}
|
361
415
|
|
362
416
|
sum[ith] = sumf;
|
363
417
|
|
364
418
|
//
|
365
419
|
// Accumulate the sum from all threads in the threadgroup
|
366
|
-
// This version is slightly faster than the commented out one below,
|
367
|
-
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
368
420
|
//
|
369
421
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
370
422
|
if (ith%4 == 0) {
|
371
|
-
|
423
|
+
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
372
424
|
}
|
373
425
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
374
426
|
if (ith%16 == 0) {
|
375
|
-
|
427
|
+
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
376
428
|
}
|
377
429
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
378
430
|
if (ith == 0) {
|
379
|
-
for (
|
431
|
+
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
380
432
|
dst[r1*ne0 + r0] = sum[0];
|
381
433
|
}
|
382
|
-
|
383
|
-
//// accumulate the sum from all threads in the threadgroup
|
384
|
-
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
385
|
-
//for (uint i = nth/2; i > 0; i /= 2) {
|
386
|
-
// if (ith < i) {
|
387
|
-
// sum[ith] += sum[ith + i];
|
388
|
-
// }
|
389
|
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
390
|
-
//}
|
391
|
-
|
392
|
-
//if (ith == 0) {
|
393
|
-
// dst[r1*ne0 + r0] = sum[0];
|
394
|
-
//}
|
395
434
|
}
|
396
435
|
|
397
436
|
kernel void kernel_mul_mat_q4_1_f32(
|
@@ -399,20 +438,10 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
399
438
|
device const float * src1,
|
400
439
|
device float * dst,
|
401
440
|
constant int64_t & ne00,
|
402
|
-
constant int64_t & ne01,
|
403
|
-
constant uint64_t & nb00,
|
404
|
-
constant uint64_t & nb01,
|
405
|
-
constant uint64_t & nb02,
|
406
441
|
constant int64_t & ne10,
|
407
|
-
constant int64_t & ne11,
|
408
|
-
constant uint64_t & nb10,
|
409
|
-
constant uint64_t & nb11,
|
410
|
-
constant uint64_t & nb12,
|
411
442
|
constant int64_t & ne0,
|
412
|
-
constant int64_t & ne1,
|
413
443
|
threadgroup float * sum [[threadgroup(0)]],
|
414
444
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
415
|
-
uint2 tpig[[thread_position_in_grid]],
|
416
445
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
417
446
|
uint2 tptg[[threads_per_threadgroup]]) {
|
418
447
|
const int nb = ne00/QK4_1;
|
@@ -460,11 +489,11 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
460
489
|
//
|
461
490
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
462
491
|
if (ith%4 == 0) {
|
463
|
-
|
492
|
+
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
464
493
|
}
|
465
494
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
466
495
|
if (ith%16 == 0) {
|
467
|
-
|
496
|
+
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
468
497
|
}
|
469
498
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
470
499
|
if (ith == 0) {
|
@@ -522,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
|
|
522
551
|
}
|
523
552
|
}
|
524
553
|
|
554
|
+
kernel void kernel_alibi_f32(
|
555
|
+
device const float * src0,
|
556
|
+
device float * dst,
|
557
|
+
constant int64_t & ne00,
|
558
|
+
constant int64_t & ne01,
|
559
|
+
constant int64_t & ne02,
|
560
|
+
constant int64_t & ne03,
|
561
|
+
constant uint64_t & nb00,
|
562
|
+
constant uint64_t & nb01,
|
563
|
+
constant uint64_t & nb02,
|
564
|
+
constant uint64_t & nb03,
|
565
|
+
constant int64_t & ne0,
|
566
|
+
constant int64_t & ne1,
|
567
|
+
constant int64_t & ne2,
|
568
|
+
constant int64_t & ne3,
|
569
|
+
constant uint64_t & nb0,
|
570
|
+
constant uint64_t & nb1,
|
571
|
+
constant uint64_t & nb2,
|
572
|
+
constant uint64_t & nb3,
|
573
|
+
constant float & m0,
|
574
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
575
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
576
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
577
|
+
const int64_t i03 = tgpig[2];
|
578
|
+
const int64_t i02 = tgpig[1];
|
579
|
+
const int64_t i01 = tgpig[0];
|
580
|
+
|
581
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
582
|
+
|
583
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
584
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
585
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
586
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
587
|
+
|
588
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
589
|
+
float m_k = pow(m0, i2 + 1);
|
590
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
591
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
592
|
+
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
593
|
+
}
|
594
|
+
}
|
595
|
+
|
525
596
|
kernel void kernel_rope(
|
526
597
|
device const void * src0,
|
527
598
|
device float * dst,
|
@@ -577,6 +648,47 @@ kernel void kernel_rope(
|
|
577
648
|
}
|
578
649
|
}
|
579
650
|
|
651
|
+
kernel void kernel_cpy_f16_f16(
|
652
|
+
device const half * src0,
|
653
|
+
device half * dst,
|
654
|
+
constant int64_t & ne00,
|
655
|
+
constant int64_t & ne01,
|
656
|
+
constant int64_t & ne02,
|
657
|
+
constant int64_t & ne03,
|
658
|
+
constant uint64_t & nb00,
|
659
|
+
constant uint64_t & nb01,
|
660
|
+
constant uint64_t & nb02,
|
661
|
+
constant uint64_t & nb03,
|
662
|
+
constant int64_t & ne0,
|
663
|
+
constant int64_t & ne1,
|
664
|
+
constant int64_t & ne2,
|
665
|
+
constant int64_t & ne3,
|
666
|
+
constant uint64_t & nb0,
|
667
|
+
constant uint64_t & nb1,
|
668
|
+
constant uint64_t & nb2,
|
669
|
+
constant uint64_t & nb3,
|
670
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
671
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
672
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
673
|
+
const int64_t i03 = tgpig[2];
|
674
|
+
const int64_t i02 = tgpig[1];
|
675
|
+
const int64_t i01 = tgpig[0];
|
676
|
+
|
677
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
678
|
+
|
679
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
680
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
681
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
682
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
683
|
+
|
684
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
685
|
+
|
686
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
687
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
688
|
+
dst_data[i00] = src[0];
|
689
|
+
}
|
690
|
+
}
|
691
|
+
|
580
692
|
kernel void kernel_cpy_f32_f16(
|
581
693
|
device const float * src0,
|
582
694
|
device half * dst,
|
@@ -671,6 +783,15 @@ typedef struct {
|
|
671
783
|
half d; // super-block scale for quantized scales
|
672
784
|
half dmin; // super-block scale for quantized mins
|
673
785
|
} block_q2_k;
|
786
|
+
// 84 bytes / block
|
787
|
+
|
788
|
+
typedef struct {
|
789
|
+
uint8_t hmask[QK_K/8]; // quants - high bit
|
790
|
+
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
791
|
+
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
792
|
+
half d; // super-block scale
|
793
|
+
} block_q3_k;
|
794
|
+
// 110 bytes / block
|
674
795
|
|
675
796
|
typedef struct {
|
676
797
|
half d; // super-block scale for quantized scales
|
@@ -678,6 +799,16 @@ typedef struct {
|
|
678
799
|
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
679
800
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
680
801
|
} block_q4_k;
|
802
|
+
// 144 bytes / block
|
803
|
+
|
804
|
+
typedef struct {
|
805
|
+
half d; // super-block scale for quantized scales
|
806
|
+
half dmin; // super-block scale for quantized mins
|
807
|
+
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
808
|
+
uint8_t qh[QK_K/8]; // quants, high bit
|
809
|
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
810
|
+
} block_q5_k;
|
811
|
+
// 176 bytes / block
|
681
812
|
|
682
813
|
typedef struct {
|
683
814
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
@@ -685,16 +816,19 @@ typedef struct {
|
|
685
816
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
686
817
|
half d; // super-block scale
|
687
818
|
} block_q6_k;
|
819
|
+
// 210 bytes / block
|
688
820
|
|
689
821
|
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
690
822
|
uchar4 r;
|
691
823
|
if (j < 4) {
|
692
|
-
r[0] = q[j+0] & 63;
|
693
|
-
r[2] = q[j+1] & 63;
|
824
|
+
r[0] = q[j+0] & 63;
|
825
|
+
r[2] = q[j+1] & 63;
|
826
|
+
r[1] = q[j+4] & 63;
|
827
|
+
r[3] = q[j+5] & 63;
|
694
828
|
} else {
|
695
829
|
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
696
|
-
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
697
830
|
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
|
831
|
+
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
698
832
|
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
|
699
833
|
}
|
700
834
|
return r;
|
@@ -735,10 +869,65 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
735
869
|
}
|
736
870
|
}
|
737
871
|
|
872
|
+
static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
|
873
|
+
assert(k % QK_K == 0);
|
874
|
+
const int nb = k / QK_K;
|
875
|
+
|
876
|
+
const uint16_t kmask1 = 0x0303;
|
877
|
+
const uint16_t kmask2 = 0x0f0f;
|
878
|
+
|
879
|
+
uint16_t aux[8];
|
880
|
+
thread const int8_t * scales = (thread const int8_t*)aux;
|
881
|
+
|
882
|
+
for (int i = 0; i < nb; i++) {
|
883
|
+
|
884
|
+
const float d_all = (float)(x[i].d);
|
885
|
+
|
886
|
+
device const uint8_t * q = x[i].qs;
|
887
|
+
device const uint8_t * h = x[i].hmask;
|
888
|
+
uint8_t m = 1;
|
889
|
+
|
890
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
891
|
+
aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
|
892
|
+
aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
|
893
|
+
aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
|
894
|
+
aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
|
895
|
+
aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
|
896
|
+
aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
|
897
|
+
aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
|
898
|
+
aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
|
899
|
+
|
900
|
+
int is = 0;
|
901
|
+
float dl;
|
902
|
+
for (int n = 0; n < QK_K; n += 128) {
|
903
|
+
int shift = 0;
|
904
|
+
for (int j = 0; j < 4; ++j) {
|
905
|
+
|
906
|
+
dl = d_all * (scales[is++] - 32);
|
907
|
+
for (int l = 0; l < 16; ++l) {
|
908
|
+
*y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
909
|
+
}
|
910
|
+
|
911
|
+
dl = d_all * (scales[is++] - 32);
|
912
|
+
for (int l = 0; l < 16; ++l) {
|
913
|
+
*y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
914
|
+
}
|
915
|
+
|
916
|
+
shift += 2;
|
917
|
+
m <<= 1;
|
918
|
+
}
|
919
|
+
q += 32;
|
920
|
+
}
|
921
|
+
|
922
|
+
}
|
923
|
+
|
924
|
+
}
|
925
|
+
|
738
926
|
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
|
739
927
|
assert(k % QK_K == 0);
|
740
928
|
const int nb = k / QK_K;
|
741
929
|
|
930
|
+
|
742
931
|
for (int i = 0; i < nb; i++) {
|
743
932
|
|
744
933
|
const float d = x[i].d;
|
@@ -760,6 +949,33 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
|
760
949
|
}
|
761
950
|
}
|
762
951
|
|
952
|
+
static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
|
953
|
+
assert(k % QK_K == 0);
|
954
|
+
const int nb = k / QK_K;
|
955
|
+
|
956
|
+
for (int i = 0; i < nb; i++) {
|
957
|
+
|
958
|
+
const float d = (float)(x[i].d);
|
959
|
+
const float min = (float)(x[i].dmin);
|
960
|
+
|
961
|
+
device const uint8_t * ql = x[i].qs;
|
962
|
+
device const uint8_t * qh = x[i].qh;
|
963
|
+
|
964
|
+
int is = 0;
|
965
|
+
uint8_t u1 = 1, u2 = 2;
|
966
|
+
for (int j = 0; j < QK_K; j += 64) {
|
967
|
+
const uchar4 sc = get_scale_min_k4(is, x[i].scales);
|
968
|
+
const float d1 = d * sc[0]; const float m1 = min * sc[1];
|
969
|
+
const float d2 = d * sc[2]; const float m2 = min * sc[3];
|
970
|
+
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
|
971
|
+
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
|
972
|
+
ql += 32; is += 2;
|
973
|
+
u1 <<= 2; u2 <<= 2;
|
974
|
+
}
|
975
|
+
}
|
976
|
+
|
977
|
+
}
|
978
|
+
|
763
979
|
static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
|
764
980
|
assert(k % QK_K == 0);
|
765
981
|
const int nb = k / QK_K;
|
@@ -808,6 +1024,22 @@ kernel void kernel_get_rows_q2_k(
|
|
808
1024
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
809
1025
|
}
|
810
1026
|
|
1027
|
+
kernel void kernel_get_rows_q3_k(
|
1028
|
+
device const void * src0,
|
1029
|
+
device const int * src1,
|
1030
|
+
device float * dst,
|
1031
|
+
constant int64_t & ne00,
|
1032
|
+
constant uint64_t & nb01,
|
1033
|
+
constant uint64_t & nb1,
|
1034
|
+
uint tpig[[thread_position_in_grid]]) {
|
1035
|
+
const int i = tpig;
|
1036
|
+
const int r = ((device int32_t *) src1)[i];
|
1037
|
+
|
1038
|
+
dequantize_row_q3_k(
|
1039
|
+
(device const block_q3_k *) ((device char *) src0 + r*nb01),
|
1040
|
+
(device float *) ((device char *) dst + i*nb1), ne00);
|
1041
|
+
}
|
1042
|
+
|
811
1043
|
kernel void kernel_get_rows_q4_k(
|
812
1044
|
device const void * src0,
|
813
1045
|
device const int * src1,
|
@@ -824,6 +1056,22 @@ kernel void kernel_get_rows_q4_k(
|
|
824
1056
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
825
1057
|
}
|
826
1058
|
|
1059
|
+
kernel void kernel_get_rows_q5_k(
|
1060
|
+
device const void * src0,
|
1061
|
+
device const int * src1,
|
1062
|
+
device float * dst,
|
1063
|
+
constant int64_t & ne00,
|
1064
|
+
constant uint64_t & nb01,
|
1065
|
+
constant uint64_t & nb1,
|
1066
|
+
uint tpig[[thread_position_in_grid]]) {
|
1067
|
+
const int i = tpig;
|
1068
|
+
const int r = ((device int32_t *) src1)[i];
|
1069
|
+
|
1070
|
+
dequantize_row_q5_k(
|
1071
|
+
(device const block_q5_k *) ((device char *) src0 + r*nb01),
|
1072
|
+
(device float *) ((device char *) dst + i*nb1), ne00);
|
1073
|
+
}
|
1074
|
+
|
827
1075
|
kernel void kernel_get_rows_q6_k(
|
828
1076
|
device const void * src0,
|
829
1077
|
device const int * src1,
|
@@ -847,20 +1095,10 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
847
1095
|
device const float * src1,
|
848
1096
|
device float * dst,
|
849
1097
|
constant int64_t & ne00,
|
850
|
-
constant int64_t & ne01,
|
851
|
-
constant uint64_t & nb00,
|
852
|
-
constant uint64_t & nb01,
|
853
|
-
constant uint64_t & nb02,
|
854
1098
|
constant int64_t & ne10,
|
855
|
-
constant int64_t & ne11,
|
856
|
-
constant uint64_t & nb10,
|
857
|
-
constant uint64_t & nb11,
|
858
|
-
constant uint64_t & nb12,
|
859
1099
|
constant int64_t & ne0,
|
860
|
-
constant int64_t & ne1,
|
861
1100
|
threadgroup float * sum [[threadgroup(0)]],
|
862
1101
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
863
|
-
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
|
864
1102
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
865
1103
|
uint2 tptg[[threads_per_threadgroup]]) {
|
866
1104
|
|
@@ -875,7 +1113,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
875
1113
|
const int nth = tptg.x*tptg.y;
|
876
1114
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
877
1115
|
|
878
|
-
|
879
1116
|
const int tid = tpitg.y; // 0...16
|
880
1117
|
const int il = tid/4; // 0...3
|
881
1118
|
const int ir = tid%4; // 0...3
|
@@ -885,35 +1122,54 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
885
1122
|
const int n = 8;
|
886
1123
|
const int is = 4*il + (n*ir)/16;
|
887
1124
|
|
1125
|
+
const int y_offset = 64*il + n*ir;
|
1126
|
+
const int q_offset = 32*ip + n*ir;
|
1127
|
+
|
888
1128
|
sum[ith] = 0.0f;
|
889
1129
|
|
890
1130
|
float sumf = 0;
|
891
1131
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
892
1132
|
|
893
|
-
device const uint8_t * q = x[i].qs +
|
1133
|
+
device const uint8_t * q = x[i].qs + q_offset;
|
894
1134
|
device const uint8_t * scales = x[i].scales + is;
|
895
1135
|
|
896
1136
|
uint8_t d1 = scales[0] & 0xF;
|
897
|
-
uint8_t m1 = scales[0] >> 4;
|
898
1137
|
uint8_t d2 = scales[2] & 0xF;
|
1138
|
+
uint8_t m1 = scales[0] >> 4;
|
899
1139
|
uint8_t m2 = scales[2] >> 4;
|
900
1140
|
|
901
|
-
device const float * y = yy + i*QK_K +
|
1141
|
+
device const float * y = yy + i*QK_K + y_offset;
|
902
1142
|
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
float4 s = {0.f, 0.f, 0.f, 0.f};
|
1143
|
+
//float4 s = {0.f, 0.f, 0.f, 0.f};
|
1144
|
+
float2 s = {0.f, 0.f};
|
1145
|
+
float smin = 0;
|
907
1146
|
for (int l = 0; l < n; ++l) {
|
908
|
-
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
|
909
|
-
s[
|
1147
|
+
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
|
1148
|
+
s[1] += y[l+32] * ((q[l] >> shift2) & 3);
|
1149
|
+
smin += y[l+ 0] * m1 + y[l+32] * m2;
|
910
1150
|
}
|
911
|
-
sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
|
912
1151
|
|
1152
|
+
const float dall = (float)x[i].d;
|
1153
|
+
const float dmin = (float)x[i].dmin;
|
1154
|
+
|
1155
|
+
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
913
1156
|
|
914
1157
|
}
|
915
1158
|
sum[ith] = sumf;
|
916
1159
|
|
1160
|
+
//int mask1 = (ith%4 == 0);
|
1161
|
+
//int mask2 = (ith%16 == 0);
|
1162
|
+
|
1163
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
1164
|
+
//for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
|
1165
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
1166
|
+
//for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
|
1167
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
1168
|
+
//if (ith == 0) {
|
1169
|
+
// for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1170
|
+
// dst[r1*ne0 + r0] = sum[0];
|
1171
|
+
//}
|
1172
|
+
|
917
1173
|
//
|
918
1174
|
// Accumulate the sum from all threads in the threadgroup
|
919
1175
|
// This version is slightly faster than the commented out one below,
|
@@ -932,19 +1188,109 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
932
1188
|
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
933
1189
|
dst[r1*ne0 + r0] = sum[0];
|
934
1190
|
}
|
1191
|
+
}
|
935
1192
|
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
1193
|
+
kernel void kernel_mul_mat_q3_k_f32(
|
1194
|
+
device const void * src0,
|
1195
|
+
device const float * src1,
|
1196
|
+
device float * dst,
|
1197
|
+
constant int64_t & ne00,
|
1198
|
+
constant int64_t & ne10,
|
1199
|
+
constant int64_t & ne0,
|
1200
|
+
constant int64_t & ne1,
|
1201
|
+
threadgroup float * sum [[threadgroup(0)]],
|
1202
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1203
|
+
uint2 tpitg[[thread_position_in_threadgroup]],
|
1204
|
+
uint2 tptg[[threads_per_threadgroup]]) {
|
1205
|
+
|
1206
|
+
const uint16_t kmask1 = 0x0303;
|
1207
|
+
const uint16_t kmask2 = 0x0f0f;
|
1208
|
+
|
1209
|
+
const uint8_t m3 = 3;
|
1210
|
+
const int8_t m4 = 4;
|
1211
|
+
|
1212
|
+
const int nb = ne00/QK_K;
|
1213
|
+
|
1214
|
+
const int64_t r0 = tgpig.x;
|
1215
|
+
const int64_t r1 = tgpig.y;
|
1216
|
+
|
1217
|
+
device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
|
1218
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1219
|
+
|
1220
|
+
const int nth = tptg.x*tptg.y;
|
1221
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1222
|
+
|
1223
|
+
const int tid = tpitg.y; // expecting 16
|
1224
|
+
const int ip = tid/8; // 0 or 1
|
1225
|
+
const int il = tid/2 - 4*ip; // 0...3
|
1226
|
+
const int ir = tid%2;
|
1227
|
+
const int n = 8;
|
1228
|
+
const int l0 = n*ir;
|
1229
|
+
|
1230
|
+
const uint8_t m = 1 << (4*ip + il);
|
1231
|
+
|
1232
|
+
const int shift = 2*il;
|
1233
|
+
|
1234
|
+
const uint16_t s_shift1 = 4*ip;
|
1235
|
+
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
1236
|
+
const int ik = 4 + (il%2);
|
1237
|
+
|
1238
|
+
const int q_offset = 32*ip + l0;
|
1239
|
+
const int y_offset = 128*ip + 32*il + l0;
|
1240
|
+
|
1241
|
+
//float sumf = 0;
|
1242
|
+
float sumf1 = 0, sumf2 = 0;
|
1243
|
+
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1244
|
+
|
1245
|
+
const float d_all = (float)(x[i].d);
|
1246
|
+
|
1247
|
+
device const uint8_t * q = x[i].qs + q_offset;
|
1248
|
+
device const uint8_t * h = x[i].hmask + l0;
|
1249
|
+
device const float * y = yy + i * QK_K + y_offset;
|
1250
|
+
|
1251
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
1252
|
+
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1253
|
+
|
1254
|
+
float s = 0;
|
1255
|
+
for (int l = 0; l < n; ++l) {
|
1256
|
+
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
|
1257
|
+
}
|
1258
|
+
float d = d_all * s;
|
1259
|
+
sumf1 += d * scales[0];
|
1260
|
+
sumf2 += d;
|
1261
|
+
//sumf += d_all * s * (scales[0] - 32);
|
1262
|
+
|
1263
|
+
s = 0;
|
1264
|
+
for (int l = 0; l < n; ++l) {
|
1265
|
+
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
|
1266
|
+
}
|
1267
|
+
d = d_all * s;
|
1268
|
+
sumf1 += d * scales[1];
|
1269
|
+
sumf2 += d;
|
1270
|
+
//sumf += d_all * s * (scales[1] - 32);
|
1271
|
+
|
1272
|
+
}
|
1273
|
+
|
1274
|
+
//sum[ith] = sumf;
|
1275
|
+
sum[ith] = sumf1 - 32.f*sumf2;
|
1276
|
+
|
1277
|
+
//
|
1278
|
+
// Accumulate the sum from all threads in the threadgroup
|
1279
|
+
//
|
1280
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1281
|
+
if (ith%4 == 0) {
|
1282
|
+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1283
|
+
}
|
1284
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1285
|
+
if (ith%16 == 0) {
|
1286
|
+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1287
|
+
}
|
1288
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1289
|
+
if (ith == 0) {
|
1290
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1291
|
+
dst[r1*ne0 + r0] = sum[0];
|
1292
|
+
}
|
944
1293
|
|
945
|
-
//if (ith == 0) {
|
946
|
-
// dst[r1*ne0 + r0] = sum[0];
|
947
|
-
//}
|
948
1294
|
}
|
949
1295
|
|
950
1296
|
kernel void kernel_mul_mat_q4_k_f32(
|
@@ -952,23 +1298,17 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
952
1298
|
device const float * src1,
|
953
1299
|
device float * dst,
|
954
1300
|
constant int64_t & ne00,
|
955
|
-
constant int64_t & ne01,
|
956
|
-
constant uint64_t & nb00,
|
957
|
-
constant uint64_t & nb01,
|
958
|
-
constant uint64_t & nb02,
|
959
1301
|
constant int64_t & ne10,
|
960
|
-
constant int64_t & ne11,
|
961
|
-
constant uint64_t & nb10,
|
962
|
-
constant uint64_t & nb11,
|
963
|
-
constant uint64_t & nb12,
|
964
1302
|
constant int64_t & ne0,
|
965
|
-
constant int64_t & ne1,
|
966
1303
|
threadgroup float * sum [[threadgroup(0)]],
|
967
1304
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
968
|
-
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
|
969
1305
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
970
1306
|
uint2 tptg[[threads_per_threadgroup]]) {
|
971
1307
|
|
1308
|
+
const uint16_t kmask1 = 0x3f3f;
|
1309
|
+
const uint16_t kmask2 = 0x0f0f;
|
1310
|
+
const uint16_t kmask3 = 0xc0c0;
|
1311
|
+
|
972
1312
|
const int nb = ne00/QK_K;
|
973
1313
|
|
974
1314
|
const int64_t r0 = tgpig.x;
|
@@ -977,37 +1317,55 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
977
1317
|
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
978
1318
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
979
1319
|
|
980
|
-
const
|
981
|
-
const
|
1320
|
+
const int nth = tptg.x*tptg.y;
|
1321
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
982
1322
|
|
983
1323
|
const int tid = tpitg.y; // 0...16
|
984
1324
|
const int il = tid/4; // 0...3
|
985
|
-
const int ir = tid
|
986
|
-
const int n =
|
987
|
-
|
1325
|
+
const int ir = tid - 4*il;// 0...3
|
1326
|
+
const int n = 4;
|
1327
|
+
|
1328
|
+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
1329
|
+
const int in = il%2;
|
1330
|
+
|
1331
|
+
const int l0 = n*(2*ir + in);
|
1332
|
+
const int q_offset = 32*im + l0;
|
1333
|
+
const int y_offset = 64*im + l0;
|
988
1334
|
|
989
1335
|
sum[ith] = 0.0f;
|
990
1336
|
|
1337
|
+
uchar2 sc1, sc2, sc3, sc4;
|
1338
|
+
|
991
1339
|
float sumf = 0;
|
992
1340
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
993
1341
|
|
994
|
-
device const uint8_t *
|
995
|
-
device const
|
996
|
-
device const
|
1342
|
+
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
1343
|
+
device const uint8_t * q2 = q1 + 64;
|
1344
|
+
device const float * y1 = yy + i*QK_K + y_offset;
|
1345
|
+
device const float * y2 = y1 + 128;
|
997
1346
|
|
998
1347
|
const float dall = (float)((x + i)->d);
|
999
1348
|
const float dmin = (float)((x + i)->dmin);
|
1000
1349
|
|
1001
|
-
const
|
1350
|
+
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
|
1351
|
+
sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
|
1352
|
+
sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
|
1353
|
+
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
|
1354
|
+
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
|
1002
1355
|
|
1003
1356
|
float4 s = {0.f, 0.f, 0.f, 0.f};
|
1357
|
+
float smin = 0;
|
1004
1358
|
for (int l = 0; l < n; ++l) {
|
1005
|
-
|
1006
|
-
s[
|
1359
|
+
|
1360
|
+
s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
|
1361
|
+
s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
|
1362
|
+
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
|
1363
|
+
|
1007
1364
|
}
|
1008
|
-
sumf += dall * (s[0] *
|
1365
|
+
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1009
1366
|
|
1010
1367
|
}
|
1368
|
+
|
1011
1369
|
sum[ith] = sumf;
|
1012
1370
|
|
1013
1371
|
//
|
@@ -1043,25 +1401,114 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1043
1401
|
//}
|
1044
1402
|
}
|
1045
1403
|
|
1404
|
+
kernel void kernel_mul_mat_q5_k_f32(
|
1405
|
+
device const void * src0,
|
1406
|
+
device const float * src1,
|
1407
|
+
device float * dst,
|
1408
|
+
constant int64_t & ne00,
|
1409
|
+
constant int64_t & ne10,
|
1410
|
+
constant int64_t & ne0,
|
1411
|
+
threadgroup float * sum [[threadgroup(0)]],
|
1412
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1413
|
+
uint2 tpitg[[thread_position_in_threadgroup]],
|
1414
|
+
uint2 tptg[[threads_per_threadgroup]]) {
|
1415
|
+
|
1416
|
+
const uint16_t kmask1 = 0x3f3f;
|
1417
|
+
const uint16_t kmask2 = 0x0f0f;
|
1418
|
+
const uint16_t kmask3 = 0xc0c0;
|
1419
|
+
|
1420
|
+
const int nb = ne00/QK_K;
|
1421
|
+
|
1422
|
+
const int64_t r0 = tgpig.x;
|
1423
|
+
const int64_t r1 = tgpig.y;
|
1424
|
+
|
1425
|
+
device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
|
1426
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1427
|
+
|
1428
|
+
const int nth = tptg.x*tptg.y;
|
1429
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1430
|
+
|
1431
|
+
const int tid = tpitg.y; // 0...16
|
1432
|
+
const int il = tid/4; // 0...3
|
1433
|
+
const int ir = tid - 4*il;// 0...3
|
1434
|
+
const int n = 4;
|
1435
|
+
|
1436
|
+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
1437
|
+
const int in = il%2;
|
1438
|
+
|
1439
|
+
const int l0 = n*(2*ir + in);
|
1440
|
+
const int q_offset = 32*im + l0;
|
1441
|
+
const int y_offset = 64*im + l0;
|
1442
|
+
|
1443
|
+
const uint8_t hm1 = 1u << (2*im);
|
1444
|
+
const uint8_t hm2 = hm1 << 1;
|
1445
|
+
const uint8_t hm3 = hm1 << 4;
|
1446
|
+
const uint8_t hm4 = hm2 << 4;
|
1447
|
+
|
1448
|
+
uchar2 sc1, sc2, sc3, sc4;
|
1449
|
+
|
1450
|
+
float sumf = 0;
|
1451
|
+
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1452
|
+
|
1453
|
+
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
1454
|
+
device const uint8_t * q2 = q1 + 64;
|
1455
|
+
device const uint8_t * qh = (x + i)->qh + l0;
|
1456
|
+
device const float * y1 = yy + i*QK_K + y_offset;
|
1457
|
+
device const float * y2 = y1 + 128;
|
1458
|
+
|
1459
|
+
const float dall = (float)((x + i)->d);
|
1460
|
+
const float dmin = (float)((x + i)->dmin);
|
1461
|
+
|
1462
|
+
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
|
1463
|
+
sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
|
1464
|
+
sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
|
1465
|
+
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
|
1466
|
+
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
|
1467
|
+
|
1468
|
+
float4 s = {0.f, 0.f, 0.f, 0.f};
|
1469
|
+
float smin = 0;
|
1470
|
+
for (int l = 0; l < n; ++l) {
|
1471
|
+
|
1472
|
+
s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
|
1473
|
+
s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
|
1474
|
+
s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
|
1475
|
+
s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
|
1476
|
+
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
|
1477
|
+
|
1478
|
+
}
|
1479
|
+
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1480
|
+
|
1481
|
+
}
|
1482
|
+
sum[ith] = sumf;
|
1483
|
+
|
1484
|
+
//
|
1485
|
+
// Accumulate the sum from all threads in the threadgroup
|
1486
|
+
//
|
1487
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1488
|
+
if (ith%4 == 0) {
|
1489
|
+
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
1490
|
+
}
|
1491
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1492
|
+
if (ith%16 == 0) {
|
1493
|
+
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
1494
|
+
}
|
1495
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1496
|
+
if (ith == 0) {
|
1497
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1498
|
+
dst[r1*ne0 + r0] = sum[0];
|
1499
|
+
}
|
1500
|
+
|
1501
|
+
}
|
1502
|
+
|
1046
1503
|
kernel void kernel_mul_mat_q6_k_f32(
|
1047
1504
|
device const void * src0,
|
1048
1505
|
device const float * src1,
|
1049
1506
|
device float * dst,
|
1050
1507
|
constant int64_t & ne00,
|
1051
|
-
constant int64_t & ne01,
|
1052
|
-
constant uint64_t & nb00,
|
1053
|
-
constant uint64_t & nb01,
|
1054
|
-
constant uint64_t & nb02,
|
1055
1508
|
constant int64_t & ne10,
|
1056
|
-
constant int64_t & ne11,
|
1057
|
-
constant uint64_t & nb10,
|
1058
|
-
constant uint64_t & nb11,
|
1059
|
-
constant uint64_t & nb12,
|
1060
1509
|
constant int64_t & ne0,
|
1061
|
-
constant int64_t & ne1,
|
1062
1510
|
threadgroup float * sum [[threadgroup(0)]],
|
1063
1511
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1064
|
-
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
|
1065
1512
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1066
1513
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1067
1514
|
|
@@ -1078,24 +1525,29 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1078
1525
|
device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
|
1079
1526
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1080
1527
|
|
1081
|
-
const
|
1082
|
-
const
|
1528
|
+
const int nth = tptg.x*tptg.y;
|
1529
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1083
1530
|
|
1084
|
-
|
1085
|
-
const int iqs =
|
1531
|
+
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
1532
|
+
const int iqs = 16 * tpitg.y;
|
1086
1533
|
const int ip = iqs / 128; // 0 or 1
|
1087
1534
|
const int il = (iqs - 128*ip)/16; // 0...7
|
1088
1535
|
const int n = 4;
|
1089
|
-
const int
|
1536
|
+
const int l0 = n*il;
|
1537
|
+
const int is = 8*ip + l0/16;
|
1538
|
+
|
1539
|
+
const int y_offset = 128*ip + l0;
|
1540
|
+
const int q_offset_l = 64*ip + l0;
|
1541
|
+
const int q_offset_h = 32*ip + l0;
|
1090
1542
|
|
1091
1543
|
float sumf = 0;
|
1092
1544
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1093
1545
|
|
1094
|
-
device const uint8_t * ql = x[i].ql +
|
1095
|
-
device const uint8_t * qh = x[i].qh +
|
1546
|
+
device const uint8_t * ql = x[i].ql + q_offset_l;
|
1547
|
+
device const uint8_t * qh = x[i].qh + q_offset_h;
|
1096
1548
|
device const int8_t * sc = x[i].scales + is;
|
1097
1549
|
|
1098
|
-
device const float * y = yy + i * QK_K +
|
1550
|
+
device const float * y = yy + i * QK_K + y_offset;
|
1099
1551
|
|
1100
1552
|
const float dall = x[i].d;
|
1101
1553
|
|