llama_cpp 0.5.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/README.md +6 -5
- data/examples/chat.rb +13 -13
- data/examples/embedding.rb +9 -9
- data/ext/llama_cpp/llama_cpp.cpp +547 -272
- data/ext/llama_cpp/src/ggml-alloc.c +14 -8
- data/ext/llama_cpp/src/ggml-alloc.h +1 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +307 -127
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.m +200 -94
- data/ext/llama_cpp/src/ggml-metal.metal +264 -82
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +1647 -865
- data/ext/llama_cpp/src/ggml.h +143 -52
- data/ext/llama_cpp/src/llama.cpp +1427 -635
- data/ext/llama_cpp/src/llama.h +308 -119
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +5 -9
- data/sig/llama_cpp.rbs +65 -34
- metadata +3 -3
@@ -24,12 +24,59 @@ typedef struct {
|
|
24
24
|
int8_t qs[QK8_0]; // quants
|
25
25
|
} block_q8_0;
|
26
26
|
|
27
|
+
// general-purpose kernel for addition of two tensors
|
28
|
+
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
29
|
+
// cons: not very efficient
|
27
30
|
kernel void kernel_add(
|
28
|
-
device const
|
29
|
-
device const
|
30
|
-
device
|
31
|
-
|
32
|
-
|
31
|
+
device const char * src0,
|
32
|
+
device const char * src1,
|
33
|
+
device char * dst,
|
34
|
+
constant int64_t & ne00,
|
35
|
+
constant int64_t & ne01,
|
36
|
+
constant int64_t & ne02,
|
37
|
+
constant int64_t & ne03,
|
38
|
+
constant int64_t & nb00,
|
39
|
+
constant int64_t & nb01,
|
40
|
+
constant int64_t & nb02,
|
41
|
+
constant int64_t & nb03,
|
42
|
+
constant int64_t & ne10,
|
43
|
+
constant int64_t & ne11,
|
44
|
+
constant int64_t & ne12,
|
45
|
+
constant int64_t & ne13,
|
46
|
+
constant int64_t & nb10,
|
47
|
+
constant int64_t & nb11,
|
48
|
+
constant int64_t & nb12,
|
49
|
+
constant int64_t & nb13,
|
50
|
+
constant int64_t & ne0,
|
51
|
+
constant int64_t & ne1,
|
52
|
+
constant int64_t & ne2,
|
53
|
+
constant int64_t & ne3,
|
54
|
+
constant int64_t & nb0,
|
55
|
+
constant int64_t & nb1,
|
56
|
+
constant int64_t & nb2,
|
57
|
+
constant int64_t & nb3,
|
58
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
59
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
60
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
61
|
+
const int64_t i03 = tgpig.z;
|
62
|
+
const int64_t i02 = tgpig.y;
|
63
|
+
const int64_t i01 = tgpig.x;
|
64
|
+
|
65
|
+
const int64_t i13 = i03 % ne13;
|
66
|
+
const int64_t i12 = i02 % ne12;
|
67
|
+
const int64_t i11 = i01 % ne11;
|
68
|
+
|
69
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
70
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
71
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
72
|
+
|
73
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
74
|
+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
|
75
|
+
|
76
|
+
src0_ptr += ntg.x*nb00;
|
77
|
+
src1_ptr += ntg.x*nb10;
|
78
|
+
dst_ptr += ntg.x*nb0;
|
79
|
+
}
|
33
80
|
}
|
34
81
|
|
35
82
|
// assumption: src1 is a row
|
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
|
|
38
85
|
device const float4 * src0,
|
39
86
|
device const float4 * src1,
|
40
87
|
device float4 * dst,
|
41
|
-
constant
|
88
|
+
constant int64_t & nb [[buffer(27)]],
|
42
89
|
uint tpig[[thread_position_in_grid]]) {
|
43
90
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
44
91
|
}
|
@@ -118,7 +165,7 @@ kernel void kernel_soft_max(
|
|
118
165
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
119
166
|
|
120
167
|
// parallel max
|
121
|
-
float lmax = psrc0[tpitg[0]];
|
168
|
+
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
|
122
169
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
123
170
|
lmax = MAX(lmax, psrc0[i00]);
|
124
171
|
}
|
@@ -158,7 +205,7 @@ kernel void kernel_soft_max_4(
|
|
158
205
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
159
206
|
|
160
207
|
// parallel max
|
161
|
-
float4 lmax4 = psrc4[tpitg[0]];
|
208
|
+
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
|
162
209
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
163
210
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
164
211
|
}
|
@@ -523,6 +570,79 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
523
570
|
}
|
524
571
|
}
|
525
572
|
|
573
|
+
#define N_F32_F32 4
|
574
|
+
|
575
|
+
kernel void kernel_mul_mat_f32_f32(
|
576
|
+
device const char * src0,
|
577
|
+
device const char * src1,
|
578
|
+
device float * dst,
|
579
|
+
constant int64_t & ne00,
|
580
|
+
constant int64_t & ne01,
|
581
|
+
constant int64_t & ne02,
|
582
|
+
constant uint64_t & nb00,
|
583
|
+
constant uint64_t & nb01,
|
584
|
+
constant uint64_t & nb02,
|
585
|
+
constant int64_t & ne10,
|
586
|
+
constant int64_t & ne11,
|
587
|
+
constant int64_t & ne12,
|
588
|
+
constant uint64_t & nb10,
|
589
|
+
constant uint64_t & nb11,
|
590
|
+
constant uint64_t & nb12,
|
591
|
+
constant int64_t & ne0,
|
592
|
+
constant int64_t & ne1,
|
593
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
594
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
595
|
+
|
596
|
+
const int64_t r0 = tgpig.x;
|
597
|
+
const int64_t rb = tgpig.y*N_F32_F32;
|
598
|
+
const int64_t im = tgpig.z;
|
599
|
+
|
600
|
+
device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
601
|
+
|
602
|
+
if (ne00 < 128) {
|
603
|
+
for (int row = 0; row < N_F32_F32; ++row) {
|
604
|
+
int r1 = rb + row;
|
605
|
+
if (r1 >= ne11) {
|
606
|
+
break;
|
607
|
+
}
|
608
|
+
|
609
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
610
|
+
|
611
|
+
float sumf = 0;
|
612
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
613
|
+
sumf += (float) x[i] * (float) y[i];
|
614
|
+
}
|
615
|
+
|
616
|
+
float all_sum = simd_sum(sumf);
|
617
|
+
if (tiisg == 0) {
|
618
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
619
|
+
}
|
620
|
+
}
|
621
|
+
} else {
|
622
|
+
device const float4 * x4 = (device const float4 *)x;
|
623
|
+
for (int row = 0; row < N_F32_F32; ++row) {
|
624
|
+
int r1 = rb + row;
|
625
|
+
if (r1 >= ne11) {
|
626
|
+
break;
|
627
|
+
}
|
628
|
+
|
629
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
630
|
+
device const float4 * y4 = (device const float4 *) y;
|
631
|
+
|
632
|
+
float sumf = 0;
|
633
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
634
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
635
|
+
}
|
636
|
+
|
637
|
+
float all_sum = simd_sum(sumf);
|
638
|
+
if (tiisg == 0) {
|
639
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
640
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
641
|
+
}
|
642
|
+
}
|
643
|
+
}
|
644
|
+
}
|
645
|
+
|
526
646
|
kernel void kernel_mul_mat_f16_f32_1row(
|
527
647
|
device const char * src0,
|
528
648
|
device const char * src1,
|
@@ -733,30 +853,61 @@ kernel void kernel_alibi_f32(
|
|
733
853
|
}
|
734
854
|
}
|
735
855
|
|
856
|
+
typedef void (rope_t)(
|
857
|
+
device const void * src0,
|
858
|
+
device const int32_t * src1,
|
859
|
+
device float * dst,
|
860
|
+
constant int64_t & ne00,
|
861
|
+
constant int64_t & ne01,
|
862
|
+
constant int64_t & ne02,
|
863
|
+
constant int64_t & ne03,
|
864
|
+
constant uint64_t & nb00,
|
865
|
+
constant uint64_t & nb01,
|
866
|
+
constant uint64_t & nb02,
|
867
|
+
constant uint64_t & nb03,
|
868
|
+
constant int64_t & ne0,
|
869
|
+
constant int64_t & ne1,
|
870
|
+
constant int64_t & ne2,
|
871
|
+
constant int64_t & ne3,
|
872
|
+
constant uint64_t & nb0,
|
873
|
+
constant uint64_t & nb1,
|
874
|
+
constant uint64_t & nb2,
|
875
|
+
constant uint64_t & nb3,
|
876
|
+
constant int & n_past,
|
877
|
+
constant int & n_dims,
|
878
|
+
constant int & mode,
|
879
|
+
constant float & freq_base,
|
880
|
+
constant float & freq_scale,
|
881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
882
|
+
uint3 tptg[[threads_per_threadgroup]],
|
883
|
+
uint3 tgpig[[threadgroup_position_in_grid]]);
|
884
|
+
|
885
|
+
template<typename T>
|
736
886
|
kernel void kernel_rope(
|
737
|
-
device const
|
738
|
-
device
|
739
|
-
|
740
|
-
constant
|
741
|
-
constant
|
742
|
-
constant
|
743
|
-
constant
|
744
|
-
constant
|
745
|
-
constant
|
746
|
-
constant
|
747
|
-
constant
|
748
|
-
constant
|
749
|
-
constant
|
750
|
-
constant
|
751
|
-
constant
|
752
|
-
constant
|
753
|
-
constant
|
754
|
-
constant
|
755
|
-
constant
|
756
|
-
constant
|
757
|
-
constant
|
758
|
-
constant
|
759
|
-
constant
|
887
|
+
device const void * src0,
|
888
|
+
device const int32_t * src1,
|
889
|
+
device float * dst,
|
890
|
+
constant int64_t & ne00,
|
891
|
+
constant int64_t & ne01,
|
892
|
+
constant int64_t & ne02,
|
893
|
+
constant int64_t & ne03,
|
894
|
+
constant uint64_t & nb00,
|
895
|
+
constant uint64_t & nb01,
|
896
|
+
constant uint64_t & nb02,
|
897
|
+
constant uint64_t & nb03,
|
898
|
+
constant int64_t & ne0,
|
899
|
+
constant int64_t & ne1,
|
900
|
+
constant int64_t & ne2,
|
901
|
+
constant int64_t & ne3,
|
902
|
+
constant uint64_t & nb0,
|
903
|
+
constant uint64_t & nb1,
|
904
|
+
constant uint64_t & nb2,
|
905
|
+
constant uint64_t & nb3,
|
906
|
+
constant int & n_past,
|
907
|
+
constant int & n_dims,
|
908
|
+
constant int & mode,
|
909
|
+
constant float & freq_base,
|
910
|
+
constant float & freq_scale,
|
760
911
|
uint tiitg[[thread_index_in_threadgroup]],
|
761
912
|
uint3 tptg[[threads_per_threadgroup]],
|
762
913
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
@@ -766,7 +917,9 @@ kernel void kernel_rope(
|
|
766
917
|
|
767
918
|
const bool is_neox = mode & 2;
|
768
919
|
|
769
|
-
const
|
920
|
+
device const int32_t * pos = src1;
|
921
|
+
|
922
|
+
const int64_t p = pos[i2];
|
770
923
|
|
771
924
|
const float theta_0 = freq_scale * (float)p;
|
772
925
|
const float inv_ndims = -1.f/n_dims;
|
@@ -778,11 +931,11 @@ kernel void kernel_rope(
|
|
778
931
|
const float cos_theta = cos(theta);
|
779
932
|
const float sin_theta = sin(theta);
|
780
933
|
|
781
|
-
device const
|
782
|
-
device
|
934
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
935
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
783
936
|
|
784
|
-
const
|
785
|
-
const
|
937
|
+
const T x0 = src[0];
|
938
|
+
const T x1 = src[1];
|
786
939
|
|
787
940
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
788
941
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
@@ -797,8 +950,8 @@ kernel void kernel_rope(
|
|
797
950
|
|
798
951
|
const int64_t i0 = ib*n_dims + ic/2;
|
799
952
|
|
800
|
-
device const
|
801
|
-
device
|
953
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
954
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
802
955
|
|
803
956
|
const float x0 = src[0];
|
804
957
|
const float x1 = src[n_dims/2];
|
@@ -810,6 +963,9 @@ kernel void kernel_rope(
|
|
810
963
|
}
|
811
964
|
}
|
812
965
|
|
966
|
+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
967
|
+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
968
|
+
|
813
969
|
kernel void kernel_cpy_f16_f16(
|
814
970
|
device const half * src0,
|
815
971
|
device half * dst,
|
@@ -1200,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1200
1356
|
|
1201
1357
|
float yl[32];
|
1202
1358
|
|
1203
|
-
const uint16_t kmask1 = 0x3030;
|
1204
|
-
const uint16_t kmask2 = 0x0f0f;
|
1359
|
+
//const uint16_t kmask1 = 0x3030;
|
1360
|
+
//const uint16_t kmask2 = 0x0f0f;
|
1205
1361
|
|
1206
1362
|
const int tid = tiisg/4;
|
1207
1363
|
const int ix = tiisg%4;
|
@@ -1321,7 +1477,6 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1321
1477
|
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
1322
1478
|
}
|
1323
1479
|
}
|
1324
|
-
|
1325
1480
|
}
|
1326
1481
|
#else
|
1327
1482
|
kernel void kernel_mul_mat_q3_K_f32(
|
@@ -1400,13 +1555,13 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1400
1555
|
device const float * src1,
|
1401
1556
|
device float * dst,
|
1402
1557
|
constant int64_t & ne00,
|
1403
|
-
constant int64_t & ne01[[buffer(4)]],
|
1404
|
-
constant int64_t & ne02[[buffer(5)]],
|
1405
|
-
constant int64_t & ne10[[buffer(9)]],
|
1406
|
-
constant int64_t & ne12[[buffer(11)]],
|
1407
|
-
constant int64_t & ne0[[buffer(15)]],
|
1408
|
-
constant int64_t & ne1[[buffer(16)]],
|
1409
|
-
constant uint & gqa[[buffer(17)]],
|
1558
|
+
constant int64_t & ne01 [[buffer(4)]],
|
1559
|
+
constant int64_t & ne02 [[buffer(5)]],
|
1560
|
+
constant int64_t & ne10 [[buffer(9)]],
|
1561
|
+
constant int64_t & ne12 [[buffer(11)]],
|
1562
|
+
constant int64_t & ne0 [[buffer(15)]],
|
1563
|
+
constant int64_t & ne1 [[buffer(16)]],
|
1564
|
+
constant uint & gqa [[buffer(17)]],
|
1410
1565
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1411
1566
|
uint tiisg[[thread_index_in_simdgroup]],
|
1412
1567
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1865,6 +2020,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1865
2020
|
|
1866
2021
|
//============================= templates and their specializations =============================
|
1867
2022
|
|
2023
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
2024
|
+
template <typename type4x4>
|
2025
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
2026
|
+
float4x4 temp = *(((device float4x4 *)src));
|
2027
|
+
for (int i = 0; i < 16; i++){
|
2028
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
2029
|
+
}
|
2030
|
+
}
|
2031
|
+
|
1868
2032
|
template <typename type4x4>
|
1869
2033
|
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
1870
2034
|
half4x4 temp = *(((device half4x4 *)src));
|
@@ -1875,7 +2039,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
1875
2039
|
|
1876
2040
|
template <typename type4x4>
|
1877
2041
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1878
|
-
|
1879
2042
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1880
2043
|
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1881
2044
|
const float d2 = d1 / 256.f;
|
@@ -1887,12 +2050,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
|
|
1887
2050
|
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
1888
2051
|
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
1889
2052
|
}
|
1890
|
-
|
1891
2053
|
}
|
1892
2054
|
|
1893
2055
|
template <typename type4x4>
|
1894
2056
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
1895
|
-
|
1896
2057
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
1897
2058
|
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1898
2059
|
const float d2 = d1 / 256.f;
|
@@ -1964,7 +2125,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1964
2125
|
for (int i = 0; i < 16; ++i) {
|
1965
2126
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
1966
2127
|
}
|
1967
|
-
|
1968
2128
|
#else
|
1969
2129
|
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
1970
2130
|
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
@@ -2008,7 +2168,6 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
2008
2168
|
for (int i = 0; i < 16; ++i) {
|
2009
2169
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
2010
2170
|
}
|
2011
|
-
|
2012
2171
|
}
|
2013
2172
|
|
2014
2173
|
template <typename type4x4>
|
@@ -2110,22 +2269,25 @@ kernel void kernel_get_rows(
|
|
2110
2269
|
// each block_q contains 16*nl weights
|
2111
2270
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
2112
2271
|
kernel void kernel_mul_mm(device const uchar * src0,
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2116
|
-
|
2117
|
-
|
2118
|
-
|
2119
|
-
|
2120
|
-
|
2121
|
-
|
2122
|
-
|
2123
|
-
|
2124
|
-
|
2125
|
-
|
2126
|
-
|
2127
|
-
|
2128
|
-
|
2272
|
+
device const uchar * src1,
|
2273
|
+
device float * dst,
|
2274
|
+
constant int64_t & ne00,
|
2275
|
+
constant int64_t & ne02,
|
2276
|
+
constant int64_t & nb01,
|
2277
|
+
constant int64_t & nb02,
|
2278
|
+
constant int64_t & ne12,
|
2279
|
+
constant int64_t & nb10,
|
2280
|
+
constant int64_t & nb11,
|
2281
|
+
constant int64_t & nb12,
|
2282
|
+
constant int64_t & ne0,
|
2283
|
+
constant int64_t & ne1,
|
2284
|
+
constant uint & gqa,
|
2285
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
2286
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2287
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
2288
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2289
|
+
|
2290
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
2129
2291
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
2130
2292
|
|
2131
2293
|
const uint r0 = tgpig.y;
|
@@ -2138,7 +2300,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2138
2300
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
2139
2301
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
2140
2302
|
|
2141
|
-
simdgroup_half8x8
|
2303
|
+
simdgroup_half8x8 ma[4];
|
2142
2304
|
simdgroup_float8x8 mb[2];
|
2143
2305
|
simdgroup_float8x8 c_res[8];
|
2144
2306
|
for (int i = 0; i < 8; i++){
|
@@ -2146,10 +2308,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2146
2308
|
}
|
2147
2309
|
|
2148
2310
|
short il = (tiitg % THREAD_PER_ROW);
|
2149
|
-
|
2150
|
-
|
2151
|
-
|
2152
|
-
|
2311
|
+
|
2312
|
+
uint offset0 = im/gqa*nb02;
|
2313
|
+
ushort offset1 = il/nl;
|
2314
|
+
|
2315
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
2316
|
+
device const float * y = (device const float *)(src1
|
2317
|
+
+ nb12 * im
|
2318
|
+
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
2319
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
2153
2320
|
|
2154
2321
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
2155
2322
|
//load data and store to threadgroup memory
|
@@ -2229,6 +2396,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2229
2396
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
2230
2397
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
2231
2398
|
|
2399
|
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
2232
2400
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2233
2401
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2234
2402
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
@@ -2239,14 +2407,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
2239
2407
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
2240
2408
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
2241
2409
|
|
2242
|
-
typedef void (mat_mm_t)(
|
2243
|
-
|
2244
|
-
|
2245
|
-
|
2246
|
-
|
2247
|
-
|
2248
|
-
|
2249
|
-
|
2410
|
+
typedef void (mat_mm_t)(
|
2411
|
+
device const uchar * src0,
|
2412
|
+
device const uchar * src1,
|
2413
|
+
device float * dst,
|
2414
|
+
constant int64_t & ne00,
|
2415
|
+
constant int64_t & ne02,
|
2416
|
+
constant int64_t & nb01,
|
2417
|
+
constant int64_t & nb02,
|
2418
|
+
constant int64_t & ne12,
|
2419
|
+
constant int64_t & nb10,
|
2420
|
+
constant int64_t & nb11,
|
2421
|
+
constant int64_t & nb12,
|
2422
|
+
constant int64_t & ne0,
|
2423
|
+
constant int64_t & ne1,
|
2424
|
+
constant uint & gqa,
|
2425
|
+
threadgroup uchar *, uint3, uint, uint);
|
2426
|
+
|
2427
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
2428
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2429
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2430
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2431
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2250
2432
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2251
2433
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
2252
2434
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
@@ -847,7 +847,7 @@ std::array<std::string, 2> mul_str_values = {
|
|
847
847
|
"mul_f32", "float"
|
848
848
|
};
|
849
849
|
|
850
|
-
std::string& replace(std::string& s, const std::string& from, const std::string& to) {
|
850
|
+
static std::string& replace(std::string& s, const std::string& from, const std::string& to) {
|
851
851
|
size_t pos = 0;
|
852
852
|
while ((pos = s.find(from, pos)) != std::string::npos) {
|
853
853
|
s.replace(pos, from.length(), to);
|
@@ -856,7 +856,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
|
|
856
856
|
return s;
|
857
857
|
}
|
858
858
|
|
859
|
-
std::string generate_kernels() {
|
859
|
+
static std::string generate_kernels() {
|
860
860
|
std::stringstream src;
|
861
861
|
src << program_source << '\n';
|
862
862
|
src << k_quants_source << '\n';
|
@@ -1788,7 +1788,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
|
1788
1788
|
return false;
|
1789
1789
|
}
|
1790
1790
|
|
1791
|
-
bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
|
1791
|
+
static bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
|
1792
1792
|
// If device doesn't support FP16
|
1793
1793
|
if (!fp16_support) {
|
1794
1794
|
return false;
|