llama_cpp 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/examples/README.md +60 -0
- data/examples/chat.rb +195 -0
- data/ext/llama_cpp/llama_cpp.cpp +52 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +697 -130
- data/ext/llama_cpp/src/ggml-cuda.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +548 -497
- data/ext/llama_cpp/src/ggml-metal.metal +425 -122
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -32
- data/ext/llama_cpp/src/ggml-opencl.h +1 -2
- data/ext/llama_cpp/src/ggml.c +1904 -303
- data/ext/llama_cpp/src/ggml.h +126 -2
- data/ext/llama_cpp/src/llama.cpp +212 -108
- data/ext/llama_cpp/src/llama.h +12 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -0
- metadata +4 -2
@@ -304,34 +304,22 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
304
304
|
device const float * src1,
|
305
305
|
device float * dst,
|
306
306
|
constant int64_t & ne00,
|
307
|
-
constant int64_t & ne01,
|
308
|
-
constant uint64_t & nb00,
|
309
|
-
constant uint64_t & nb01,
|
310
|
-
constant uint64_t & nb02,
|
311
307
|
constant int64_t & ne10,
|
312
|
-
constant int64_t & ne11,
|
313
|
-
constant uint64_t & nb10,
|
314
|
-
constant uint64_t & nb11,
|
315
|
-
constant uint64_t & nb12,
|
316
308
|
constant int64_t & ne0,
|
317
|
-
constant int64_t & ne1,
|
318
309
|
threadgroup float * sum [[threadgroup(0)]],
|
319
310
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
320
|
-
uint2 tpig[[thread_position_in_grid]],
|
321
311
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
322
312
|
uint2 tptg[[threads_per_threadgroup]]) {
|
323
313
|
const int nb = ne00/QK4_0;
|
324
314
|
|
325
|
-
const int8_t m8 = 8;
|
326
|
-
|
327
315
|
const int64_t r0 = tgpig.x;
|
328
316
|
const int64_t r1 = tgpig.y;
|
329
317
|
|
330
318
|
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
|
331
319
|
device const float * y = (device const float *) src1 + r1*ne10;
|
332
320
|
|
333
|
-
const
|
334
|
-
const
|
321
|
+
const int nth = tptg.x*tptg.y;
|
322
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
335
323
|
|
336
324
|
const int ix = tpitg.y/4; // 0 or 1
|
337
325
|
const int iy = tpitg.y - 4*ix; // 0...3
|
@@ -351,47 +339,32 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
351
339
|
|
352
340
|
for (int j = 0; j < 4; ++j) {
|
353
341
|
|
354
|
-
acc[0] += yl[j
|
355
|
-
acc[1] += yl[j
|
342
|
+
acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
|
343
|
+
acc[1] += yl[j] + yl[j+16];
|
356
344
|
|
357
345
|
}
|
358
346
|
|
359
|
-
sumf += d * (acc[0]
|
347
|
+
sumf += d * (acc[0] - 8.f*acc[1]);
|
360
348
|
}
|
361
349
|
|
362
350
|
sum[ith] = sumf;
|
363
351
|
|
364
352
|
//
|
365
353
|
// Accumulate the sum from all threads in the threadgroup
|
366
|
-
// This version is slightly faster than the commented out one below,
|
367
|
-
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
368
354
|
//
|
369
355
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
370
356
|
if (ith%4 == 0) {
|
371
|
-
|
357
|
+
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
372
358
|
}
|
373
359
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
374
360
|
if (ith%16 == 0) {
|
375
|
-
|
361
|
+
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
376
362
|
}
|
377
363
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
378
364
|
if (ith == 0) {
|
379
|
-
for (
|
365
|
+
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
380
366
|
dst[r1*ne0 + r0] = sum[0];
|
381
367
|
}
|
382
|
-
|
383
|
-
//// accumulate the sum from all threads in the threadgroup
|
384
|
-
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
385
|
-
//for (uint i = nth/2; i > 0; i /= 2) {
|
386
|
-
// if (ith < i) {
|
387
|
-
// sum[ith] += sum[ith + i];
|
388
|
-
// }
|
389
|
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
390
|
-
//}
|
391
|
-
|
392
|
-
//if (ith == 0) {
|
393
|
-
// dst[r1*ne0 + r0] = sum[0];
|
394
|
-
//}
|
395
368
|
}
|
396
369
|
|
397
370
|
kernel void kernel_mul_mat_q4_1_f32(
|
@@ -399,20 +372,10 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
399
372
|
device const float * src1,
|
400
373
|
device float * dst,
|
401
374
|
constant int64_t & ne00,
|
402
|
-
constant int64_t & ne01,
|
403
|
-
constant uint64_t & nb00,
|
404
|
-
constant uint64_t & nb01,
|
405
|
-
constant uint64_t & nb02,
|
406
375
|
constant int64_t & ne10,
|
407
|
-
constant int64_t & ne11,
|
408
|
-
constant uint64_t & nb10,
|
409
|
-
constant uint64_t & nb11,
|
410
|
-
constant uint64_t & nb12,
|
411
376
|
constant int64_t & ne0,
|
412
|
-
constant int64_t & ne1,
|
413
377
|
threadgroup float * sum [[threadgroup(0)]],
|
414
378
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
415
|
-
uint2 tpig[[thread_position_in_grid]],
|
416
379
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
417
380
|
uint2 tptg[[threads_per_threadgroup]]) {
|
418
381
|
const int nb = ne00/QK4_1;
|
@@ -460,11 +423,11 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
460
423
|
//
|
461
424
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
462
425
|
if (ith%4 == 0) {
|
463
|
-
|
426
|
+
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
464
427
|
}
|
465
428
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
466
429
|
if (ith%16 == 0) {
|
467
|
-
|
430
|
+
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
468
431
|
}
|
469
432
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
470
433
|
if (ith == 0) {
|
@@ -671,6 +634,15 @@ typedef struct {
|
|
671
634
|
half d; // super-block scale for quantized scales
|
672
635
|
half dmin; // super-block scale for quantized mins
|
673
636
|
} block_q2_k;
|
637
|
+
// 84 bytes / block
|
638
|
+
|
639
|
+
typedef struct {
|
640
|
+
uint8_t hmask[QK_K/8]; // quants - high bit
|
641
|
+
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
642
|
+
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
|
643
|
+
half d; // super-block scale
|
644
|
+
} block_q3_k;
|
645
|
+
// 110 bytes / block
|
674
646
|
|
675
647
|
typedef struct {
|
676
648
|
half d; // super-block scale for quantized scales
|
@@ -678,6 +650,16 @@ typedef struct {
|
|
678
650
|
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
679
651
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
680
652
|
} block_q4_k;
|
653
|
+
// 144 bytes / block
|
654
|
+
|
655
|
+
typedef struct {
|
656
|
+
half d; // super-block scale for quantized scales
|
657
|
+
half dmin; // super-block scale for quantized mins
|
658
|
+
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
659
|
+
uint8_t qh[QK_K/8]; // quants, high bit
|
660
|
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
661
|
+
} block_q5_k;
|
662
|
+
// 176 bytes / block
|
681
663
|
|
682
664
|
typedef struct {
|
683
665
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
@@ -685,16 +667,19 @@ typedef struct {
|
|
685
667
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
686
668
|
half d; // super-block scale
|
687
669
|
} block_q6_k;
|
670
|
+
// 210 bytes / block
|
688
671
|
|
689
672
|
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
690
673
|
uchar4 r;
|
691
674
|
if (j < 4) {
|
692
|
-
r[0] = q[j+0] & 63;
|
693
|
-
r[2] = q[j+1] & 63;
|
675
|
+
r[0] = q[j+0] & 63;
|
676
|
+
r[2] = q[j+1] & 63;
|
677
|
+
r[1] = q[j+4] & 63;
|
678
|
+
r[3] = q[j+5] & 63;
|
694
679
|
} else {
|
695
680
|
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
696
|
-
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
697
681
|
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
|
682
|
+
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
698
683
|
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
|
699
684
|
}
|
700
685
|
return r;
|
@@ -735,10 +720,65 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
735
720
|
}
|
736
721
|
}
|
737
722
|
|
723
|
+
static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
|
724
|
+
assert(k % QK_K == 0);
|
725
|
+
const int nb = k / QK_K;
|
726
|
+
|
727
|
+
const uint16_t kmask1 = 0x0303;
|
728
|
+
const uint16_t kmask2 = 0x0f0f;
|
729
|
+
|
730
|
+
uint16_t aux[8];
|
731
|
+
thread const int8_t * scales = (thread const int8_t*)aux;
|
732
|
+
|
733
|
+
for (int i = 0; i < nb; i++) {
|
734
|
+
|
735
|
+
const float d_all = (float)(x[i].d);
|
736
|
+
|
737
|
+
device const uint8_t * q = x[i].qs;
|
738
|
+
device const uint8_t * h = x[i].hmask;
|
739
|
+
uint8_t m = 1;
|
740
|
+
|
741
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
742
|
+
aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
|
743
|
+
aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
|
744
|
+
aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
|
745
|
+
aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
|
746
|
+
aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
|
747
|
+
aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
|
748
|
+
aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
|
749
|
+
aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
|
750
|
+
|
751
|
+
int is = 0;
|
752
|
+
float dl;
|
753
|
+
for (int n = 0; n < QK_K; n += 128) {
|
754
|
+
int shift = 0;
|
755
|
+
for (int j = 0; j < 4; ++j) {
|
756
|
+
|
757
|
+
dl = d_all * (scales[is++] - 32);
|
758
|
+
for (int l = 0; l < 16; ++l) {
|
759
|
+
*y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
760
|
+
}
|
761
|
+
|
762
|
+
dl = d_all * (scales[is++] - 32);
|
763
|
+
for (int l = 0; l < 16; ++l) {
|
764
|
+
*y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
|
765
|
+
}
|
766
|
+
|
767
|
+
shift += 2;
|
768
|
+
m <<= 1;
|
769
|
+
}
|
770
|
+
q += 32;
|
771
|
+
}
|
772
|
+
|
773
|
+
}
|
774
|
+
|
775
|
+
}
|
776
|
+
|
738
777
|
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
|
739
778
|
assert(k % QK_K == 0);
|
740
779
|
const int nb = k / QK_K;
|
741
780
|
|
781
|
+
|
742
782
|
for (int i = 0; i < nb; i++) {
|
743
783
|
|
744
784
|
const float d = x[i].d;
|
@@ -760,6 +800,33 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
|
760
800
|
}
|
761
801
|
}
|
762
802
|
|
803
|
+
static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
|
804
|
+
assert(k % QK_K == 0);
|
805
|
+
const int nb = k / QK_K;
|
806
|
+
|
807
|
+
for (int i = 0; i < nb; i++) {
|
808
|
+
|
809
|
+
const float d = (float)(x[i].d);
|
810
|
+
const float min = (float)(x[i].dmin);
|
811
|
+
|
812
|
+
device const uint8_t * ql = x[i].qs;
|
813
|
+
device const uint8_t * qh = x[i].qh;
|
814
|
+
|
815
|
+
int is = 0;
|
816
|
+
uint8_t u1 = 1, u2 = 2;
|
817
|
+
for (int j = 0; j < QK_K; j += 64) {
|
818
|
+
const uchar4 sc = get_scale_min_k4(is, x[i].scales);
|
819
|
+
const float d1 = d * sc[0]; const float m1 = min * sc[1];
|
820
|
+
const float d2 = d * sc[2]; const float m2 = min * sc[3];
|
821
|
+
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
|
822
|
+
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
|
823
|
+
ql += 32; is += 2;
|
824
|
+
u1 <<= 2; u2 <<= 2;
|
825
|
+
}
|
826
|
+
}
|
827
|
+
|
828
|
+
}
|
829
|
+
|
763
830
|
static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
|
764
831
|
assert(k % QK_K == 0);
|
765
832
|
const int nb = k / QK_K;
|
@@ -808,6 +875,22 @@ kernel void kernel_get_rows_q2_k(
|
|
808
875
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
809
876
|
}
|
810
877
|
|
878
|
+
kernel void kernel_get_rows_q3_k(
|
879
|
+
device const void * src0,
|
880
|
+
device const int * src1,
|
881
|
+
device float * dst,
|
882
|
+
constant int64_t & ne00,
|
883
|
+
constant uint64_t & nb01,
|
884
|
+
constant uint64_t & nb1,
|
885
|
+
uint tpig[[thread_position_in_grid]]) {
|
886
|
+
const int i = tpig;
|
887
|
+
const int r = ((device int32_t *) src1)[i];
|
888
|
+
|
889
|
+
dequantize_row_q3_k(
|
890
|
+
(device const block_q3_k *) ((device char *) src0 + r*nb01),
|
891
|
+
(device float *) ((device char *) dst + i*nb1), ne00);
|
892
|
+
}
|
893
|
+
|
811
894
|
kernel void kernel_get_rows_q4_k(
|
812
895
|
device const void * src0,
|
813
896
|
device const int * src1,
|
@@ -824,6 +907,22 @@ kernel void kernel_get_rows_q4_k(
|
|
824
907
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
825
908
|
}
|
826
909
|
|
910
|
+
kernel void kernel_get_rows_q5_k(
|
911
|
+
device const void * src0,
|
912
|
+
device const int * src1,
|
913
|
+
device float * dst,
|
914
|
+
constant int64_t & ne00,
|
915
|
+
constant uint64_t & nb01,
|
916
|
+
constant uint64_t & nb1,
|
917
|
+
uint tpig[[thread_position_in_grid]]) {
|
918
|
+
const int i = tpig;
|
919
|
+
const int r = ((device int32_t *) src1)[i];
|
920
|
+
|
921
|
+
dequantize_row_q5_k(
|
922
|
+
(device const block_q5_k *) ((device char *) src0 + r*nb01),
|
923
|
+
(device float *) ((device char *) dst + i*nb1), ne00);
|
924
|
+
}
|
925
|
+
|
827
926
|
kernel void kernel_get_rows_q6_k(
|
828
927
|
device const void * src0,
|
829
928
|
device const int * src1,
|
@@ -847,20 +946,10 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
847
946
|
device const float * src1,
|
848
947
|
device float * dst,
|
849
948
|
constant int64_t & ne00,
|
850
|
-
constant int64_t & ne01,
|
851
|
-
constant uint64_t & nb00,
|
852
|
-
constant uint64_t & nb01,
|
853
|
-
constant uint64_t & nb02,
|
854
949
|
constant int64_t & ne10,
|
855
|
-
constant int64_t & ne11,
|
856
|
-
constant uint64_t & nb10,
|
857
|
-
constant uint64_t & nb11,
|
858
|
-
constant uint64_t & nb12,
|
859
950
|
constant int64_t & ne0,
|
860
|
-
constant int64_t & ne1,
|
861
951
|
threadgroup float * sum [[threadgroup(0)]],
|
862
952
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
863
|
-
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
|
864
953
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
865
954
|
uint2 tptg[[threads_per_threadgroup]]) {
|
866
955
|
|
@@ -875,7 +964,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
875
964
|
const int nth = tptg.x*tptg.y;
|
876
965
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
877
966
|
|
878
|
-
|
879
967
|
const int tid = tpitg.y; // 0...16
|
880
968
|
const int il = tid/4; // 0...3
|
881
969
|
const int ir = tid%4; // 0...3
|
@@ -885,35 +973,54 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
885
973
|
const int n = 8;
|
886
974
|
const int is = 4*il + (n*ir)/16;
|
887
975
|
|
976
|
+
const int y_offset = 64*il + n*ir;
|
977
|
+
const int q_offset = 32*ip + n*ir;
|
978
|
+
|
888
979
|
sum[ith] = 0.0f;
|
889
980
|
|
890
981
|
float sumf = 0;
|
891
982
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
892
983
|
|
893
|
-
device const uint8_t * q = x[i].qs +
|
984
|
+
device const uint8_t * q = x[i].qs + q_offset;
|
894
985
|
device const uint8_t * scales = x[i].scales + is;
|
895
986
|
|
896
987
|
uint8_t d1 = scales[0] & 0xF;
|
897
|
-
uint8_t m1 = scales[0] >> 4;
|
898
988
|
uint8_t d2 = scales[2] & 0xF;
|
989
|
+
uint8_t m1 = scales[0] >> 4;
|
899
990
|
uint8_t m2 = scales[2] >> 4;
|
900
991
|
|
901
|
-
device const float * y = yy + i*QK_K +
|
902
|
-
|
903
|
-
const float dall = (float)x[i].d;
|
904
|
-
const float dmin = (float)x[i].dmin;
|
992
|
+
device const float * y = yy + i*QK_K + y_offset;
|
905
993
|
|
906
|
-
float4 s = {0.f, 0.f, 0.f, 0.f};
|
994
|
+
//float4 s = {0.f, 0.f, 0.f, 0.f};
|
995
|
+
float2 s = {0.f, 0.f};
|
996
|
+
float smin = 0;
|
907
997
|
for (int l = 0; l < n; ++l) {
|
908
|
-
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
|
909
|
-
s[
|
998
|
+
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
|
999
|
+
s[1] += y[l+32] * ((q[l] >> shift2) & 3);
|
1000
|
+
smin += y[l+ 0] * m1 + y[l+32] * m2;
|
910
1001
|
}
|
911
|
-
sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
|
912
1002
|
|
1003
|
+
const float dall = (float)x[i].d;
|
1004
|
+
const float dmin = (float)x[i].dmin;
|
1005
|
+
|
1006
|
+
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
913
1007
|
|
914
1008
|
}
|
915
1009
|
sum[ith] = sumf;
|
916
1010
|
|
1011
|
+
//int mask1 = (ith%4 == 0);
|
1012
|
+
//int mask2 = (ith%16 == 0);
|
1013
|
+
|
1014
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
1015
|
+
//for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
|
1016
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
1017
|
+
//for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
|
1018
|
+
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
1019
|
+
//if (ith == 0) {
|
1020
|
+
// for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1021
|
+
// dst[r1*ne0 + r0] = sum[0];
|
1022
|
+
//}
|
1023
|
+
|
917
1024
|
//
|
918
1025
|
// Accumulate the sum from all threads in the threadgroup
|
919
1026
|
// This version is slightly faster than the commented out one below,
|
@@ -932,19 +1039,109 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
932
1039
|
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
933
1040
|
dst[r1*ne0 + r0] = sum[0];
|
934
1041
|
}
|
1042
|
+
}
|
935
1043
|
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
1044
|
+
kernel void kernel_mul_mat_q3_k_f32(
|
1045
|
+
device const void * src0,
|
1046
|
+
device const float * src1,
|
1047
|
+
device float * dst,
|
1048
|
+
constant int64_t & ne00,
|
1049
|
+
constant int64_t & ne10,
|
1050
|
+
constant int64_t & ne0,
|
1051
|
+
constant int64_t & ne1,
|
1052
|
+
threadgroup float * sum [[threadgroup(0)]],
|
1053
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1054
|
+
uint2 tpitg[[thread_position_in_threadgroup]],
|
1055
|
+
uint2 tptg[[threads_per_threadgroup]]) {
|
1056
|
+
|
1057
|
+
const uint16_t kmask1 = 0x0303;
|
1058
|
+
const uint16_t kmask2 = 0x0f0f;
|
1059
|
+
|
1060
|
+
const uint8_t m3 = 3;
|
1061
|
+
const int8_t m4 = 4;
|
1062
|
+
|
1063
|
+
const int nb = ne00/QK_K;
|
1064
|
+
|
1065
|
+
const int64_t r0 = tgpig.x;
|
1066
|
+
const int64_t r1 = tgpig.y;
|
1067
|
+
|
1068
|
+
device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
|
1069
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1070
|
+
|
1071
|
+
const int nth = tptg.x*tptg.y;
|
1072
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1073
|
+
|
1074
|
+
const int tid = tpitg.y; // expecting 16
|
1075
|
+
const int ip = tid/8; // 0 or 1
|
1076
|
+
const int il = tid/2 - 4*ip; // 0...3
|
1077
|
+
const int ir = tid%2;
|
1078
|
+
const int n = 8;
|
1079
|
+
const int l0 = n*ir;
|
1080
|
+
|
1081
|
+
const uint8_t m = 1 << (4*ip + il);
|
1082
|
+
|
1083
|
+
const int shift = 2*il;
|
1084
|
+
|
1085
|
+
const uint16_t s_shift1 = 4*ip;
|
1086
|
+
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
1087
|
+
const int ik = 4 + (il%2);
|
1088
|
+
|
1089
|
+
const int q_offset = 32*ip + l0;
|
1090
|
+
const int y_offset = 128*ip + 32*il + l0;
|
1091
|
+
|
1092
|
+
//float sumf = 0;
|
1093
|
+
float sumf1 = 0, sumf2 = 0;
|
1094
|
+
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1095
|
+
|
1096
|
+
const float d_all = (float)(x[i].d);
|
1097
|
+
|
1098
|
+
device const uint8_t * q = x[i].qs + q_offset;
|
1099
|
+
device const uint8_t * h = x[i].hmask + l0;
|
1100
|
+
device const float * y = yy + i * QK_K + y_offset;
|
1101
|
+
|
1102
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
1103
|
+
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1104
|
+
|
1105
|
+
float s = 0;
|
1106
|
+
for (int l = 0; l < n; ++l) {
|
1107
|
+
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
|
1108
|
+
}
|
1109
|
+
float d = d_all * s;
|
1110
|
+
sumf1 += d * scales[0];
|
1111
|
+
sumf2 += d;
|
1112
|
+
//sumf += d_all * s * (scales[0] - 32);
|
1113
|
+
|
1114
|
+
s = 0;
|
1115
|
+
for (int l = 0; l < n; ++l) {
|
1116
|
+
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
|
1117
|
+
}
|
1118
|
+
d = d_all * s;
|
1119
|
+
sumf1 += d * scales[1];
|
1120
|
+
sumf2 += d;
|
1121
|
+
//sumf += d_all * s * (scales[1] - 32);
|
1122
|
+
|
1123
|
+
}
|
1124
|
+
|
1125
|
+
//sum[ith] = sumf;
|
1126
|
+
sum[ith] = sumf1 - 32.f*sumf2;
|
1127
|
+
|
1128
|
+
//
|
1129
|
+
// Accumulate the sum from all threads in the threadgroup
|
1130
|
+
//
|
1131
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1132
|
+
if (ith%4 == 0) {
|
1133
|
+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
1134
|
+
}
|
1135
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1136
|
+
if (ith%16 == 0) {
|
1137
|
+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
1138
|
+
}
|
1139
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1140
|
+
if (ith == 0) {
|
1141
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1142
|
+
dst[r1*ne0 + r0] = sum[0];
|
1143
|
+
}
|
944
1144
|
|
945
|
-
//if (ith == 0) {
|
946
|
-
// dst[r1*ne0 + r0] = sum[0];
|
947
|
-
//}
|
948
1145
|
}
|
949
1146
|
|
950
1147
|
kernel void kernel_mul_mat_q4_k_f32(
|
@@ -952,23 +1149,17 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
952
1149
|
device const float * src1,
|
953
1150
|
device float * dst,
|
954
1151
|
constant int64_t & ne00,
|
955
|
-
constant int64_t & ne01,
|
956
|
-
constant uint64_t & nb00,
|
957
|
-
constant uint64_t & nb01,
|
958
|
-
constant uint64_t & nb02,
|
959
1152
|
constant int64_t & ne10,
|
960
|
-
constant int64_t & ne11,
|
961
|
-
constant uint64_t & nb10,
|
962
|
-
constant uint64_t & nb11,
|
963
|
-
constant uint64_t & nb12,
|
964
1153
|
constant int64_t & ne0,
|
965
|
-
constant int64_t & ne1,
|
966
1154
|
threadgroup float * sum [[threadgroup(0)]],
|
967
1155
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
968
|
-
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
|
969
1156
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
970
1157
|
uint2 tptg[[threads_per_threadgroup]]) {
|
971
1158
|
|
1159
|
+
const uint16_t kmask1 = 0x3f3f;
|
1160
|
+
const uint16_t kmask2 = 0x0f0f;
|
1161
|
+
const uint16_t kmask3 = 0xc0c0;
|
1162
|
+
|
972
1163
|
const int nb = ne00/QK_K;
|
973
1164
|
|
974
1165
|
const int64_t r0 = tgpig.x;
|
@@ -977,37 +1168,55 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
977
1168
|
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
978
1169
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
979
1170
|
|
980
|
-
const
|
981
|
-
const
|
1171
|
+
const int nth = tptg.x*tptg.y;
|
1172
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
982
1173
|
|
983
1174
|
const int tid = tpitg.y; // 0...16
|
984
1175
|
const int il = tid/4; // 0...3
|
985
|
-
const int ir = tid
|
986
|
-
const int n =
|
987
|
-
|
1176
|
+
const int ir = tid - 4*il;// 0...3
|
1177
|
+
const int n = 4;
|
1178
|
+
|
1179
|
+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
1180
|
+
const int in = il%2;
|
1181
|
+
|
1182
|
+
const int l0 = n*(2*ir + in);
|
1183
|
+
const int q_offset = 32*im + l0;
|
1184
|
+
const int y_offset = 64*im + l0;
|
988
1185
|
|
989
1186
|
sum[ith] = 0.0f;
|
990
1187
|
|
1188
|
+
uchar2 sc1, sc2, sc3, sc4;
|
1189
|
+
|
991
1190
|
float sumf = 0;
|
992
1191
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
993
1192
|
|
994
|
-
device const uint8_t *
|
995
|
-
device const
|
996
|
-
device const
|
1193
|
+
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
1194
|
+
device const uint8_t * q2 = q1 + 64;
|
1195
|
+
device const float * y1 = yy + i*QK_K + y_offset;
|
1196
|
+
device const float * y2 = y1 + 128;
|
997
1197
|
|
998
1198
|
const float dall = (float)((x + i)->d);
|
999
1199
|
const float dmin = (float)((x + i)->dmin);
|
1000
1200
|
|
1001
|
-
const
|
1201
|
+
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
|
1202
|
+
sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
|
1203
|
+
sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
|
1204
|
+
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
|
1205
|
+
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
|
1002
1206
|
|
1003
1207
|
float4 s = {0.f, 0.f, 0.f, 0.f};
|
1208
|
+
float smin = 0;
|
1004
1209
|
for (int l = 0; l < n; ++l) {
|
1005
|
-
|
1006
|
-
s[
|
1210
|
+
|
1211
|
+
s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
|
1212
|
+
s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
|
1213
|
+
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
|
1214
|
+
|
1007
1215
|
}
|
1008
|
-
sumf += dall * (s[0] *
|
1216
|
+
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1009
1217
|
|
1010
1218
|
}
|
1219
|
+
|
1011
1220
|
sum[ith] = sumf;
|
1012
1221
|
|
1013
1222
|
//
|
@@ -1043,25 +1252,114 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1043
1252
|
//}
|
1044
1253
|
}
|
1045
1254
|
|
1255
|
+
kernel void kernel_mul_mat_q5_k_f32(
|
1256
|
+
device const void * src0,
|
1257
|
+
device const float * src1,
|
1258
|
+
device float * dst,
|
1259
|
+
constant int64_t & ne00,
|
1260
|
+
constant int64_t & ne10,
|
1261
|
+
constant int64_t & ne0,
|
1262
|
+
threadgroup float * sum [[threadgroup(0)]],
|
1263
|
+
uint2 tgpig[[threadgroup_position_in_grid]],
|
1264
|
+
uint2 tpitg[[thread_position_in_threadgroup]],
|
1265
|
+
uint2 tptg[[threads_per_threadgroup]]) {
|
1266
|
+
|
1267
|
+
const uint16_t kmask1 = 0x3f3f;
|
1268
|
+
const uint16_t kmask2 = 0x0f0f;
|
1269
|
+
const uint16_t kmask3 = 0xc0c0;
|
1270
|
+
|
1271
|
+
const int nb = ne00/QK_K;
|
1272
|
+
|
1273
|
+
const int64_t r0 = tgpig.x;
|
1274
|
+
const int64_t r1 = tgpig.y;
|
1275
|
+
|
1276
|
+
device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
|
1277
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1278
|
+
|
1279
|
+
const int nth = tptg.x*tptg.y;
|
1280
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1281
|
+
|
1282
|
+
const int tid = tpitg.y; // 0...16
|
1283
|
+
const int il = tid/4; // 0...3
|
1284
|
+
const int ir = tid - 4*il;// 0...3
|
1285
|
+
const int n = 4;
|
1286
|
+
|
1287
|
+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
1288
|
+
const int in = il%2;
|
1289
|
+
|
1290
|
+
const int l0 = n*(2*ir + in);
|
1291
|
+
const int q_offset = 32*im + l0;
|
1292
|
+
const int y_offset = 64*im + l0;
|
1293
|
+
|
1294
|
+
const uint8_t hm1 = 1u << (2*im);
|
1295
|
+
const uint8_t hm2 = hm1 << 1;
|
1296
|
+
const uint8_t hm3 = hm1 << 4;
|
1297
|
+
const uint8_t hm4 = hm2 << 4;
|
1298
|
+
|
1299
|
+
uchar2 sc1, sc2, sc3, sc4;
|
1300
|
+
|
1301
|
+
float sumf = 0;
|
1302
|
+
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1303
|
+
|
1304
|
+
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
1305
|
+
device const uint8_t * q2 = q1 + 64;
|
1306
|
+
device const uint8_t * qh = (x + i)->qh + l0;
|
1307
|
+
device const float * y1 = yy + i*QK_K + y_offset;
|
1308
|
+
device const float * y2 = y1 + 128;
|
1309
|
+
|
1310
|
+
const float dall = (float)((x + i)->d);
|
1311
|
+
const float dmin = (float)((x + i)->dmin);
|
1312
|
+
|
1313
|
+
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
|
1314
|
+
sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
|
1315
|
+
sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
|
1316
|
+
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
|
1317
|
+
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
|
1318
|
+
|
1319
|
+
float4 s = {0.f, 0.f, 0.f, 0.f};
|
1320
|
+
float smin = 0;
|
1321
|
+
for (int l = 0; l < n; ++l) {
|
1322
|
+
|
1323
|
+
s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
|
1324
|
+
s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
|
1325
|
+
s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
|
1326
|
+
s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
|
1327
|
+
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
|
1328
|
+
|
1329
|
+
}
|
1330
|
+
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1331
|
+
|
1332
|
+
}
|
1333
|
+
sum[ith] = sumf;
|
1334
|
+
|
1335
|
+
//
|
1336
|
+
// Accumulate the sum from all threads in the threadgroup
|
1337
|
+
//
|
1338
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1339
|
+
if (ith%4 == 0) {
|
1340
|
+
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
1341
|
+
}
|
1342
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1343
|
+
if (ith%16 == 0) {
|
1344
|
+
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
1345
|
+
}
|
1346
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1347
|
+
if (ith == 0) {
|
1348
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
1349
|
+
dst[r1*ne0 + r0] = sum[0];
|
1350
|
+
}
|
1351
|
+
|
1352
|
+
}
|
1353
|
+
|
1046
1354
|
kernel void kernel_mul_mat_q6_k_f32(
|
1047
1355
|
device const void * src0,
|
1048
1356
|
device const float * src1,
|
1049
1357
|
device float * dst,
|
1050
1358
|
constant int64_t & ne00,
|
1051
|
-
constant int64_t & ne01,
|
1052
|
-
constant uint64_t & nb00,
|
1053
|
-
constant uint64_t & nb01,
|
1054
|
-
constant uint64_t & nb02,
|
1055
1359
|
constant int64_t & ne10,
|
1056
|
-
constant int64_t & ne11,
|
1057
|
-
constant uint64_t & nb10,
|
1058
|
-
constant uint64_t & nb11,
|
1059
|
-
constant uint64_t & nb12,
|
1060
1360
|
constant int64_t & ne0,
|
1061
|
-
constant int64_t & ne1,
|
1062
1361
|
threadgroup float * sum [[threadgroup(0)]],
|
1063
1362
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
1064
|
-
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
|
1065
1363
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1066
1364
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1067
1365
|
|
@@ -1078,24 +1376,29 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1078
1376
|
device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
|
1079
1377
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1080
1378
|
|
1081
|
-
const
|
1082
|
-
const
|
1379
|
+
const int nth = tptg.x*tptg.y;
|
1380
|
+
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1083
1381
|
|
1084
|
-
|
1085
|
-
const int iqs =
|
1382
|
+
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
1383
|
+
const int iqs = 16 * tpitg.y;
|
1086
1384
|
const int ip = iqs / 128; // 0 or 1
|
1087
1385
|
const int il = (iqs - 128*ip)/16; // 0...7
|
1088
1386
|
const int n = 4;
|
1089
|
-
const int
|
1387
|
+
const int l0 = n*il;
|
1388
|
+
const int is = 8*ip + l0/16;
|
1389
|
+
|
1390
|
+
const int y_offset = 128*ip + l0;
|
1391
|
+
const int q_offset_l = 64*ip + l0;
|
1392
|
+
const int q_offset_h = 32*ip + l0;
|
1090
1393
|
|
1091
1394
|
float sumf = 0;
|
1092
1395
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1093
1396
|
|
1094
|
-
device const uint8_t * ql = x[i].ql +
|
1095
|
-
device const uint8_t * qh = x[i].qh +
|
1397
|
+
device const uint8_t * ql = x[i].ql + q_offset_l;
|
1398
|
+
device const uint8_t * qh = x[i].qh + q_offset_h;
|
1096
1399
|
device const int8_t * sc = x[i].scales + is;
|
1097
1400
|
|
1098
|
-
device const float * y = yy + i * QK_K +
|
1401
|
+
device const float * y = yy + i * QK_K + y_offset;
|
1099
1402
|
|
1100
1403
|
const float dall = x[i].d;
|
1101
1404
|
|