llama_cpp 0.5.2 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +14 -8
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +307 -127
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +200 -94
- data/ext/llama_cpp/src/ggml-metal.metal +264 -82
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +1647 -865
- data/ext/llama_cpp/src/ggml.h +143 -52
- data/ext/llama_cpp/src/llama.cpp +1427 -635
- data/ext/llama_cpp/src/llama.h +308 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
@@ -24,12 +24,59 @@ typedef struct {
|
|
24
24
|
int8_t qs[QK8_0]; // quants
|
25
25
|
} block_q8_0;
|
26
26
|
|
27
|
+
// general-purpose kernel for addition of two tensors
|
28
|
+
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
29
|
+
// cons: not very efficient
|
27
30
|
kernel void kernel_add(
|
28
|
-
device const
|
29
|
-
device const
|
30
|
-
device
|
31
|
-
|
32
|
-
|
31
|
+
device const char * src0,
|
32
|
+
device const char * src1,
|
33
|
+
device char * dst,
|
34
|
+
constant int64_t & ne00,
|
35
|
+
constant int64_t & ne01,
|
36
|
+
constant int64_t & ne02,
|
37
|
+
constant int64_t & ne03,
|
38
|
+
constant int64_t & nb00,
|
39
|
+
constant int64_t & nb01,
|
40
|
+
constant int64_t & nb02,
|
41
|
+
constant int64_t & nb03,
|
42
|
+
constant int64_t & ne10,
|
43
|
+
constant int64_t & ne11,
|
44
|
+
constant int64_t & ne12,
|
45
|
+
constant int64_t & ne13,
|
46
|
+
constant int64_t & nb10,
|
47
|
+
constant int64_t & nb11,
|
48
|
+
constant int64_t & nb12,
|
49
|
+
constant int64_t & nb13,
|
50
|
+
constant int64_t & ne0,
|
51
|
+
constant int64_t & ne1,
|
52
|
+
constant int64_t & ne2,
|
53
|
+
constant int64_t & ne3,
|
54
|
+
constant int64_t & nb0,
|
55
|
+
constant int64_t & nb1,
|
56
|
+
constant int64_t & nb2,
|
57
|
+
constant int64_t & nb3,
|
58
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
59
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
60
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
61
|
+
const int64_t i03 = tgpig.z;
|
62
|
+
const int64_t i02 = tgpig.y;
|
63
|
+
const int64_t i01 = tgpig.x;
|
64
|
+
|
65
|
+
const int64_t i13 = i03 % ne13;
|
66
|
+
const int64_t i12 = i02 % ne12;
|
67
|
+
const int64_t i11 = i01 % ne11;
|
68
|
+
|
69
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
70
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
71
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
72
|
+
|
73
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
74
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
|
75
|
+
|
76
|
+
src0_ptr += ntg.x*nb00;
|
77
|
+
src1_ptr += ntg.x*nb10;
|
78
|
+
dst_ptr += ntg.x*nb0;
|
79
|
+
}
|
33
80
|
}
|
34
81
|
|
35
82
|
// assumption: src1 is a row
|
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
|
|
38
85
|
device const float4 * src0,
|
39
86
|
device const float4 * src1,
|
40
87
|
device float4 * dst,
|
41
|
-
constant
|
88
|
+
constant int64_t & nb [[buffer(27)]],
|
42
89
|
uint tpig[[thread_position_in_grid]]) {
|
43
90
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
44
91
|
}
|
@@ -118,7 +165,7 @@ kernel void kernel_soft_max(
|
|
118
165
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
119
166
|
|
120
167
|
// parallel max
|
121
|
-
float lmax = psrc0[tpitg[0]];
|
168
|
+
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
|
122
169
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
123
170
|
lmax = MAX(lmax, psrc0[i00]);
|
124
171
|
}
|
@@ -158,7 +205,7 @@ kernel void kernel_soft_max_4(
|
|
158
205
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
159
206
|
|
160
207
|
// parallel max
|
161
|
-
float4 lmax4 = psrc4[tpitg[0]];
|
208
|
+
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
|
162
209
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
163
210
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
164
211
|
}
|
@@ -523,6 +570,79 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
523
570
|
}
|
524
571
|
}
|
525
572
|
|
573
|
+
#define N_F32_F32 4
|
574
|
+
|
575
|
+
kernel void kernel_mul_mat_f32_f32(
|
576
|
+
device const char * src0,
|
577
|
+
device const char * src1,
|
578
|
+
device float * dst,
|
579
|
+
constant int64_t & ne00,
|
580
|
+
constant int64_t & ne01,
|
581
|
+
constant int64_t & ne02,
|
582
|
+
constant uint64_t & nb00,
|
583
|
+
constant uint64_t & nb01,
|
584
|
+
constant uint64_t & nb02,
|
585
|
+
constant int64_t & ne10,
|
586
|
+
constant int64_t & ne11,
|
587
|
+
constant int64_t & ne12,
|
588
|
+
constant uint64_t & nb10,
|
589
|
+
constant uint64_t & nb11,
|
590
|
+
constant uint64_t & nb12,
|
591
|
+
constant int64_t & ne0,
|
592
|
+
constant int64_t & ne1,
|
593
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
594
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
595
|
+
|
596
|
+
const int64_t r0 = tgpig.x;
|
597
|
+
const int64_t rb = tgpig.y*N_F32_F32;
|
598
|
+
const int64_t im = tgpig.z;
|
599
|
+
|
600
|
+
device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
601
|
+
|
602
|
+
if (ne00 < 128) {
|
603
|
+
for (int row = 0; row < N_F32_F32; ++row) {
|
604
|
+
int r1 = rb + row;
|
605
|
+
if (r1 >= ne11) {
|
606
|
+
break;
|
607
|
+
}
|
608
|
+
|
609
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
610
|
+
|
611
|
+
float sumf = 0;
|
612
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
613
|
+
sumf += (float) x[i] * (float) y[i];
|
614
|
+
}
|
615
|
+
|
616
|
+
float all_sum = simd_sum(sumf);
|
617
|
+
if (tiisg == 0) {
|
618
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
619
|
+
}
|
620
|
+
}
|
621
|
+
} else {
|
622
|
+
device const float4 * x4 = (device const float4 *)x;
|
623
|
+
for (int row = 0; row < N_F32_F32; ++row) {
|
624
|
+
int r1 = rb + row;
|
625
|
+
if (r1 >= ne11) {
|
626
|
+
break;
|
627
|
+
}
|
628
|
+
|
629
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
630
|
+
device const float4 * y4 = (device const float4 *) y;
|
631
|
+
|
632
|
+
float sumf = 0;
|
633
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
634
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
635
|
+
}
|
636
|
+
|
637
|
+
float all_sum = simd_sum(sumf);
|
638
|
+
if (tiisg == 0) {
|
639
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
640
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
641
|
+
}
|
642
|
+
}
|
643
|
+
}
|
644
|
+
}
|
645
|
+
|
526
646
|
kernel void kernel_mul_mat_f16_f32_1row(
|
527
647
|
device const char * src0,
|
528
648
|
device const char * src1,
|
@@ -733,30 +853,61 @@ kernel void kernel_alibi_f32(
|
|
733
853
|
}
|
734
854
|
}
|
735
855
|
|
856
|
+
typedef void (rope_t)(
|
857
|
+
device const void * src0,
|
858
|
+
device const int32_t * src1,
|
859
|
+
device float * dst,
|
860
|
+
constant int64_t & ne00,
|
861
|
+
constant int64_t & ne01,
|
862
|
+
constant int64_t & ne02,
|
863
|
+
constant int64_t & ne03,
|
864
|
+
constant uint64_t & nb00,
|
865
|
+
constant uint64_t & nb01,
|
866
|
+
constant uint64_t & nb02,
|
867
|
+
constant uint64_t & nb03,
|
868
|
+
constant int64_t & ne0,
|
869
|
+
constant int64_t & ne1,
|
870
|
+
constant int64_t & ne2,
|
871
|
+
constant int64_t & ne3,
|
872
|
+
constant uint64_t & nb0,
|
873
|
+
constant uint64_t & nb1,
|
874
|
+
constant uint64_t & nb2,
|
875
|
+
constant uint64_t & nb3,
|
876
|
+
constant int & n_past,
|
877
|
+
constant int & n_dims,
|
878
|
+
constant int & mode,
|
879
|
+
constant float & freq_base,
|
880
|
+
constant float & freq_scale,
|
881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
882
|
+
uint3 tptg[[threads_per_threadgroup]],
|
883
|
+
uint3 tgpig[[threadgroup_position_in_grid]]);
|
884
|
+
|
885
|
+
template<typename T>
|
736
886
|
kernel void kernel_rope(
|
737
|
-
device const
|
738
|
-
device
|
739
|
-
|
740
|
-
constant
|
741
|
-
constant
|
742
|
-
constant
|
743
|
-
constant
|
744
|
-
constant
|
745
|
-
constant
|
746
|
-
constant
|
747
|
-
constant
|
748
|
-
constant
|
749
|
-
constant
|
750
|
-
constant
|
751
|
-
constant
|
752
|
-
constant
|
753
|
-
constant
|
754
|
-
constant
|
755
|
-
constant
|
756
|
-
constant
|
757
|
-
constant
|
758
|
-
constant
|
759
|
-
constant
|
887
|
+
device const void * src0,
|
888
|
+
device const int32_t * src1,
|
889
|
+
device float * dst,
|
890
|
+
constant int64_t & ne00,
|
891
|
+
constant int64_t & ne01,
|
892
|
+
constant int64_t & ne02,
|
893
|
+
constant int64_t & ne03,
|
894
|
+
constant uint64_t & nb00,
|
895
|
+
constant uint64_t & nb01,
|
896
|
+
constant uint64_t & nb02,
|
897
|
+
constant uint64_t & nb03,
|
898
|
+
constant int64_t & ne0,
|
899
|
+
constant int64_t & ne1,
|
900
|
+
constant int64_t & ne2,
|
901
|
+
constant int64_t & ne3,
|
902
|
+
constant uint64_t & nb0,
|
903
|
+
constant uint64_t & nb1,
|
904
|
+
constant uint64_t & nb2,
|
905
|
+
constant uint64_t & nb3,
|
906
|
+
constant int & n_past,
|
907
|
+
constant int & n_dims,
|
908
|
+
constant int & mode,
|
909
|
+
constant float & freq_base,
|
910
|
+
constant float & freq_scale,
|
760
911
|
uint tiitg[[thread_index_in_threadgroup]],
|
761
912
|
uint3 tptg[[threads_per_threadgroup]],
|
762
913
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
@@ -766,7 +917,9 @@ kernel void kernel_rope(
|
|
766
917
|
|
767
918
|
const bool is_neox = mode & 2;
|
768
919
|
|
769
|
-
const
|
920
|
+
device const int32_t * pos = src1;
|
921
|
+
|
922
|
+
const int64_t p = pos[i2];
|
770
923
|
|
771
924
|
const float theta_0 = freq_scale * (float)p;
|
772
925
|
const float inv_ndims = -1.f/n_dims;
|
@@ -778,11 +931,11 @@ kernel void kernel_rope(
|
|
778
931
|
const float cos_theta = cos(theta);
|
779
932
|
const float sin_theta = sin(theta);
|
780
933
|
|
781
|
-
device const
|
782
|
-
device
|
934
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
935
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
783
936
|
|
784
|
-
const
|
785
|
-
const
|
937
|
+
const T x0 = src[0];
|
938
|
+
const T x1 = src[1];
|
786
939
|
|
787
940
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
788
941
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
@@ -797,8 +950,8 @@ kernel void kernel_rope(
|
|
797
950
|
|
798
951
|
const int64_t i0 = ib*n_dims + ic/2;
|
799
952
|
|
800
|
-
device const
|
801
|
-
device
|
953
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
954
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
802
955
|
|
803
956
|
const float x0 = src[0];
|
804
957
|
const float x1 = src[n_dims/2];
|
@@ -810,6 +963,9 @@ kernel void kernel_rope(
|
|
810
963
|
}
|
811
964
|
}
|
812
965
|
|
966
|
+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
967
|
+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
968
|
+
|
813
969
|
kernel void kernel_cpy_f16_f16(
|
814
970
|
device const half * src0,
|
815
971
|
device half * dst,
|
@@ -1200,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1200
1356
|
|
1201
1357
|
float yl[32];
|
1202
1358
|
|
1203
|
-
const uint16_t kmask1 = 0x3030;
|
1204
|
-
const uint16_t kmask2 = 0x0f0f;
|
1359
|
+
//const uint16_t kmask1 = 0x3030;
|
1360
|
+
//const uint16_t kmask2 = 0x0f0f;
|
1205
1361
|
|
1206
1362
|
const int tid = tiisg/4;
|
1207
1363
|
const int ix = tiisg%4;
|
@@ -1321,7 +1477,6 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1321
1477
|
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
1322
1478
|
}
|
1323
1479
|
}
|
1324
|
-
|
1325
1480
|
}
|
1326
1481
|
#else
|
1327
1482
|
kernel void kernel_mul_mat_q3_K_f32(
|
@@ -1400,13 +1555,13 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1400
1555
|
device const float * src1,
|
1401
1556
|
device float * dst,
|
1402
1557
|
constant int64_t & ne00,
|
1403
|
-
constant int64_t & ne01[[buffer(4)]],
|
1404
|
-
constant int64_t & ne02[[buffer(5)]],
|
1405
|
-
constant int64_t & ne10[[buffer(9)]],
|
1406
|
-
constant int64_t & ne12[[buffer(11)]],
|
1407
|
-
constant int64_t & ne0[[buffer(15)]],
|
1408
|
-
constant int64_t & ne1[[buffer(16)]],
|
1409
|
-
constant uint & gqa[[buffer(17)]],
|
1558
|
+
constant int64_t & ne01 [[buffer(4)]],
|
1559
|
+
constant int64_t & ne02 [[buffer(5)]],
|
1560
|
+
constant int64_t & ne10 [[buffer(9)]],
|
1561
|
+
constant int64_t & ne12 [[buffer(11)]],
|
1562
|
+
constant int64_t & ne0 [[buffer(15)]],
|
1563
|
+
constant int64_t & ne1 [[buffer(16)]],
|
1564
|
+
constant uint & gqa [[buffer(17)]],
|
1410
1565
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1411
1566
|
uint tiisg[[thread_index_in_simdgroup]],
|
1412
1567
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1865,6 +2020,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1865
2020
|
|
1866
2021
|
//============================= templates and their specializations =============================
|
1867
2022
|
|
2023
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
2024
|
+
template <typename type4x4>
|
2025
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
2026
|
+
float4x4 temp = *(((device float4x4 *)src));
|
2027
|
+
for (int i = 0; i < 16; i++){
|
2028
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
2029
|
+
}
|
2030
|
+
}
|
2031
|
+
|
1868
2032
|
template <typename type4x4>
|
1869
2033
|
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
1870
2034
|
half4x4 temp = *(((device half4x4 *)src));
|
@@ -1875,7 +2039,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
1875
2039
|
|
1876
2040
|
template <typename type4x4>
|
1877
2041
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1878
|
-
|
1879
2042
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1880
2043
|
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1881
2044
|
const float d2 = d1 / 256.f;
|
@@ -1887,12 +2050,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
|
|
1887
2050
|
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
1888
2051
|
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
1889
2052
|
}
|
1890
|
-
|
1891
2053
|
}
|
1892
2054
|
|
1893
2055
|
template <typename type4x4>
|
1894
2056
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
1895
|
-
|
1896
2057
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
1897
2058
|
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1898
2059
|
const float d2 = d1 / 256.f;
|
@@ -1964,7 +2125,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1964
2125
|
for (int i = 0; i < 16; ++i) {
|
1965
2126
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
1966
2127
|
}
|
1967
|
-
|
1968
2128
|
#else
|
1969
2129
|
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
1970
2130
|
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
@@ -2008,7 +2168,6 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
2008
2168
|
for (int i = 0; i < 16; ++i) {
|
2009
2169
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
2010
2170
|
}
|
2011
|
-
|
2012
2171
|
}
|
2013
2172
|
|
2014
2173
|
template <typename type4x4>
|
@@ -2110,22 +2269,25 @@ kernel void kernel_get_rows(
|
|
2110
2269
|
// each block_q contains 16*nl weights
|
2111
2270
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
2112
2271
|
kernel void kernel_mul_mm(device const uchar * src0,
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2116
|
-
|
2117
|
-
|
2118
|
-
|
2119
|
-
|
2120
|
-
|
2121
|
-
|
2122
|
-
|
2123
|
-
|
2124
|
-
|
2125
|
-
|
2126
|
-
|
2127
|
-
|
2128
|
-
|
2272
|
+
device const uchar * src1,
|
2273
|
+
device float * dst,
|
2274
|
+
constant int64_t & ne00,
|
2275
|
+
constant int64_t & ne02,
|
2276
|
+
constant int64_t & nb01,
|
2277
|
+
constant int64_t & nb02,
|
2278
|
+
constant int64_t & ne12,
|
2279
|
+
constant int64_t & nb10,
|
2280
|
+
constant int64_t & nb11,
|
2281
|
+
constant int64_t & nb12,
|
2282
|
+
constant int64_t & ne0,
|
2283
|
+
constant int64_t & ne1,
|
2284
|
+
constant uint & gqa,
|
2285
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
2286
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2287
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
2288
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2289
|
+
|
2290
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
2129
2291
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
2130
2292
|
|
2131
2293
|
const uint r0 = tgpig.y;
|
@@ -2138,7 +2300,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2138
2300
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
2139
2301
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
2140
2302
|
|
2141
|
-
simdgroup_half8x8
|
2303
|
+
simdgroup_half8x8 ma[4];
|
2142
2304
|
simdgroup_float8x8 mb[2];
|
2143
2305
|
simdgroup_float8x8 c_res[8];
|
2144
2306
|
for (int i = 0; i < 8; i++){
|
@@ -2146,10 +2308,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2146
2308
|
}
|
2147
2309
|
|
2148
2310
|
short il = (tiitg % THREAD_PER_ROW);
|
2149
|
-
|
2150
|
-
|
2151
|
-
|
2152
|
-
|
2311
|
+
|
2312
|
+
uint offset0 = im/gqa*nb02;
|
2313
|
+
ushort offset1 = il/nl;
|
2314
|
+
|
2315
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
2316
|
+
device const float * y = (device const float *)(src1
|
2317
|
+
+ nb12 * im
|
2318
|
+
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
2319
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
2153
2320
|
|
2154
2321
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
2155
2322
|
//load data and store to threadgroup memory
|
@@ -2229,6 +2396,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2229
2396
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
2230
2397
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
2231
2398
|
|
2399
|
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
2232
2400
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2233
2401
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2234
2402
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
@@ -2239,14 +2407,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
2239
2407
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
2240
2408
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
2241
2409
|
|
2242
|
-
typedef void (mat_mm_t)(
|
2243
|
-
|
2244
|
-
|
2245
|
-
|
2246
|
-
|
2247
|
-
|
2248
|
-
|
2249
|
-
|
2410
|
+
typedef void (mat_mm_t)(
|
2411
|
+
device const uchar * src0,
|
2412
|
+
device const uchar * src1,
|
2413
|
+
device float * dst,
|
2414
|
+
constant int64_t & ne00,
|
2415
|
+
constant int64_t & ne02,
|
2416
|
+
constant int64_t & nb01,
|
2417
|
+
constant int64_t & nb02,
|
2418
|
+
constant int64_t & ne12,
|
2419
|
+
constant int64_t & nb10,
|
2420
|
+
constant int64_t & nb11,
|
2421
|
+
constant int64_t & nb12,
|
2422
|
+
constant int64_t & ne0,
|
2423
|
+
constant int64_t & ne1,
|
2424
|
+
constant uint & gqa,
|
2425
|
+
threadgroup uchar *, uint3, uint, uint);
|
2426
|
+
|
2427
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
2428
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2429
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2430
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2431
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2250
2432
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2251
2433
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
2252
2434
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
@@ -847,7 +847,7 @@ std::array<std::string, 2> mul_str_values = {
|
|
847
847
|
"mul_f32", "float"
|
848
848
|
};
|
849
849
|
|
850
|
-
std::string& replace(std::string& s, const std::string& from, const std::string& to) {
|
850
|
+
static std::string& replace(std::string& s, const std::string& from, const std::string& to) {
|
851
851
|
size_t pos = 0;
|
852
852
|
while ((pos = s.find(from, pos)) != std::string::npos) {
|
853
853
|
s.replace(pos, from.length(), to);
|
@@ -856,7 +856,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
|
|
856
856
|
return s;
|
857
857
|
}
|
858
858
|
|
859
|
-
std::string generate_kernels() {
|
859
|
+
static std::string generate_kernels() {
|
860
860
|
std::stringstream src;
|
861
861
|
src << program_source << '\n';
|
862
862
|
src << k_quants_source << '\n';
|
@@ -1788,7 +1788,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
|
1788
1788
|
return false;
|
1789
1789
|
}
|
1790
1790
|
|
1791
|
-
bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
|
1791
|
+
static bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
|
1792
1792
|
// If device doesn't support FP16
|
1793
1793
|
if (!fp16_support) {
|
1794
1794
|
return false;
|