llama_cpp 0.2.1 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +32 -0
- data/README.md +39 -6
- data/examples/README.md +32 -0
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +38 -0
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +231 -132
- data/ext/llama_cpp/src/ggml-cuda.cu +844 -337
- data/ext/llama_cpp/src/ggml-metal.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.m +193 -49
- data/ext/llama_cpp/src/ggml-metal.metal +477 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +493 -4
- data/ext/llama_cpp/src/ggml.c +1565 -430
- data/ext/llama_cpp/src/ggml.h +208 -14
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +194 -101
- data/ext/llama_cpp/src/llama.h +41 -14
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +12 -17
- metadata +3 -3
- data/lib/llama_cpp/client.rb +0 -172
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
|
|
256
256
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
257
257
|
}
|
258
258
|
|
259
|
+
kernel void kernel_norm(
|
260
|
+
device const void * src0,
|
261
|
+
device float * dst,
|
262
|
+
constant int64_t & ne00,
|
263
|
+
constant uint64_t & nb01,
|
264
|
+
constant float & eps,
|
265
|
+
threadgroup float * sum [[threadgroup(0)]],
|
266
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
267
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
268
|
+
uint ntg[[threads_per_threadgroup]]) {
|
269
|
+
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
270
|
+
// MEAN
|
271
|
+
// parallel sum
|
272
|
+
sum[tpitg] = 0.0f;
|
273
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
274
|
+
sum[tpitg] += x[i00];
|
275
|
+
}
|
276
|
+
// reduce
|
277
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
278
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
279
|
+
if (tpitg < i) {
|
280
|
+
sum[tpitg] += sum[tpitg + i];
|
281
|
+
}
|
282
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
283
|
+
}
|
284
|
+
// broadcast
|
285
|
+
if (tpitg == 0) {
|
286
|
+
sum[0] /= ne00;
|
287
|
+
}
|
288
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
289
|
+
const float mean = sum[0];
|
290
|
+
|
291
|
+
// recenter
|
292
|
+
device float * y = dst + tgpig*ne00;
|
293
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
294
|
+
y[i00] = x[i00] - mean;
|
295
|
+
}
|
296
|
+
|
297
|
+
// VARIANCE
|
298
|
+
// parallel sum
|
299
|
+
sum[tpitg] = 0.0f;
|
300
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
301
|
+
sum[tpitg] += y[i00] * y[i00];
|
302
|
+
}
|
303
|
+
// reduce
|
304
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
305
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
306
|
+
if (tpitg < i) {
|
307
|
+
sum[tpitg] += sum[tpitg + i];
|
308
|
+
}
|
309
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
310
|
+
}
|
311
|
+
// broadcast
|
312
|
+
if (tpitg == 0) {
|
313
|
+
sum[0] /= ne00;
|
314
|
+
}
|
315
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
316
|
+
const float variance = sum[0];
|
317
|
+
|
318
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
319
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
320
|
+
y[i00] = y[i00] * scale;
|
321
|
+
}
|
322
|
+
}
|
323
|
+
|
324
|
+
|
259
325
|
kernel void kernel_rms_norm(
|
260
326
|
device const void * src0,
|
261
327
|
device float * dst,
|
@@ -362,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
362
428
|
}
|
363
429
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
364
430
|
if (ith == 0) {
|
365
|
-
for (
|
431
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
366
432
|
dst[r1*ne0 + r0] = sum[0];
|
367
433
|
}
|
368
434
|
}
|
@@ -431,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
431
497
|
}
|
432
498
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
433
499
|
if (ith == 0) {
|
434
|
-
for (
|
500
|
+
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
435
501
|
dst[r1*ne0 + r0] = sum[0];
|
436
502
|
}
|
437
503
|
}
|
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
|
|
485
551
|
}
|
486
552
|
}
|
487
553
|
|
554
|
+
kernel void kernel_alibi_f32(
|
555
|
+
device const float * src0,
|
556
|
+
device float * dst,
|
557
|
+
constant int64_t & ne00,
|
558
|
+
constant int64_t & ne01,
|
559
|
+
constant int64_t & ne02,
|
560
|
+
constant int64_t & ne03,
|
561
|
+
constant uint64_t & nb00,
|
562
|
+
constant uint64_t & nb01,
|
563
|
+
constant uint64_t & nb02,
|
564
|
+
constant uint64_t & nb03,
|
565
|
+
constant int64_t & ne0,
|
566
|
+
constant int64_t & ne1,
|
567
|
+
constant int64_t & ne2,
|
568
|
+
constant int64_t & ne3,
|
569
|
+
constant uint64_t & nb0,
|
570
|
+
constant uint64_t & nb1,
|
571
|
+
constant uint64_t & nb2,
|
572
|
+
constant uint64_t & nb3,
|
573
|
+
constant float & m0,
|
574
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
575
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
576
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
577
|
+
const int64_t i03 = tgpig[2];
|
578
|
+
const int64_t i02 = tgpig[1];
|
579
|
+
const int64_t i01 = tgpig[0];
|
580
|
+
|
581
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
582
|
+
|
583
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
584
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
585
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
586
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
587
|
+
|
588
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
589
|
+
float m_k = pow(m0, i2 + 1);
|
590
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
591
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
592
|
+
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
593
|
+
}
|
594
|
+
}
|
595
|
+
|
488
596
|
kernel void kernel_rope(
|
489
597
|
device const void * src0,
|
490
598
|
device float * dst,
|
@@ -540,6 +648,47 @@ kernel void kernel_rope(
|
|
540
648
|
}
|
541
649
|
}
|
542
650
|
|
651
|
+
kernel void kernel_cpy_f16_f16(
|
652
|
+
device const half * src0,
|
653
|
+
device half * dst,
|
654
|
+
constant int64_t & ne00,
|
655
|
+
constant int64_t & ne01,
|
656
|
+
constant int64_t & ne02,
|
657
|
+
constant int64_t & ne03,
|
658
|
+
constant uint64_t & nb00,
|
659
|
+
constant uint64_t & nb01,
|
660
|
+
constant uint64_t & nb02,
|
661
|
+
constant uint64_t & nb03,
|
662
|
+
constant int64_t & ne0,
|
663
|
+
constant int64_t & ne1,
|
664
|
+
constant int64_t & ne2,
|
665
|
+
constant int64_t & ne3,
|
666
|
+
constant uint64_t & nb0,
|
667
|
+
constant uint64_t & nb1,
|
668
|
+
constant uint64_t & nb2,
|
669
|
+
constant uint64_t & nb3,
|
670
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
671
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
672
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
673
|
+
const int64_t i03 = tgpig[2];
|
674
|
+
const int64_t i02 = tgpig[1];
|
675
|
+
const int64_t i01 = tgpig[0];
|
676
|
+
|
677
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
678
|
+
|
679
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
680
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
681
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
682
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
683
|
+
|
684
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
685
|
+
|
686
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
687
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
688
|
+
dst_data[i00] = src[0];
|
689
|
+
}
|
690
|
+
}
|
691
|
+
|
543
692
|
kernel void kernel_cpy_f32_f16(
|
544
693
|
device const float * src0,
|
545
694
|
device half * dst,
|
@@ -626,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
|
|
626
775
|
|
627
776
|
//============================================ k-quants ======================================================
|
628
777
|
|
778
|
+
#ifndef QK_K
|
629
779
|
#define QK_K 256
|
780
|
+
#else
|
781
|
+
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
|
782
|
+
#endif
|
783
|
+
|
784
|
+
#if QK_K == 256
|
785
|
+
#define K_SCALE_SIZE 12
|
786
|
+
#else
|
787
|
+
#define K_SCALE_SIZE 4
|
788
|
+
#endif
|
630
789
|
|
631
790
|
typedef struct {
|
632
791
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
633
792
|
uint8_t qs[QK_K/4]; // quants
|
634
793
|
half d; // super-block scale for quantized scales
|
635
794
|
half dmin; // super-block scale for quantized mins
|
636
|
-
}
|
795
|
+
} block_q2_K;
|
637
796
|
// 84 bytes / block
|
638
797
|
|
639
798
|
typedef struct {
|
640
799
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
641
800
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
//
|
646
|
-
|
801
|
+
#if QK_K == 64
|
802
|
+
uint8_t scales[2];
|
803
|
+
#else
|
804
|
+
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
805
|
+
#endif
|
806
|
+
half d; // super-block scale
|
807
|
+
} block_q3_K;
|
808
|
+
|
809
|
+
#if QK_K == 64
|
810
|
+
typedef struct {
|
811
|
+
half d[2]; // super-block scales/mins
|
812
|
+
uint8_t scales[2];
|
813
|
+
uint8_t qs[QK_K/2]; // 4-bit quants
|
814
|
+
} block_q4_K;
|
815
|
+
#else
|
647
816
|
typedef struct {
|
648
817
|
half d; // super-block scale for quantized scales
|
649
818
|
half dmin; // super-block scale for quantized mins
|
650
|
-
uint8_t scales[
|
819
|
+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
651
820
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
652
|
-
}
|
653
|
-
|
821
|
+
} block_q4_K;
|
822
|
+
#endif
|
654
823
|
|
824
|
+
#if QK_K == 64
|
825
|
+
typedef struct {
|
826
|
+
half d; // super-block scales/mins
|
827
|
+
int8_t scales[QK_K/16]; // 8-bit block scales
|
828
|
+
uint8_t qh[QK_K/8]; // quants, high bit
|
829
|
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
830
|
+
} block_q5_K;
|
831
|
+
#else
|
655
832
|
typedef struct {
|
656
833
|
half d; // super-block scale for quantized scales
|
657
834
|
half dmin; // super-block scale for quantized mins
|
658
835
|
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
659
836
|
uint8_t qh[QK_K/8]; // quants, high bit
|
660
837
|
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
661
|
-
}
|
838
|
+
} block_q5_K;
|
662
839
|
// 176 bytes / block
|
840
|
+
#endif
|
663
841
|
|
664
842
|
typedef struct {
|
665
843
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
666
844
|
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
667
845
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
668
846
|
half d; // super-block scale
|
669
|
-
}
|
847
|
+
} block_q6_K;
|
670
848
|
// 210 bytes / block
|
671
849
|
|
672
850
|
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
@@ -687,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
687
865
|
|
688
866
|
//========================================== dequantization =============================
|
689
867
|
|
690
|
-
static void
|
868
|
+
static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
|
691
869
|
assert(k % QK_K == 0);
|
692
870
|
const int nb = k / QK_K;
|
693
871
|
|
@@ -698,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
698
876
|
|
699
877
|
device const uint8_t * q = x[i].qs;
|
700
878
|
|
879
|
+
#if QK_K == 256
|
701
880
|
int is = 0;
|
702
881
|
float dl, ml;
|
703
882
|
for (int n = 0; n < QK_K; n += 128) {
|
@@ -716,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
716
895
|
}
|
717
896
|
q += 32;
|
718
897
|
}
|
898
|
+
#else
|
899
|
+
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
900
|
+
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
901
|
+
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
902
|
+
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
903
|
+
for (int l = 0; l < 16; ++l) {
|
904
|
+
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
|
905
|
+
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
|
906
|
+
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
|
907
|
+
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
|
908
|
+
}
|
909
|
+
y += QK_K;
|
910
|
+
#endif
|
719
911
|
|
720
912
|
}
|
721
913
|
}
|
722
914
|
|
723
|
-
static void
|
915
|
+
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
724
916
|
assert(k % QK_K == 0);
|
725
917
|
const int nb = k / QK_K;
|
726
918
|
|
919
|
+
#if QK_K == 256
|
920
|
+
|
727
921
|
const uint16_t kmask1 = 0x0303;
|
728
922
|
const uint16_t kmask2 = 0x0f0f;
|
729
923
|
|
@@ -769,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
|
769
963
|
}
|
770
964
|
q += 32;
|
771
965
|
}
|
966
|
+
}
|
967
|
+
#else
|
968
|
+
for (int i = 0; i < nb; i++) {
|
969
|
+
|
970
|
+
const float d_all = (float)(x[i].d);
|
772
971
|
|
972
|
+
device const uint8_t * q = x[i].qs;
|
973
|
+
device const uint8_t * hm = x[i].hmask;
|
974
|
+
|
975
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
976
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
977
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
978
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
979
|
+
|
980
|
+
for (int l = 0; l < 8; ++l) {
|
981
|
+
uint8_t h = hm[l];
|
982
|
+
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
983
|
+
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
984
|
+
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
985
|
+
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
986
|
+
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
987
|
+
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
988
|
+
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
989
|
+
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
990
|
+
}
|
991
|
+
y += QK_K;
|
773
992
|
}
|
993
|
+
#endif
|
774
994
|
|
775
995
|
}
|
776
996
|
|
777
|
-
static void
|
997
|
+
static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
|
778
998
|
assert(k % QK_K == 0);
|
779
999
|
const int nb = k / QK_K;
|
780
1000
|
|
781
|
-
|
782
1001
|
for (int i = 0; i < nb; i++) {
|
783
1002
|
|
1003
|
+
device const uint8_t * q = x[i].qs;
|
1004
|
+
|
1005
|
+
#if QK_K == 256
|
784
1006
|
const float d = x[i].d;
|
785
1007
|
const float min = x[i].dmin;
|
786
1008
|
|
787
|
-
device const uint8_t * q = x[i].qs;
|
788
1009
|
device const uint8_t * scales = x[i].scales;
|
789
1010
|
|
790
1011
|
int is = 0;
|
@@ -796,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
|
796
1017
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
797
1018
|
q += 32; is += 2;
|
798
1019
|
}
|
1020
|
+
#else
|
1021
|
+
device const uint8_t * s = x[i].scales;
|
1022
|
+
device const half2 * dh = (device const half2 *)x[i].d;
|
1023
|
+
const float2 d = (float2)dh[0];
|
1024
|
+
const float d1 = d[0] * (s[0] & 0xF);
|
1025
|
+
const float d2 = d[0] * (s[1] & 0xF);
|
1026
|
+
const float m1 = d[1] * (s[0] >> 4);
|
1027
|
+
const float m2 = d[1] * (s[1] >> 4);
|
1028
|
+
for (int l = 0; l < 32; ++l) {
|
1029
|
+
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
1030
|
+
y[l+32] = d2 * (q[l] >> 4) - m2;
|
1031
|
+
}
|
1032
|
+
y += QK_K;
|
1033
|
+
#endif
|
799
1034
|
|
800
1035
|
}
|
801
1036
|
}
|
802
1037
|
|
803
|
-
static void
|
1038
|
+
static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
|
804
1039
|
assert(k % QK_K == 0);
|
805
1040
|
const int nb = k / QK_K;
|
806
1041
|
|
1042
|
+
#if QK_K == 256
|
807
1043
|
for (int i = 0; i < nb; i++) {
|
808
1044
|
|
809
1045
|
const float d = (float)(x[i].d);
|
@@ -824,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
|
|
824
1060
|
u1 <<= 2; u2 <<= 2;
|
825
1061
|
}
|
826
1062
|
}
|
1063
|
+
#else
|
1064
|
+
for (int i = 0; i < nb; i++) {
|
1065
|
+
|
1066
|
+
const float d = (float)x[i].d;
|
1067
|
+
|
1068
|
+
device const uint8_t * ql = x[i].qs;
|
1069
|
+
device const uint8_t * qh = x[i].qh;
|
1070
|
+
device const int8_t * sc = x[i].scales;
|
1071
|
+
|
1072
|
+
for (int l = 0; l < 8; ++l) {
|
1073
|
+
y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
1074
|
+
y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
1075
|
+
y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
1076
|
+
y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
1077
|
+
y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
1078
|
+
y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
1079
|
+
y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
1080
|
+
y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
1081
|
+
}
|
1082
|
+
y += QK_K;
|
1083
|
+
}
|
1084
|
+
#endif
|
827
1085
|
|
828
1086
|
}
|
829
1087
|
|
830
|
-
static void
|
1088
|
+
static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
|
831
1089
|
assert(k % QK_K == 0);
|
832
1090
|
const int nb = k / QK_K;
|
833
1091
|
|
@@ -839,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
839
1097
|
|
840
1098
|
const float d = x[i].d;
|
841
1099
|
|
1100
|
+
#if QK_K == 256
|
842
1101
|
for (int n = 0; n < QK_K; n += 128) {
|
843
1102
|
for (int l = 0; l < 32; ++l) {
|
844
1103
|
int is = l/16;
|
@@ -856,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
856
1115
|
qh += 32;
|
857
1116
|
sc += 8;
|
858
1117
|
}
|
1118
|
+
#else
|
1119
|
+
for (int l = 0; l < 16; ++l) {
|
1120
|
+
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1121
|
+
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1122
|
+
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1123
|
+
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1124
|
+
y[l+ 0] = d * sc[0] * q1;
|
1125
|
+
y[l+16] = d * sc[1] * q2;
|
1126
|
+
y[l+32] = d * sc[2] * q3;
|
1127
|
+
y[l+48] = d * sc[3] * q4;
|
1128
|
+
}
|
1129
|
+
y += 64;
|
1130
|
+
#endif
|
859
1131
|
}
|
860
1132
|
}
|
861
1133
|
|
862
|
-
kernel void
|
1134
|
+
kernel void kernel_get_rows_q2_K(
|
863
1135
|
device const void * src0,
|
864
1136
|
device const int * src1,
|
865
1137
|
device float * dst,
|
@@ -870,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
|
|
870
1142
|
const int i = tpig;
|
871
1143
|
const int r = ((device int32_t *) src1)[i];
|
872
1144
|
|
873
|
-
|
874
|
-
(device const
|
1145
|
+
dequantize_row_q2_K(
|
1146
|
+
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
875
1147
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
876
1148
|
}
|
877
1149
|
|
878
|
-
kernel void
|
1150
|
+
kernel void kernel_get_rows_q3_K(
|
879
1151
|
device const void * src0,
|
880
1152
|
device const int * src1,
|
881
1153
|
device float * dst,
|
@@ -886,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
|
|
886
1158
|
const int i = tpig;
|
887
1159
|
const int r = ((device int32_t *) src1)[i];
|
888
1160
|
|
889
|
-
|
890
|
-
(device const
|
1161
|
+
dequantize_row_q3_K(
|
1162
|
+
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
891
1163
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
892
1164
|
}
|
893
1165
|
|
894
|
-
kernel void
|
1166
|
+
kernel void kernel_get_rows_q4_K(
|
895
1167
|
device const void * src0,
|
896
1168
|
device const int * src1,
|
897
1169
|
device float * dst,
|
@@ -902,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
|
|
902
1174
|
const int i = tpig;
|
903
1175
|
const int r = ((device int32_t *) src1)[i];
|
904
1176
|
|
905
|
-
|
906
|
-
(device const
|
1177
|
+
dequantize_row_q4_K(
|
1178
|
+
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
907
1179
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
908
1180
|
}
|
909
1181
|
|
910
|
-
kernel void
|
1182
|
+
kernel void kernel_get_rows_q5_K(
|
911
1183
|
device const void * src0,
|
912
1184
|
device const int * src1,
|
913
1185
|
device float * dst,
|
@@ -918,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
|
|
918
1190
|
const int i = tpig;
|
919
1191
|
const int r = ((device int32_t *) src1)[i];
|
920
1192
|
|
921
|
-
|
922
|
-
(device const
|
1193
|
+
dequantize_row_q5_K(
|
1194
|
+
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
923
1195
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
924
1196
|
}
|
925
1197
|
|
926
|
-
kernel void
|
1198
|
+
kernel void kernel_get_rows_q6_K(
|
927
1199
|
device const void * src0,
|
928
1200
|
device const int * src1,
|
929
1201
|
device float * dst,
|
@@ -934,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
|
|
934
1206
|
const int i = tpig;
|
935
1207
|
const int r = ((device int32_t *) src1)[i];
|
936
1208
|
|
937
|
-
|
938
|
-
(device const
|
1209
|
+
dequantize_row_q6_K(
|
1210
|
+
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
939
1211
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
940
1212
|
}
|
941
1213
|
|
942
1214
|
//====================================== dot products =========================
|
943
1215
|
|
944
|
-
kernel void
|
1216
|
+
kernel void kernel_mul_mat_q2_K_f32(
|
945
1217
|
device const void * src0,
|
946
1218
|
device const float * src1,
|
947
1219
|
device float * dst,
|
@@ -958,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
958
1230
|
const int64_t r0 = tgpig.x;
|
959
1231
|
const int64_t r1 = tgpig.y;
|
960
1232
|
|
961
|
-
device const
|
1233
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
|
962
1234
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
963
1235
|
|
964
1236
|
const int nth = tptg.x*tptg.y;
|
965
1237
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
966
1238
|
|
1239
|
+
float sumf = 0;
|
1240
|
+
|
1241
|
+
#if QK_K == 256
|
967
1242
|
const int tid = tpitg.y; // 0...16
|
968
1243
|
const int il = tid/4; // 0...3
|
969
1244
|
const int ir = tid%4; // 0...3
|
@@ -976,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
976
1251
|
const int y_offset = 64*il + n*ir;
|
977
1252
|
const int q_offset = 32*ip + n*ir;
|
978
1253
|
|
979
|
-
sum[ith] = 0.0f;
|
980
|
-
|
981
|
-
float sumf = 0;
|
982
1254
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
983
1255
|
|
984
1256
|
device const uint8_t * q = x[i].qs + q_offset;
|
@@ -991,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
991
1263
|
|
992
1264
|
device const float * y = yy + i*QK_K + y_offset;
|
993
1265
|
|
994
|
-
//float4 s = {0.f, 0.f, 0.f, 0.f};
|
995
1266
|
float2 s = {0.f, 0.f};
|
996
1267
|
float smin = 0;
|
997
1268
|
for (int l = 0; l < n; ++l) {
|
@@ -1006,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1006
1277
|
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
1007
1278
|
|
1008
1279
|
}
|
1009
|
-
|
1280
|
+
#else
|
1281
|
+
const int il = 4 * tpitg.x;
|
1010
1282
|
|
1011
|
-
|
1012
|
-
|
1283
|
+
uint32_t aux[2];
|
1284
|
+
thread const uint8_t * d = (thread const uint8_t *)aux;
|
1285
|
+
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
1013
1286
|
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1287
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1288
|
+
|
1289
|
+
device const uint8_t * q = x[i].qs + il;
|
1290
|
+
device const float * y = yy + i*QK_K + il;
|
1291
|
+
|
1292
|
+
const float dall = (float)x[i].d;
|
1293
|
+
const float dmin = (float)x[i].dmin;
|
1294
|
+
|
1295
|
+
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
1296
|
+
aux[0] = a[0] & 0x0f0f0f0f;
|
1297
|
+
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
1298
|
+
|
1299
|
+
for (int l = 0; l < 4; ++l) {
|
1300
|
+
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
|
1301
|
+
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
|
1302
|
+
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
|
1303
|
+
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
|
1304
|
+
}
|
1305
|
+
}
|
1306
|
+
#endif
|
1307
|
+
|
1308
|
+
sum[ith] = sumf;
|
1023
1309
|
|
1024
1310
|
//
|
1025
1311
|
// Accumulate the sum from all threads in the threadgroup
|
1026
|
-
// This version is slightly faster than the commented out one below,
|
1027
|
-
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
1028
1312
|
//
|
1029
1313
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1030
1314
|
if (ith%4 == 0) {
|
@@ -1041,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1041
1325
|
}
|
1042
1326
|
}
|
1043
1327
|
|
1044
|
-
kernel void
|
1328
|
+
kernel void kernel_mul_mat_q3_K_f32(
|
1045
1329
|
device const void * src0,
|
1046
1330
|
device const float * src1,
|
1047
1331
|
device float * dst,
|
@@ -1054,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1054
1338
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1055
1339
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1056
1340
|
|
1057
|
-
const uint16_t kmask1 = 0x0303;
|
1058
|
-
const uint16_t kmask2 = 0x0f0f;
|
1059
|
-
|
1060
|
-
const uint8_t m3 = 3;
|
1061
|
-
const int8_t m4 = 4;
|
1062
|
-
|
1063
1341
|
const int nb = ne00/QK_K;
|
1064
1342
|
|
1065
1343
|
const int64_t r0 = tgpig.x;
|
1066
1344
|
const int64_t r1 = tgpig.y;
|
1067
1345
|
|
1068
|
-
device const
|
1346
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
1069
1347
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1070
1348
|
|
1071
1349
|
const int nth = tptg.x*tptg.y;
|
1072
1350
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1073
1351
|
|
1352
|
+
#if QK_K == 256
|
1353
|
+
|
1354
|
+
const uint8_t m3 = 3;
|
1355
|
+
const int8_t m4 = 4;
|
1356
|
+
|
1357
|
+
const uint16_t kmask1 = 0x0303;
|
1358
|
+
const uint16_t kmask2 = 0x0f0f;
|
1359
|
+
|
1074
1360
|
const int tid = tpitg.y; // expecting 16
|
1075
1361
|
const int ip = tid/8; // 0 or 1
|
1076
1362
|
const int il = tid/2 - 4*ip; // 0...3
|
@@ -1124,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1124
1410
|
|
1125
1411
|
//sum[ith] = sumf;
|
1126
1412
|
sum[ith] = sumf1 - 32.f*sumf2;
|
1413
|
+
#else
|
1414
|
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1415
|
+
const int im = il/8; // 0, 0, 1, 1
|
1416
|
+
const int in = il%8; // 0, 4, 0, 4
|
1417
|
+
|
1418
|
+
float sumf = 0;
|
1419
|
+
|
1420
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1421
|
+
|
1422
|
+
const float d_all = (float)(x[i].d);
|
1423
|
+
|
1424
|
+
device const uint8_t * q = x[i].qs + il;
|
1425
|
+
device const uint8_t * h = x[i].hmask + in;
|
1426
|
+
device const float * y = yy + i * QK_K + il;
|
1427
|
+
|
1428
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
1429
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
1430
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
1431
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
1432
|
+
|
1433
|
+
for (int l = 0; l < 4; ++l) {
|
1434
|
+
const uint8_t hm = h[l] >> im;
|
1435
|
+
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
1436
|
+
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
1437
|
+
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
1438
|
+
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
}
|
1442
|
+
|
1443
|
+
sum[ith] = sumf;
|
1444
|
+
|
1445
|
+
#endif
|
1127
1446
|
|
1128
1447
|
//
|
1129
1448
|
// Accumulate the sum from all threads in the threadgroup
|
@@ -1144,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1144
1463
|
|
1145
1464
|
}
|
1146
1465
|
|
1147
|
-
kernel void
|
1466
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1148
1467
|
device const void * src0,
|
1149
1468
|
device const float * src1,
|
1150
1469
|
device float * dst,
|
@@ -1156,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1156
1475
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1157
1476
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1158
1477
|
|
1159
|
-
const uint16_t kmask1 = 0x3f3f;
|
1160
|
-
const uint16_t kmask2 = 0x0f0f;
|
1161
|
-
const uint16_t kmask3 = 0xc0c0;
|
1162
|
-
|
1163
1478
|
const int nb = ne00/QK_K;
|
1164
1479
|
|
1165
1480
|
const int64_t r0 = tgpig.x;
|
1166
1481
|
const int64_t r1 = tgpig.y;
|
1167
1482
|
|
1168
|
-
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
1169
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1170
|
-
|
1171
1483
|
const int nth = tptg.x*tptg.y;
|
1172
1484
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1173
1485
|
|
1486
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
|
1487
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1488
|
+
|
1489
|
+
float sumf = 0;
|
1490
|
+
|
1491
|
+
#if QK_K == 256
|
1492
|
+
|
1493
|
+
const uint16_t kmask1 = 0x3f3f;
|
1494
|
+
const uint16_t kmask2 = 0x0f0f;
|
1495
|
+
const uint16_t kmask3 = 0xc0c0;
|
1496
|
+
|
1174
1497
|
const int tid = tpitg.y; // 0...16
|
1175
1498
|
const int il = tid/4; // 0...3
|
1176
1499
|
const int ir = tid - 4*il;// 0...3
|
@@ -1183,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1183
1506
|
const int q_offset = 32*im + l0;
|
1184
1507
|
const int y_offset = 64*im + l0;
|
1185
1508
|
|
1186
|
-
sum[ith] = 0.0f;
|
1187
|
-
|
1188
1509
|
uchar2 sc1, sc2, sc3, sc4;
|
1189
1510
|
|
1190
|
-
float sumf = 0;
|
1191
1511
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1192
1512
|
|
1193
1513
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
@@ -1216,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1216
1536
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1217
1537
|
|
1218
1538
|
}
|
1539
|
+
#else
|
1540
|
+
uint16_t aux16[2];
|
1541
|
+
thread const uint8_t * scales = (thread const uint8_t *)aux16;
|
1542
|
+
|
1543
|
+
const int il = 4*tpitg.x;
|
1544
|
+
|
1545
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1546
|
+
|
1547
|
+
device const uint8_t * q = x[i].qs + il;
|
1548
|
+
device const float * y = yy + i * QK_K + il;
|
1549
|
+
|
1550
|
+
const float d = (float)x[i].d[0];
|
1551
|
+
const float m = (float)x[i].d[1];
|
1552
|
+
|
1553
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
1554
|
+
aux16[0] = a[0] & 0x0f0f;
|
1555
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
1556
|
+
|
1557
|
+
for (int l = 0; l < 4; ++l) {
|
1558
|
+
sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
|
1559
|
+
+ d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
|
1560
|
+
}
|
1561
|
+
}
|
1562
|
+
#endif
|
1219
1563
|
|
1220
1564
|
sum[ith] = sumf;
|
1221
1565
|
|
@@ -1252,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1252
1596
|
//}
|
1253
1597
|
}
|
1254
1598
|
|
1255
|
-
kernel void
|
1599
|
+
kernel void kernel_mul_mat_q5_K_f32(
|
1256
1600
|
device const void * src0,
|
1257
1601
|
device const float * src1,
|
1258
1602
|
device float * dst,
|
@@ -1264,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1264
1608
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1265
1609
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1266
1610
|
|
1267
|
-
const uint16_t kmask1 = 0x3f3f;
|
1268
|
-
const uint16_t kmask2 = 0x0f0f;
|
1269
|
-
const uint16_t kmask3 = 0xc0c0;
|
1270
|
-
|
1271
1611
|
const int nb = ne00/QK_K;
|
1272
1612
|
|
1273
1613
|
const int64_t r0 = tgpig.x;
|
1274
1614
|
const int64_t r1 = tgpig.y;
|
1275
1615
|
|
1276
|
-
device const
|
1616
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
|
1277
1617
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1278
1618
|
|
1279
1619
|
const int nth = tptg.x*tptg.y;
|
1280
1620
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1281
1621
|
|
1622
|
+
float sumf = 0;
|
1623
|
+
|
1624
|
+
#if QK_K == 256
|
1625
|
+
|
1626
|
+
const uint16_t kmask1 = 0x3f3f;
|
1627
|
+
const uint16_t kmask2 = 0x0f0f;
|
1628
|
+
const uint16_t kmask3 = 0xc0c0;
|
1629
|
+
|
1282
1630
|
const int tid = tpitg.y; // 0...16
|
1283
1631
|
const int il = tid/4; // 0...3
|
1284
1632
|
const int ir = tid - 4*il;// 0...3
|
@@ -1298,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1298
1646
|
|
1299
1647
|
uchar2 sc1, sc2, sc3, sc4;
|
1300
1648
|
|
1301
|
-
float sumf = 0;
|
1302
1649
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1303
1650
|
|
1304
1651
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
@@ -1330,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1330
1677
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1331
1678
|
|
1332
1679
|
}
|
1680
|
+
#else
|
1681
|
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1682
|
+
const int im = il/8; // 0, 0, 1, 1
|
1683
|
+
const int in = il%8; // 0, 4, 0, 4
|
1684
|
+
|
1685
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1686
|
+
|
1687
|
+
const float d = (float)x[i].d;
|
1688
|
+
device const uint8_t * q = x[i].qs + il;
|
1689
|
+
device const uint8_t * h = x[i].qh + in;
|
1690
|
+
device const int8_t * s = x[i].scales;
|
1691
|
+
device const float * y = yy + i*QK_K + il;
|
1692
|
+
|
1693
|
+
for (int l = 0; l < 4; ++l) {
|
1694
|
+
const uint8_t hl = h[l] >> im;
|
1695
|
+
sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
|
1696
|
+
+ y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
|
1697
|
+
+ y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
|
1698
|
+
+ y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
|
1699
|
+
}
|
1700
|
+
}
|
1701
|
+
#endif
|
1333
1702
|
sum[ith] = sumf;
|
1334
1703
|
|
1335
1704
|
//
|
@@ -1351,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1351
1720
|
|
1352
1721
|
}
|
1353
1722
|
|
1354
|
-
kernel void
|
1723
|
+
kernel void kernel_mul_mat_q6_K_f32(
|
1355
1724
|
device const void * src0,
|
1356
1725
|
device const float * src1,
|
1357
1726
|
device float * dst,
|
@@ -1373,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1373
1742
|
const int64_t r0 = tgpig.x;
|
1374
1743
|
const int64_t r1 = tgpig.y;
|
1375
1744
|
|
1376
|
-
device const
|
1745
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
|
1377
1746
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1378
1747
|
|
1379
1748
|
const int nth = tptg.x*tptg.y;
|
1380
1749
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1381
1750
|
|
1751
|
+
float sumf = 0;
|
1752
|
+
|
1753
|
+
#if QK_K == 256
|
1382
1754
|
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
1383
1755
|
const int iqs = 16 * tpitg.y;
|
1384
1756
|
const int ip = iqs / 128; // 0 or 1
|
@@ -1391,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1391
1763
|
const int q_offset_l = 64*ip + l0;
|
1392
1764
|
const int q_offset_h = 32*ip + l0;
|
1393
1765
|
|
1394
|
-
float sumf = 0;
|
1395
1766
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1396
1767
|
|
1397
1768
|
device const uint8_t * ql = x[i].ql + q_offset_l;
|
@@ -1413,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1413
1784
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
1414
1785
|
|
1415
1786
|
}
|
1787
|
+
#else
|
1788
|
+
const int il = 4*tpitg.x; // 0, 4, 8, 12
|
1789
|
+
|
1790
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1791
|
+
device const float * y = yy + i * QK_K + il;
|
1792
|
+
device const uint8_t * ql = x[i].ql + il;
|
1793
|
+
device const uint8_t * qh = x[i].qh + il;
|
1794
|
+
device const int8_t * s = x[i].scales;
|
1795
|
+
|
1796
|
+
const float d = x[i].d;
|
1797
|
+
|
1798
|
+
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
1799
|
+
for (int l = 0; l < 4; ++l) {
|
1800
|
+
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
1801
|
+
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
1802
|
+
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
1803
|
+
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
1804
|
+
}
|
1805
|
+
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
1806
|
+
}
|
1807
|
+
|
1808
|
+
#endif
|
1416
1809
|
|
1417
1810
|
sum[ith] = sumf;
|
1418
1811
|
|