llama_cpp 0.2.1 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +32 -0
- data/README.md +39 -6
- data/examples/README.md +32 -0
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +38 -0
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +231 -132
- data/ext/llama_cpp/src/ggml-cuda.cu +844 -337
- data/ext/llama_cpp/src/ggml-metal.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.m +193 -49
- data/ext/llama_cpp/src/ggml-metal.metal +477 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +493 -4
- data/ext/llama_cpp/src/ggml.c +1565 -430
- data/ext/llama_cpp/src/ggml.h +208 -14
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +194 -101
- data/ext/llama_cpp/src/llama.h +41 -14
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +12 -17
- metadata +3 -3
- data/lib/llama_cpp/client.rb +0 -172
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
|
|
256
256
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
257
257
|
}
|
258
258
|
|
259
|
+
kernel void kernel_norm(
|
260
|
+
device const void * src0,
|
261
|
+
device float * dst,
|
262
|
+
constant int64_t & ne00,
|
263
|
+
constant uint64_t & nb01,
|
264
|
+
constant float & eps,
|
265
|
+
threadgroup float * sum [[threadgroup(0)]],
|
266
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
267
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
268
|
+
uint ntg[[threads_per_threadgroup]]) {
|
269
|
+
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
270
|
+
// MEAN
|
271
|
+
// parallel sum
|
272
|
+
sum[tpitg] = 0.0f;
|
273
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
274
|
+
sum[tpitg] += x[i00];
|
275
|
+
}
|
276
|
+
// reduce
|
277
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
278
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
279
|
+
if (tpitg < i) {
|
280
|
+
sum[tpitg] += sum[tpitg + i];
|
281
|
+
}
|
282
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
283
|
+
}
|
284
|
+
// broadcast
|
285
|
+
if (tpitg == 0) {
|
286
|
+
sum[0] /= ne00;
|
287
|
+
}
|
288
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
289
|
+
const float mean = sum[0];
|
290
|
+
|
291
|
+
// recenter
|
292
|
+
device float * y = dst + tgpig*ne00;
|
293
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
294
|
+
y[i00] = x[i00] - mean;
|
295
|
+
}
|
296
|
+
|
297
|
+
// VARIANCE
|
298
|
+
// parallel sum
|
299
|
+
sum[tpitg] = 0.0f;
|
300
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
301
|
+
sum[tpitg] += y[i00] * y[i00];
|
302
|
+
}
|
303
|
+
// reduce
|
304
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
305
|
+
for (uint i = ntg/2; i > 0; i /= 2) {
|
306
|
+
if (tpitg < i) {
|
307
|
+
sum[tpitg] += sum[tpitg + i];
|
308
|
+
}
|
309
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
310
|
+
}
|
311
|
+
// broadcast
|
312
|
+
if (tpitg == 0) {
|
313
|
+
sum[0] /= ne00;
|
314
|
+
}
|
315
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
316
|
+
const float variance = sum[0];
|
317
|
+
|
318
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
319
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
320
|
+
y[i00] = y[i00] * scale;
|
321
|
+
}
|
322
|
+
}
|
323
|
+
|
324
|
+
|
259
325
|
kernel void kernel_rms_norm(
|
260
326
|
device const void * src0,
|
261
327
|
device float * dst,
|
@@ -362,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
362
428
|
}
|
363
429
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
364
430
|
if (ith == 0) {
|
365
|
-
for (
|
431
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
366
432
|
dst[r1*ne0 + r0] = sum[0];
|
367
433
|
}
|
368
434
|
}
|
@@ -431,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
431
497
|
}
|
432
498
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
433
499
|
if (ith == 0) {
|
434
|
-
for (
|
500
|
+
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
435
501
|
dst[r1*ne0 + r0] = sum[0];
|
436
502
|
}
|
437
503
|
}
|
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
|
|
485
551
|
}
|
486
552
|
}
|
487
553
|
|
554
|
+
kernel void kernel_alibi_f32(
|
555
|
+
device const float * src0,
|
556
|
+
device float * dst,
|
557
|
+
constant int64_t & ne00,
|
558
|
+
constant int64_t & ne01,
|
559
|
+
constant int64_t & ne02,
|
560
|
+
constant int64_t & ne03,
|
561
|
+
constant uint64_t & nb00,
|
562
|
+
constant uint64_t & nb01,
|
563
|
+
constant uint64_t & nb02,
|
564
|
+
constant uint64_t & nb03,
|
565
|
+
constant int64_t & ne0,
|
566
|
+
constant int64_t & ne1,
|
567
|
+
constant int64_t & ne2,
|
568
|
+
constant int64_t & ne3,
|
569
|
+
constant uint64_t & nb0,
|
570
|
+
constant uint64_t & nb1,
|
571
|
+
constant uint64_t & nb2,
|
572
|
+
constant uint64_t & nb3,
|
573
|
+
constant float & m0,
|
574
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
575
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
576
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
577
|
+
const int64_t i03 = tgpig[2];
|
578
|
+
const int64_t i02 = tgpig[1];
|
579
|
+
const int64_t i01 = tgpig[0];
|
580
|
+
|
581
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
582
|
+
|
583
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
584
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
585
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
586
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
587
|
+
|
588
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
589
|
+
float m_k = pow(m0, i2 + 1);
|
590
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
591
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
592
|
+
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
593
|
+
}
|
594
|
+
}
|
595
|
+
|
488
596
|
kernel void kernel_rope(
|
489
597
|
device const void * src0,
|
490
598
|
device float * dst,
|
@@ -540,6 +648,47 @@ kernel void kernel_rope(
|
|
540
648
|
}
|
541
649
|
}
|
542
650
|
|
651
|
+
kernel void kernel_cpy_f16_f16(
|
652
|
+
device const half * src0,
|
653
|
+
device half * dst,
|
654
|
+
constant int64_t & ne00,
|
655
|
+
constant int64_t & ne01,
|
656
|
+
constant int64_t & ne02,
|
657
|
+
constant int64_t & ne03,
|
658
|
+
constant uint64_t & nb00,
|
659
|
+
constant uint64_t & nb01,
|
660
|
+
constant uint64_t & nb02,
|
661
|
+
constant uint64_t & nb03,
|
662
|
+
constant int64_t & ne0,
|
663
|
+
constant int64_t & ne1,
|
664
|
+
constant int64_t & ne2,
|
665
|
+
constant int64_t & ne3,
|
666
|
+
constant uint64_t & nb0,
|
667
|
+
constant uint64_t & nb1,
|
668
|
+
constant uint64_t & nb2,
|
669
|
+
constant uint64_t & nb3,
|
670
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
671
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
672
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
673
|
+
const int64_t i03 = tgpig[2];
|
674
|
+
const int64_t i02 = tgpig[1];
|
675
|
+
const int64_t i01 = tgpig[0];
|
676
|
+
|
677
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
678
|
+
|
679
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
680
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
681
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
682
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
683
|
+
|
684
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
685
|
+
|
686
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
687
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
688
|
+
dst_data[i00] = src[0];
|
689
|
+
}
|
690
|
+
}
|
691
|
+
|
543
692
|
kernel void kernel_cpy_f32_f16(
|
544
693
|
device const float * src0,
|
545
694
|
device half * dst,
|
@@ -626,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
|
|
626
775
|
|
627
776
|
//============================================ k-quants ======================================================
|
628
777
|
|
778
|
+
#ifndef QK_K
|
629
779
|
#define QK_K 256
|
780
|
+
#else
|
781
|
+
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
|
782
|
+
#endif
|
783
|
+
|
784
|
+
#if QK_K == 256
|
785
|
+
#define K_SCALE_SIZE 12
|
786
|
+
#else
|
787
|
+
#define K_SCALE_SIZE 4
|
788
|
+
#endif
|
630
789
|
|
631
790
|
typedef struct {
|
632
791
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
633
792
|
uint8_t qs[QK_K/4]; // quants
|
634
793
|
half d; // super-block scale for quantized scales
|
635
794
|
half dmin; // super-block scale for quantized mins
|
636
|
-
}
|
795
|
+
} block_q2_K;
|
637
796
|
// 84 bytes / block
|
638
797
|
|
639
798
|
typedef struct {
|
640
799
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
641
800
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
//
|
646
|
-
|
801
|
+
#if QK_K == 64
|
802
|
+
uint8_t scales[2];
|
803
|
+
#else
|
804
|
+
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
805
|
+
#endif
|
806
|
+
half d; // super-block scale
|
807
|
+
} block_q3_K;
|
808
|
+
|
809
|
+
#if QK_K == 64
|
810
|
+
typedef struct {
|
811
|
+
half d[2]; // super-block scales/mins
|
812
|
+
uint8_t scales[2];
|
813
|
+
uint8_t qs[QK_K/2]; // 4-bit quants
|
814
|
+
} block_q4_K;
|
815
|
+
#else
|
647
816
|
typedef struct {
|
648
817
|
half d; // super-block scale for quantized scales
|
649
818
|
half dmin; // super-block scale for quantized mins
|
650
|
-
uint8_t scales[
|
819
|
+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
651
820
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
652
|
-
}
|
653
|
-
|
821
|
+
} block_q4_K;
|
822
|
+
#endif
|
654
823
|
|
824
|
+
#if QK_K == 64
|
825
|
+
typedef struct {
|
826
|
+
half d; // super-block scales/mins
|
827
|
+
int8_t scales[QK_K/16]; // 8-bit block scales
|
828
|
+
uint8_t qh[QK_K/8]; // quants, high bit
|
829
|
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
830
|
+
} block_q5_K;
|
831
|
+
#else
|
655
832
|
typedef struct {
|
656
833
|
half d; // super-block scale for quantized scales
|
657
834
|
half dmin; // super-block scale for quantized mins
|
658
835
|
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
659
836
|
uint8_t qh[QK_K/8]; // quants, high bit
|
660
837
|
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
661
|
-
}
|
838
|
+
} block_q5_K;
|
662
839
|
// 176 bytes / block
|
840
|
+
#endif
|
663
841
|
|
664
842
|
typedef struct {
|
665
843
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
666
844
|
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
667
845
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
668
846
|
half d; // super-block scale
|
669
|
-
}
|
847
|
+
} block_q6_K;
|
670
848
|
// 210 bytes / block
|
671
849
|
|
672
850
|
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
@@ -687,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
687
865
|
|
688
866
|
//========================================== dequantization =============================
|
689
867
|
|
690
|
-
static void
|
868
|
+
static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
|
691
869
|
assert(k % QK_K == 0);
|
692
870
|
const int nb = k / QK_K;
|
693
871
|
|
@@ -698,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
698
876
|
|
699
877
|
device const uint8_t * q = x[i].qs;
|
700
878
|
|
879
|
+
#if QK_K == 256
|
701
880
|
int is = 0;
|
702
881
|
float dl, ml;
|
703
882
|
for (int n = 0; n < QK_K; n += 128) {
|
@@ -716,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
716
895
|
}
|
717
896
|
q += 32;
|
718
897
|
}
|
898
|
+
#else
|
899
|
+
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
900
|
+
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
901
|
+
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
902
|
+
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
903
|
+
for (int l = 0; l < 16; ++l) {
|
904
|
+
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
|
905
|
+
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
|
906
|
+
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
|
907
|
+
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
|
908
|
+
}
|
909
|
+
y += QK_K;
|
910
|
+
#endif
|
719
911
|
|
720
912
|
}
|
721
913
|
}
|
722
914
|
|
723
|
-
static void
|
915
|
+
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
724
916
|
assert(k % QK_K == 0);
|
725
917
|
const int nb = k / QK_K;
|
726
918
|
|
919
|
+
#if QK_K == 256
|
920
|
+
|
727
921
|
const uint16_t kmask1 = 0x0303;
|
728
922
|
const uint16_t kmask2 = 0x0f0f;
|
729
923
|
|
@@ -769,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
|
769
963
|
}
|
770
964
|
q += 32;
|
771
965
|
}
|
966
|
+
}
|
967
|
+
#else
|
968
|
+
for (int i = 0; i < nb; i++) {
|
969
|
+
|
970
|
+
const float d_all = (float)(x[i].d);
|
772
971
|
|
972
|
+
device const uint8_t * q = x[i].qs;
|
973
|
+
device const uint8_t * hm = x[i].hmask;
|
974
|
+
|
975
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
976
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
977
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
978
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
979
|
+
|
980
|
+
for (int l = 0; l < 8; ++l) {
|
981
|
+
uint8_t h = hm[l];
|
982
|
+
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
983
|
+
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
984
|
+
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
985
|
+
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
986
|
+
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
987
|
+
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
988
|
+
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
989
|
+
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
990
|
+
}
|
991
|
+
y += QK_K;
|
773
992
|
}
|
993
|
+
#endif
|
774
994
|
|
775
995
|
}
|
776
996
|
|
777
|
-
static void
|
997
|
+
static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
|
778
998
|
assert(k % QK_K == 0);
|
779
999
|
const int nb = k / QK_K;
|
780
1000
|
|
781
|
-
|
782
1001
|
for (int i = 0; i < nb; i++) {
|
783
1002
|
|
1003
|
+
device const uint8_t * q = x[i].qs;
|
1004
|
+
|
1005
|
+
#if QK_K == 256
|
784
1006
|
const float d = x[i].d;
|
785
1007
|
const float min = x[i].dmin;
|
786
1008
|
|
787
|
-
device const uint8_t * q = x[i].qs;
|
788
1009
|
device const uint8_t * scales = x[i].scales;
|
789
1010
|
|
790
1011
|
int is = 0;
|
@@ -796,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
|
796
1017
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
797
1018
|
q += 32; is += 2;
|
798
1019
|
}
|
1020
|
+
#else
|
1021
|
+
device const uint8_t * s = x[i].scales;
|
1022
|
+
device const half2 * dh = (device const half2 *)x[i].d;
|
1023
|
+
const float2 d = (float2)dh[0];
|
1024
|
+
const float d1 = d[0] * (s[0] & 0xF);
|
1025
|
+
const float d2 = d[0] * (s[1] & 0xF);
|
1026
|
+
const float m1 = d[1] * (s[0] >> 4);
|
1027
|
+
const float m2 = d[1] * (s[1] >> 4);
|
1028
|
+
for (int l = 0; l < 32; ++l) {
|
1029
|
+
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
1030
|
+
y[l+32] = d2 * (q[l] >> 4) - m2;
|
1031
|
+
}
|
1032
|
+
y += QK_K;
|
1033
|
+
#endif
|
799
1034
|
|
800
1035
|
}
|
801
1036
|
}
|
802
1037
|
|
803
|
-
static void
|
1038
|
+
static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
|
804
1039
|
assert(k % QK_K == 0);
|
805
1040
|
const int nb = k / QK_K;
|
806
1041
|
|
1042
|
+
#if QK_K == 256
|
807
1043
|
for (int i = 0; i < nb; i++) {
|
808
1044
|
|
809
1045
|
const float d = (float)(x[i].d);
|
@@ -824,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
|
|
824
1060
|
u1 <<= 2; u2 <<= 2;
|
825
1061
|
}
|
826
1062
|
}
|
1063
|
+
#else
|
1064
|
+
for (int i = 0; i < nb; i++) {
|
1065
|
+
|
1066
|
+
const float d = (float)x[i].d;
|
1067
|
+
|
1068
|
+
device const uint8_t * ql = x[i].qs;
|
1069
|
+
device const uint8_t * qh = x[i].qh;
|
1070
|
+
device const int8_t * sc = x[i].scales;
|
1071
|
+
|
1072
|
+
for (int l = 0; l < 8; ++l) {
|
1073
|
+
y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
1074
|
+
y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
1075
|
+
y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
1076
|
+
y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
1077
|
+
y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
1078
|
+
y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
1079
|
+
y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
1080
|
+
y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
1081
|
+
}
|
1082
|
+
y += QK_K;
|
1083
|
+
}
|
1084
|
+
#endif
|
827
1085
|
|
828
1086
|
}
|
829
1087
|
|
830
|
-
static void
|
1088
|
+
static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
|
831
1089
|
assert(k % QK_K == 0);
|
832
1090
|
const int nb = k / QK_K;
|
833
1091
|
|
@@ -839,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
839
1097
|
|
840
1098
|
const float d = x[i].d;
|
841
1099
|
|
1100
|
+
#if QK_K == 256
|
842
1101
|
for (int n = 0; n < QK_K; n += 128) {
|
843
1102
|
for (int l = 0; l < 32; ++l) {
|
844
1103
|
int is = l/16;
|
@@ -856,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
856
1115
|
qh += 32;
|
857
1116
|
sc += 8;
|
858
1117
|
}
|
1118
|
+
#else
|
1119
|
+
for (int l = 0; l < 16; ++l) {
|
1120
|
+
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1121
|
+
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1122
|
+
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1123
|
+
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1124
|
+
y[l+ 0] = d * sc[0] * q1;
|
1125
|
+
y[l+16] = d * sc[1] * q2;
|
1126
|
+
y[l+32] = d * sc[2] * q3;
|
1127
|
+
y[l+48] = d * sc[3] * q4;
|
1128
|
+
}
|
1129
|
+
y += 64;
|
1130
|
+
#endif
|
859
1131
|
}
|
860
1132
|
}
|
861
1133
|
|
862
|
-
kernel void
|
1134
|
+
kernel void kernel_get_rows_q2_K(
|
863
1135
|
device const void * src0,
|
864
1136
|
device const int * src1,
|
865
1137
|
device float * dst,
|
@@ -870,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
|
|
870
1142
|
const int i = tpig;
|
871
1143
|
const int r = ((device int32_t *) src1)[i];
|
872
1144
|
|
873
|
-
|
874
|
-
(device const
|
1145
|
+
dequantize_row_q2_K(
|
1146
|
+
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
875
1147
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
876
1148
|
}
|
877
1149
|
|
878
|
-
kernel void
|
1150
|
+
kernel void kernel_get_rows_q3_K(
|
879
1151
|
device const void * src0,
|
880
1152
|
device const int * src1,
|
881
1153
|
device float * dst,
|
@@ -886,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
|
|
886
1158
|
const int i = tpig;
|
887
1159
|
const int r = ((device int32_t *) src1)[i];
|
888
1160
|
|
889
|
-
|
890
|
-
(device const
|
1161
|
+
dequantize_row_q3_K(
|
1162
|
+
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
891
1163
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
892
1164
|
}
|
893
1165
|
|
894
|
-
kernel void
|
1166
|
+
kernel void kernel_get_rows_q4_K(
|
895
1167
|
device const void * src0,
|
896
1168
|
device const int * src1,
|
897
1169
|
device float * dst,
|
@@ -902,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
|
|
902
1174
|
const int i = tpig;
|
903
1175
|
const int r = ((device int32_t *) src1)[i];
|
904
1176
|
|
905
|
-
|
906
|
-
(device const
|
1177
|
+
dequantize_row_q4_K(
|
1178
|
+
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
907
1179
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
908
1180
|
}
|
909
1181
|
|
910
|
-
kernel void
|
1182
|
+
kernel void kernel_get_rows_q5_K(
|
911
1183
|
device const void * src0,
|
912
1184
|
device const int * src1,
|
913
1185
|
device float * dst,
|
@@ -918,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
|
|
918
1190
|
const int i = tpig;
|
919
1191
|
const int r = ((device int32_t *) src1)[i];
|
920
1192
|
|
921
|
-
|
922
|
-
(device const
|
1193
|
+
dequantize_row_q5_K(
|
1194
|
+
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
923
1195
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
924
1196
|
}
|
925
1197
|
|
926
|
-
kernel void
|
1198
|
+
kernel void kernel_get_rows_q6_K(
|
927
1199
|
device const void * src0,
|
928
1200
|
device const int * src1,
|
929
1201
|
device float * dst,
|
@@ -934,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
|
|
934
1206
|
const int i = tpig;
|
935
1207
|
const int r = ((device int32_t *) src1)[i];
|
936
1208
|
|
937
|
-
|
938
|
-
(device const
|
1209
|
+
dequantize_row_q6_K(
|
1210
|
+
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
939
1211
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
940
1212
|
}
|
941
1213
|
|
942
1214
|
//====================================== dot products =========================
|
943
1215
|
|
944
|
-
kernel void
|
1216
|
+
kernel void kernel_mul_mat_q2_K_f32(
|
945
1217
|
device const void * src0,
|
946
1218
|
device const float * src1,
|
947
1219
|
device float * dst,
|
@@ -958,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
958
1230
|
const int64_t r0 = tgpig.x;
|
959
1231
|
const int64_t r1 = tgpig.y;
|
960
1232
|
|
961
|
-
device const
|
1233
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
|
962
1234
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
963
1235
|
|
964
1236
|
const int nth = tptg.x*tptg.y;
|
965
1237
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
966
1238
|
|
1239
|
+
float sumf = 0;
|
1240
|
+
|
1241
|
+
#if QK_K == 256
|
967
1242
|
const int tid = tpitg.y; // 0...16
|
968
1243
|
const int il = tid/4; // 0...3
|
969
1244
|
const int ir = tid%4; // 0...3
|
@@ -976,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
976
1251
|
const int y_offset = 64*il + n*ir;
|
977
1252
|
const int q_offset = 32*ip + n*ir;
|
978
1253
|
|
979
|
-
sum[ith] = 0.0f;
|
980
|
-
|
981
|
-
float sumf = 0;
|
982
1254
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
983
1255
|
|
984
1256
|
device const uint8_t * q = x[i].qs + q_offset;
|
@@ -991,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
991
1263
|
|
992
1264
|
device const float * y = yy + i*QK_K + y_offset;
|
993
1265
|
|
994
|
-
//float4 s = {0.f, 0.f, 0.f, 0.f};
|
995
1266
|
float2 s = {0.f, 0.f};
|
996
1267
|
float smin = 0;
|
997
1268
|
for (int l = 0; l < n; ++l) {
|
@@ -1006,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1006
1277
|
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
1007
1278
|
|
1008
1279
|
}
|
1009
|
-
|
1280
|
+
#else
|
1281
|
+
const int il = 4 * tpitg.x;
|
1010
1282
|
|
1011
|
-
|
1012
|
-
|
1283
|
+
uint32_t aux[2];
|
1284
|
+
thread const uint8_t * d = (thread const uint8_t *)aux;
|
1285
|
+
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
1013
1286
|
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1287
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1288
|
+
|
1289
|
+
device const uint8_t * q = x[i].qs + il;
|
1290
|
+
device const float * y = yy + i*QK_K + il;
|
1291
|
+
|
1292
|
+
const float dall = (float)x[i].d;
|
1293
|
+
const float dmin = (float)x[i].dmin;
|
1294
|
+
|
1295
|
+
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
1296
|
+
aux[0] = a[0] & 0x0f0f0f0f;
|
1297
|
+
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
1298
|
+
|
1299
|
+
for (int l = 0; l < 4; ++l) {
|
1300
|
+
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
|
1301
|
+
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
|
1302
|
+
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
|
1303
|
+
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
|
1304
|
+
}
|
1305
|
+
}
|
1306
|
+
#endif
|
1307
|
+
|
1308
|
+
sum[ith] = sumf;
|
1023
1309
|
|
1024
1310
|
//
|
1025
1311
|
// Accumulate the sum from all threads in the threadgroup
|
1026
|
-
// This version is slightly faster than the commented out one below,
|
1027
|
-
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
1028
1312
|
//
|
1029
1313
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1030
1314
|
if (ith%4 == 0) {
|
@@ -1041,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1041
1325
|
}
|
1042
1326
|
}
|
1043
1327
|
|
1044
|
-
kernel void
|
1328
|
+
kernel void kernel_mul_mat_q3_K_f32(
|
1045
1329
|
device const void * src0,
|
1046
1330
|
device const float * src1,
|
1047
1331
|
device float * dst,
|
@@ -1054,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1054
1338
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1055
1339
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1056
1340
|
|
1057
|
-
const uint16_t kmask1 = 0x0303;
|
1058
|
-
const uint16_t kmask2 = 0x0f0f;
|
1059
|
-
|
1060
|
-
const uint8_t m3 = 3;
|
1061
|
-
const int8_t m4 = 4;
|
1062
|
-
|
1063
1341
|
const int nb = ne00/QK_K;
|
1064
1342
|
|
1065
1343
|
const int64_t r0 = tgpig.x;
|
1066
1344
|
const int64_t r1 = tgpig.y;
|
1067
1345
|
|
1068
|
-
device const
|
1346
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
1069
1347
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1070
1348
|
|
1071
1349
|
const int nth = tptg.x*tptg.y;
|
1072
1350
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1073
1351
|
|
1352
|
+
#if QK_K == 256
|
1353
|
+
|
1354
|
+
const uint8_t m3 = 3;
|
1355
|
+
const int8_t m4 = 4;
|
1356
|
+
|
1357
|
+
const uint16_t kmask1 = 0x0303;
|
1358
|
+
const uint16_t kmask2 = 0x0f0f;
|
1359
|
+
|
1074
1360
|
const int tid = tpitg.y; // expecting 16
|
1075
1361
|
const int ip = tid/8; // 0 or 1
|
1076
1362
|
const int il = tid/2 - 4*ip; // 0...3
|
@@ -1124,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1124
1410
|
|
1125
1411
|
//sum[ith] = sumf;
|
1126
1412
|
sum[ith] = sumf1 - 32.f*sumf2;
|
1413
|
+
#else
|
1414
|
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1415
|
+
const int im = il/8; // 0, 0, 1, 1
|
1416
|
+
const int in = il%8; // 0, 4, 0, 4
|
1417
|
+
|
1418
|
+
float sumf = 0;
|
1419
|
+
|
1420
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1421
|
+
|
1422
|
+
const float d_all = (float)(x[i].d);
|
1423
|
+
|
1424
|
+
device const uint8_t * q = x[i].qs + il;
|
1425
|
+
device const uint8_t * h = x[i].hmask + in;
|
1426
|
+
device const float * y = yy + i * QK_K + il;
|
1427
|
+
|
1428
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
1429
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
1430
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
1431
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
1432
|
+
|
1433
|
+
for (int l = 0; l < 4; ++l) {
|
1434
|
+
const uint8_t hm = h[l] >> im;
|
1435
|
+
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
1436
|
+
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
1437
|
+
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
1438
|
+
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
}
|
1442
|
+
|
1443
|
+
sum[ith] = sumf;
|
1444
|
+
|
1445
|
+
#endif
|
1127
1446
|
|
1128
1447
|
//
|
1129
1448
|
// Accumulate the sum from all threads in the threadgroup
|
@@ -1144,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1144
1463
|
|
1145
1464
|
}
|
1146
1465
|
|
1147
|
-
kernel void
|
1466
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1148
1467
|
device const void * src0,
|
1149
1468
|
device const float * src1,
|
1150
1469
|
device float * dst,
|
@@ -1156,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1156
1475
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1157
1476
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1158
1477
|
|
1159
|
-
const uint16_t kmask1 = 0x3f3f;
|
1160
|
-
const uint16_t kmask2 = 0x0f0f;
|
1161
|
-
const uint16_t kmask3 = 0xc0c0;
|
1162
|
-
|
1163
1478
|
const int nb = ne00/QK_K;
|
1164
1479
|
|
1165
1480
|
const int64_t r0 = tgpig.x;
|
1166
1481
|
const int64_t r1 = tgpig.y;
|
1167
1482
|
|
1168
|
-
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
1169
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1170
|
-
|
1171
1483
|
const int nth = tptg.x*tptg.y;
|
1172
1484
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1173
1485
|
|
1486
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
|
1487
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1488
|
+
|
1489
|
+
float sumf = 0;
|
1490
|
+
|
1491
|
+
#if QK_K == 256
|
1492
|
+
|
1493
|
+
const uint16_t kmask1 = 0x3f3f;
|
1494
|
+
const uint16_t kmask2 = 0x0f0f;
|
1495
|
+
const uint16_t kmask3 = 0xc0c0;
|
1496
|
+
|
1174
1497
|
const int tid = tpitg.y; // 0...16
|
1175
1498
|
const int il = tid/4; // 0...3
|
1176
1499
|
const int ir = tid - 4*il;// 0...3
|
@@ -1183,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1183
1506
|
const int q_offset = 32*im + l0;
|
1184
1507
|
const int y_offset = 64*im + l0;
|
1185
1508
|
|
1186
|
-
sum[ith] = 0.0f;
|
1187
|
-
|
1188
1509
|
uchar2 sc1, sc2, sc3, sc4;
|
1189
1510
|
|
1190
|
-
float sumf = 0;
|
1191
1511
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1192
1512
|
|
1193
1513
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
@@ -1216,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1216
1536
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1217
1537
|
|
1218
1538
|
}
|
1539
|
+
#else
|
1540
|
+
uint16_t aux16[2];
|
1541
|
+
thread const uint8_t * scales = (thread const uint8_t *)aux16;
|
1542
|
+
|
1543
|
+
const int il = 4*tpitg.x;
|
1544
|
+
|
1545
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1546
|
+
|
1547
|
+
device const uint8_t * q = x[i].qs + il;
|
1548
|
+
device const float * y = yy + i * QK_K + il;
|
1549
|
+
|
1550
|
+
const float d = (float)x[i].d[0];
|
1551
|
+
const float m = (float)x[i].d[1];
|
1552
|
+
|
1553
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
1554
|
+
aux16[0] = a[0] & 0x0f0f;
|
1555
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
1556
|
+
|
1557
|
+
for (int l = 0; l < 4; ++l) {
|
1558
|
+
sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
|
1559
|
+
+ d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
|
1560
|
+
}
|
1561
|
+
}
|
1562
|
+
#endif
|
1219
1563
|
|
1220
1564
|
sum[ith] = sumf;
|
1221
1565
|
|
@@ -1252,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1252
1596
|
//}
|
1253
1597
|
}
|
1254
1598
|
|
1255
|
-
kernel void
|
1599
|
+
kernel void kernel_mul_mat_q5_K_f32(
|
1256
1600
|
device const void * src0,
|
1257
1601
|
device const float * src1,
|
1258
1602
|
device float * dst,
|
@@ -1264,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1264
1608
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1265
1609
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1266
1610
|
|
1267
|
-
const uint16_t kmask1 = 0x3f3f;
|
1268
|
-
const uint16_t kmask2 = 0x0f0f;
|
1269
|
-
const uint16_t kmask3 = 0xc0c0;
|
1270
|
-
|
1271
1611
|
const int nb = ne00/QK_K;
|
1272
1612
|
|
1273
1613
|
const int64_t r0 = tgpig.x;
|
1274
1614
|
const int64_t r1 = tgpig.y;
|
1275
1615
|
|
1276
|
-
device const
|
1616
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
|
1277
1617
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1278
1618
|
|
1279
1619
|
const int nth = tptg.x*tptg.y;
|
1280
1620
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1281
1621
|
|
1622
|
+
float sumf = 0;
|
1623
|
+
|
1624
|
+
#if QK_K == 256
|
1625
|
+
|
1626
|
+
const uint16_t kmask1 = 0x3f3f;
|
1627
|
+
const uint16_t kmask2 = 0x0f0f;
|
1628
|
+
const uint16_t kmask3 = 0xc0c0;
|
1629
|
+
|
1282
1630
|
const int tid = tpitg.y; // 0...16
|
1283
1631
|
const int il = tid/4; // 0...3
|
1284
1632
|
const int ir = tid - 4*il;// 0...3
|
@@ -1298,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1298
1646
|
|
1299
1647
|
uchar2 sc1, sc2, sc3, sc4;
|
1300
1648
|
|
1301
|
-
float sumf = 0;
|
1302
1649
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1303
1650
|
|
1304
1651
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
@@ -1330,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1330
1677
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1331
1678
|
|
1332
1679
|
}
|
1680
|
+
#else
|
1681
|
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1682
|
+
const int im = il/8; // 0, 0, 1, 1
|
1683
|
+
const int in = il%8; // 0, 4, 0, 4
|
1684
|
+
|
1685
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1686
|
+
|
1687
|
+
const float d = (float)x[i].d;
|
1688
|
+
device const uint8_t * q = x[i].qs + il;
|
1689
|
+
device const uint8_t * h = x[i].qh + in;
|
1690
|
+
device const int8_t * s = x[i].scales;
|
1691
|
+
device const float * y = yy + i*QK_K + il;
|
1692
|
+
|
1693
|
+
for (int l = 0; l < 4; ++l) {
|
1694
|
+
const uint8_t hl = h[l] >> im;
|
1695
|
+
sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
|
1696
|
+
+ y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
|
1697
|
+
+ y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
|
1698
|
+
+ y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
|
1699
|
+
}
|
1700
|
+
}
|
1701
|
+
#endif
|
1333
1702
|
sum[ith] = sumf;
|
1334
1703
|
|
1335
1704
|
//
|
@@ -1351,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1351
1720
|
|
1352
1721
|
}
|
1353
1722
|
|
1354
|
-
kernel void
|
1723
|
+
kernel void kernel_mul_mat_q6_K_f32(
|
1355
1724
|
device const void * src0,
|
1356
1725
|
device const float * src1,
|
1357
1726
|
device float * dst,
|
@@ -1373,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1373
1742
|
const int64_t r0 = tgpig.x;
|
1374
1743
|
const int64_t r1 = tgpig.y;
|
1375
1744
|
|
1376
|
-
device const
|
1745
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
|
1377
1746
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1378
1747
|
|
1379
1748
|
const int nth = tptg.x*tptg.y;
|
1380
1749
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1381
1750
|
|
1751
|
+
float sumf = 0;
|
1752
|
+
|
1753
|
+
#if QK_K == 256
|
1382
1754
|
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
1383
1755
|
const int iqs = 16 * tpitg.y;
|
1384
1756
|
const int ip = iqs / 128; // 0 or 1
|
@@ -1391,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1391
1763
|
const int q_offset_l = 64*ip + l0;
|
1392
1764
|
const int q_offset_h = 32*ip + l0;
|
1393
1765
|
|
1394
|
-
float sumf = 0;
|
1395
1766
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1396
1767
|
|
1397
1768
|
device const uint8_t * ql = x[i].ql + q_offset_l;
|
@@ -1413,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1413
1784
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
1414
1785
|
|
1415
1786
|
}
|
1787
|
+
#else
|
1788
|
+
const int il = 4*tpitg.x; // 0, 4, 8, 12
|
1789
|
+
|
1790
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1791
|
+
device const float * y = yy + i * QK_K + il;
|
1792
|
+
device const uint8_t * ql = x[i].ql + il;
|
1793
|
+
device const uint8_t * qh = x[i].qh + il;
|
1794
|
+
device const int8_t * s = x[i].scales;
|
1795
|
+
|
1796
|
+
const float d = x[i].d;
|
1797
|
+
|
1798
|
+
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
1799
|
+
for (int l = 0; l < 4; ++l) {
|
1800
|
+
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
1801
|
+
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
1802
|
+
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
1803
|
+
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
1804
|
+
}
|
1805
|
+
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
1806
|
+
}
|
1807
|
+
|
1808
|
+
#endif
|
1416
1809
|
|
1417
1810
|
sum[ith] = sumf;
|
1418
1811
|
|