llama_cpp 0.9.5 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/ext/llama_cpp/llama_cpp.cpp +121 -15
- data/ext/llama_cpp/src/ggml-alloc.c +42 -7
- data/ext/llama_cpp/src/ggml-alloc.h +7 -0
- data/ext/llama_cpp/src/ggml-backend-impl.h +46 -21
- data/ext/llama_cpp/src/ggml-backend.c +563 -156
- data/ext/llama_cpp/src/ggml-backend.h +62 -17
- data/ext/llama_cpp/src/ggml-cuda.cu +1140 -355
- data/ext/llama_cpp/src/ggml-cuda.h +9 -1
- data/ext/llama_cpp/src/ggml-impl.h +1 -1
- data/ext/llama_cpp/src/ggml-metal.h +6 -0
- data/ext/llama_cpp/src/ggml-metal.m +506 -158
- data/ext/llama_cpp/src/ggml-metal.metal +795 -144
- data/ext/llama_cpp/src/ggml.c +331 -111
- data/ext/llama_cpp/src/ggml.h +49 -4
- data/ext/llama_cpp/src/llama.cpp +749 -329
- data/ext/llama_cpp/src/llama.h +28 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +20 -2
- metadata +2 -2
@@ -3,6 +3,8 @@
|
|
3
3
|
using namespace metal;
|
4
4
|
|
5
5
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
6
|
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
7
|
+
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
6
8
|
|
7
9
|
#define QK4_0 32
|
8
10
|
#define QR4_0 2
|
@@ -41,8 +43,13 @@ typedef struct {
|
|
41
43
|
|
42
44
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
43
45
|
|
44
|
-
|
45
|
-
|
46
|
+
enum ggml_sort_order {
|
47
|
+
GGML_SORT_ASC,
|
48
|
+
GGML_SORT_DESC,
|
49
|
+
};
|
50
|
+
|
51
|
+
// general-purpose kernel for addition, multiplication and division of two tensors
|
52
|
+
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
46
53
|
// cons: not very efficient
|
47
54
|
kernel void kernel_add(
|
48
55
|
device const char * src0,
|
@@ -83,16 +90,111 @@ kernel void kernel_add(
|
|
83
90
|
const int64_t i12 = i02 % ne12;
|
84
91
|
const int64_t i11 = i01 % ne11;
|
85
92
|
|
86
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01
|
87
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11
|
88
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1
|
93
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
94
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
95
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
89
96
|
|
90
97
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
91
|
-
|
98
|
+
const int i10 = i0 % ne10;
|
99
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
|
100
|
+
}
|
101
|
+
}
|
92
102
|
|
93
|
-
|
94
|
-
|
95
|
-
|
103
|
+
kernel void kernel_mul(
|
104
|
+
device const char * src0,
|
105
|
+
device const char * src1,
|
106
|
+
device char * dst,
|
107
|
+
constant int64_t & ne00,
|
108
|
+
constant int64_t & ne01,
|
109
|
+
constant int64_t & ne02,
|
110
|
+
constant int64_t & ne03,
|
111
|
+
constant int64_t & nb00,
|
112
|
+
constant int64_t & nb01,
|
113
|
+
constant int64_t & nb02,
|
114
|
+
constant int64_t & nb03,
|
115
|
+
constant int64_t & ne10,
|
116
|
+
constant int64_t & ne11,
|
117
|
+
constant int64_t & ne12,
|
118
|
+
constant int64_t & ne13,
|
119
|
+
constant int64_t & nb10,
|
120
|
+
constant int64_t & nb11,
|
121
|
+
constant int64_t & nb12,
|
122
|
+
constant int64_t & nb13,
|
123
|
+
constant int64_t & ne0,
|
124
|
+
constant int64_t & ne1,
|
125
|
+
constant int64_t & ne2,
|
126
|
+
constant int64_t & ne3,
|
127
|
+
constant int64_t & nb0,
|
128
|
+
constant int64_t & nb1,
|
129
|
+
constant int64_t & nb2,
|
130
|
+
constant int64_t & nb3,
|
131
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
132
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
133
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
134
|
+
const int64_t i03 = tgpig.z;
|
135
|
+
const int64_t i02 = tgpig.y;
|
136
|
+
const int64_t i01 = tgpig.x;
|
137
|
+
|
138
|
+
const int64_t i13 = i03 % ne13;
|
139
|
+
const int64_t i12 = i02 % ne12;
|
140
|
+
const int64_t i11 = i01 % ne11;
|
141
|
+
|
142
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
143
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
144
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
145
|
+
|
146
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
147
|
+
const int i10 = i0 % ne10;
|
148
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
|
149
|
+
}
|
150
|
+
}
|
151
|
+
|
152
|
+
kernel void kernel_div(
|
153
|
+
device const char * src0,
|
154
|
+
device const char * src1,
|
155
|
+
device char * dst,
|
156
|
+
constant int64_t & ne00,
|
157
|
+
constant int64_t & ne01,
|
158
|
+
constant int64_t & ne02,
|
159
|
+
constant int64_t & ne03,
|
160
|
+
constant int64_t & nb00,
|
161
|
+
constant int64_t & nb01,
|
162
|
+
constant int64_t & nb02,
|
163
|
+
constant int64_t & nb03,
|
164
|
+
constant int64_t & ne10,
|
165
|
+
constant int64_t & ne11,
|
166
|
+
constant int64_t & ne12,
|
167
|
+
constant int64_t & ne13,
|
168
|
+
constant int64_t & nb10,
|
169
|
+
constant int64_t & nb11,
|
170
|
+
constant int64_t & nb12,
|
171
|
+
constant int64_t & nb13,
|
172
|
+
constant int64_t & ne0,
|
173
|
+
constant int64_t & ne1,
|
174
|
+
constant int64_t & ne2,
|
175
|
+
constant int64_t & ne3,
|
176
|
+
constant int64_t & nb0,
|
177
|
+
constant int64_t & nb1,
|
178
|
+
constant int64_t & nb2,
|
179
|
+
constant int64_t & nb3,
|
180
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
181
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
182
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
183
|
+
const int64_t i03 = tgpig.z;
|
184
|
+
const int64_t i02 = tgpig.y;
|
185
|
+
const int64_t i01 = tgpig.x;
|
186
|
+
|
187
|
+
const int64_t i13 = i03 % ne13;
|
188
|
+
const int64_t i12 = i02 % ne12;
|
189
|
+
const int64_t i11 = i01 % ne11;
|
190
|
+
|
191
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
192
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
193
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
194
|
+
|
195
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
196
|
+
const int i10 = i0 % ne10;
|
197
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
|
96
198
|
}
|
97
199
|
}
|
98
200
|
|
@@ -107,23 +209,22 @@ kernel void kernel_add_row(
|
|
107
209
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
108
210
|
}
|
109
211
|
|
110
|
-
kernel void
|
212
|
+
kernel void kernel_mul_row(
|
111
213
|
device const float4 * src0,
|
112
214
|
device const float4 * src1,
|
113
215
|
device float4 * dst,
|
216
|
+
constant int64_t & nb [[buffer(27)]],
|
114
217
|
uint tpig[[thread_position_in_grid]]) {
|
115
|
-
dst[tpig] = src0[tpig] * src1[tpig];
|
218
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
116
219
|
}
|
117
220
|
|
118
|
-
|
119
|
-
// broadcast src1 into src0
|
120
|
-
kernel void kernel_mul_row(
|
221
|
+
kernel void kernel_div_row(
|
121
222
|
device const float4 * src0,
|
122
223
|
device const float4 * src1,
|
123
224
|
device float4 * dst,
|
124
|
-
constant int64_t & nb,
|
225
|
+
constant int64_t & nb [[buffer(27)]],
|
125
226
|
uint tpig[[thread_position_in_grid]]) {
|
126
|
-
dst[tpig] = src0[tpig]
|
227
|
+
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
127
228
|
}
|
128
229
|
|
129
230
|
kernel void kernel_scale(
|
@@ -164,6 +265,54 @@ kernel void kernel_sqr(
|
|
164
265
|
dst[tpig] = src0[tpig] * src0[tpig];
|
165
266
|
}
|
166
267
|
|
268
|
+
kernel void kernel_sum_rows(
|
269
|
+
device const float * src0,
|
270
|
+
device float * dst,
|
271
|
+
constant int64_t & ne00,
|
272
|
+
constant int64_t & ne01,
|
273
|
+
constant int64_t & ne02,
|
274
|
+
constant int64_t & ne03,
|
275
|
+
constant int64_t & nb00,
|
276
|
+
constant int64_t & nb01,
|
277
|
+
constant int64_t & nb02,
|
278
|
+
constant int64_t & nb03,
|
279
|
+
constant int64_t & ne10,
|
280
|
+
constant int64_t & ne11,
|
281
|
+
constant int64_t & ne12,
|
282
|
+
constant int64_t & ne13,
|
283
|
+
constant int64_t & nb10,
|
284
|
+
constant int64_t & nb11,
|
285
|
+
constant int64_t & nb12,
|
286
|
+
constant int64_t & nb13,
|
287
|
+
constant int64_t & ne0,
|
288
|
+
constant int64_t & ne1,
|
289
|
+
constant int64_t & ne2,
|
290
|
+
constant int64_t & ne3,
|
291
|
+
constant int64_t & nb0,
|
292
|
+
constant int64_t & nb1,
|
293
|
+
constant int64_t & nb2,
|
294
|
+
constant int64_t & nb3,
|
295
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
296
|
+
int64_t i3 = tpig.z;
|
297
|
+
int64_t i2 = tpig.y;
|
298
|
+
int64_t i1 = tpig.x;
|
299
|
+
|
300
|
+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
301
|
+
return;
|
302
|
+
}
|
303
|
+
|
304
|
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
305
|
+
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
306
|
+
|
307
|
+
float row_sum = 0;
|
308
|
+
|
309
|
+
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
310
|
+
row_sum += src_row[i0];
|
311
|
+
}
|
312
|
+
|
313
|
+
dst_row[0] = row_sum;
|
314
|
+
}
|
315
|
+
|
167
316
|
constant float GELU_COEF_A = 0.044715f;
|
168
317
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
169
318
|
|
@@ -582,9 +731,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
582
731
|
// giard against the number of rows not being divisible by
|
583
732
|
// N_DST, so this is another explicit assumption of the implementation.
|
584
733
|
template<typename block_q_type, int nr, int nsg, int nw>
|
585
|
-
void mul_vec_q_n_f32(
|
586
|
-
|
587
|
-
|
734
|
+
void mul_vec_q_n_f32(
|
735
|
+
device const void * src0,
|
736
|
+
device const float * src1,
|
737
|
+
device float * dst,
|
738
|
+
int64_t ne00,
|
739
|
+
int64_t ne01,
|
740
|
+
int64_t ne02,
|
741
|
+
int64_t ne10,
|
742
|
+
int64_t ne12,
|
743
|
+
int64_t ne0,
|
744
|
+
int64_t ne1,
|
745
|
+
uint r2,
|
746
|
+
uint r3,
|
747
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
588
748
|
const int nb = ne00/QK4_0;
|
589
749
|
|
590
750
|
const int r0 = tgpig.x;
|
@@ -593,7 +753,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
593
753
|
|
594
754
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
595
755
|
|
596
|
-
const uint
|
756
|
+
const uint i12 = im%ne12;
|
757
|
+
const uint i13 = im/ne12;
|
758
|
+
|
759
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
597
760
|
|
598
761
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
599
762
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
@@ -643,13 +806,14 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
643
806
|
constant int64_t & ne02[[buffer(5)]],
|
644
807
|
constant int64_t & ne10[[buffer(9)]],
|
645
808
|
constant int64_t & ne12[[buffer(11)]],
|
646
|
-
constant int64_t & ne0[[buffer(15)]],
|
647
|
-
constant int64_t & ne1[[buffer(16)]],
|
648
|
-
constant uint &
|
809
|
+
constant int64_t & ne0 [[buffer(15)]],
|
810
|
+
constant int64_t & ne1 [[buffer(16)]],
|
811
|
+
constant uint & r2 [[buffer(17)]],
|
812
|
+
constant uint & r3 [[buffer(18)]],
|
649
813
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
650
814
|
uint tiisg[[thread_index_in_simdgroup]],
|
651
815
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
652
|
-
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
816
|
+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
653
817
|
}
|
654
818
|
|
655
819
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -661,13 +825,14 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
661
825
|
constant int64_t & ne02[[buffer(5)]],
|
662
826
|
constant int64_t & ne10[[buffer(9)]],
|
663
827
|
constant int64_t & ne12[[buffer(11)]],
|
664
|
-
constant int64_t & ne0[[buffer(15)]],
|
665
|
-
constant int64_t & ne1[[buffer(16)]],
|
666
|
-
constant uint &
|
828
|
+
constant int64_t & ne0 [[buffer(15)]],
|
829
|
+
constant int64_t & ne1 [[buffer(16)]],
|
830
|
+
constant uint & r2 [[buffer(17)]],
|
831
|
+
constant uint & r3 [[buffer(18)]],
|
667
832
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
668
833
|
uint tiisg[[thread_index_in_simdgroup]],
|
669
834
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
670
|
-
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
835
|
+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
671
836
|
}
|
672
837
|
|
673
838
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -679,13 +844,14 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
679
844
|
constant int64_t & ne02[[buffer(5)]],
|
680
845
|
constant int64_t & ne10[[buffer(9)]],
|
681
846
|
constant int64_t & ne12[[buffer(11)]],
|
682
|
-
constant int64_t & ne0[[buffer(15)]],
|
683
|
-
constant int64_t & ne1[[buffer(16)]],
|
684
|
-
constant uint &
|
847
|
+
constant int64_t & ne0 [[buffer(15)]],
|
848
|
+
constant int64_t & ne1 [[buffer(16)]],
|
849
|
+
constant uint & r2 [[buffer(17)]],
|
850
|
+
constant uint & r3 [[buffer(18)]],
|
685
851
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
686
852
|
uint tiisg[[thread_index_in_simdgroup]],
|
687
853
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
688
|
-
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
854
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
689
855
|
}
|
690
856
|
|
691
857
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -697,13 +863,14 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
697
863
|
constant int64_t & ne02[[buffer(5)]],
|
698
864
|
constant int64_t & ne10[[buffer(9)]],
|
699
865
|
constant int64_t & ne12[[buffer(11)]],
|
700
|
-
constant int64_t & ne0[[buffer(15)]],
|
701
|
-
constant int64_t & ne1[[buffer(16)]],
|
702
|
-
constant uint &
|
866
|
+
constant int64_t & ne0 [[buffer(15)]],
|
867
|
+
constant int64_t & ne1 [[buffer(16)]],
|
868
|
+
constant uint & r2 [[buffer(17)]],
|
869
|
+
constant uint & r3 [[buffer(18)]],
|
703
870
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
704
871
|
uint tiisg[[thread_index_in_simdgroup]],
|
705
872
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
706
|
-
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
873
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
707
874
|
}
|
708
875
|
|
709
876
|
|
@@ -718,9 +885,10 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
718
885
|
constant int64_t & ne02[[buffer(5)]],
|
719
886
|
constant int64_t & ne10[[buffer(9)]],
|
720
887
|
constant int64_t & ne12[[buffer(11)]],
|
721
|
-
constant int64_t & ne0[[buffer(15)]],
|
722
|
-
constant int64_t & ne1[[buffer(16)]],
|
723
|
-
constant uint &
|
888
|
+
constant int64_t & ne0 [[buffer(15)]],
|
889
|
+
constant int64_t & ne1 [[buffer(16)]],
|
890
|
+
constant uint & r2 [[buffer(17)]],
|
891
|
+
constant uint & r3 [[buffer(18)]],
|
724
892
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
725
893
|
uint tiisg[[thread_index_in_simdgroup]],
|
726
894
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -732,8 +900,14 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
732
900
|
const int r0 = tgpig.x;
|
733
901
|
const int r1 = tgpig.y;
|
734
902
|
const int im = tgpig.z;
|
903
|
+
|
735
904
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
736
|
-
|
905
|
+
|
906
|
+
const uint i12 = im%ne12;
|
907
|
+
const uint i13 = im/ne12;
|
908
|
+
|
909
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
910
|
+
|
737
911
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
738
912
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
739
913
|
|
@@ -791,6 +965,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
791
965
|
constant uint64_t & nb12,
|
792
966
|
constant int64_t & ne0,
|
793
967
|
constant int64_t & ne1,
|
968
|
+
constant uint & r2 [[buffer(17)]],
|
969
|
+
constant uint & r3 [[buffer(18)]],
|
794
970
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
795
971
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
796
972
|
|
@@ -798,7 +974,12 @@ kernel void kernel_mul_mv_f32_f32(
|
|
798
974
|
const int64_t rb = tgpig.y*N_F32_F32;
|
799
975
|
const int64_t im = tgpig.z;
|
800
976
|
|
801
|
-
|
977
|
+
const uint i12 = im%ne12;
|
978
|
+
const uint i13 = im/ne12;
|
979
|
+
|
980
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
981
|
+
|
982
|
+
device const float * x = (device const float *) (src0 + offset0);
|
802
983
|
|
803
984
|
if (ne00 < 128) {
|
804
985
|
for (int row = 0; row < N_F32_F32; ++row) {
|
@@ -864,6 +1045,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
864
1045
|
constant uint64_t & nb12,
|
865
1046
|
constant int64_t & ne0,
|
866
1047
|
constant int64_t & ne1,
|
1048
|
+
constant uint & r2 [[buffer(17)]],
|
1049
|
+
constant uint & r3 [[buffer(18)]],
|
867
1050
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
868
1051
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
869
1052
|
|
@@ -871,7 +1054,12 @@ kernel void kernel_mul_mv_f16_f16(
|
|
871
1054
|
const int64_t rb = tgpig.y*N_F16_F16;
|
872
1055
|
const int64_t im = tgpig.z;
|
873
1056
|
|
874
|
-
|
1057
|
+
const uint i12 = im%ne12;
|
1058
|
+
const uint i13 = im/ne12;
|
1059
|
+
|
1060
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1061
|
+
|
1062
|
+
device const half * x = (device const half *) (src0 + offset0);
|
875
1063
|
|
876
1064
|
if (ne00 < 128) {
|
877
1065
|
for (int row = 0; row < N_F16_F16; ++row) {
|
@@ -935,6 +1123,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
935
1123
|
constant uint64_t & nb12,
|
936
1124
|
constant int64_t & ne0,
|
937
1125
|
constant int64_t & ne1,
|
1126
|
+
constant uint & r2 [[buffer(17)]],
|
1127
|
+
constant uint & r3 [[buffer(18)]],
|
938
1128
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
939
1129
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
940
1130
|
|
@@ -942,7 +1132,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
942
1132
|
const int64_t r1 = tgpig.y;
|
943
1133
|
const int64_t im = tgpig.z;
|
944
1134
|
|
945
|
-
|
1135
|
+
const uint i12 = im%ne12;
|
1136
|
+
const uint i13 = im/ne12;
|
1137
|
+
|
1138
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1139
|
+
|
1140
|
+
device const half * x = (device const half *) (src0 + offset0);
|
946
1141
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
947
1142
|
|
948
1143
|
float sumf = 0;
|
@@ -989,6 +1184,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
989
1184
|
constant uint64_t & nb12,
|
990
1185
|
constant int64_t & ne0,
|
991
1186
|
constant int64_t & ne1,
|
1187
|
+
constant uint & r2 [[buffer(17)]],
|
1188
|
+
constant uint & r3 [[buffer(18)]],
|
992
1189
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
993
1190
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
994
1191
|
|
@@ -996,7 +1193,12 @@ kernel void kernel_mul_mv_f16_f32(
|
|
996
1193
|
const int64_t rb = tgpig.y*N_F16_F32;
|
997
1194
|
const int64_t im = tgpig.z;
|
998
1195
|
|
999
|
-
|
1196
|
+
const uint i12 = im%ne12;
|
1197
|
+
const uint i13 = im/ne12;
|
1198
|
+
|
1199
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1200
|
+
|
1201
|
+
device const half * x = (device const half *) (src0 + offset0);
|
1000
1202
|
|
1001
1203
|
if (ne00 < 128) {
|
1002
1204
|
for (int row = 0; row < N_F16_F32; ++row) {
|
@@ -1061,6 +1263,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1061
1263
|
constant uint64_t & nb12,
|
1062
1264
|
constant int64_t & ne0,
|
1063
1265
|
constant int64_t & ne1,
|
1266
|
+
constant uint & r2 [[buffer(17)]],
|
1267
|
+
constant uint & r3 [[buffer(18)]],
|
1064
1268
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1065
1269
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1066
1270
|
|
@@ -1068,7 +1272,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1068
1272
|
const int64_t r0 = tgpig.x;
|
1069
1273
|
const int64_t im = tgpig.z;
|
1070
1274
|
|
1071
|
-
|
1275
|
+
const uint i12 = im%ne12;
|
1276
|
+
const uint i13 = im/ne12;
|
1277
|
+
|
1278
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1279
|
+
|
1280
|
+
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
1072
1281
|
|
1073
1282
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
1074
1283
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
@@ -1120,17 +1329,21 @@ kernel void kernel_alibi_f32(
|
|
1120
1329
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1121
1330
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1122
1331
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1332
|
+
const int64_t k = i3*ne3 + i2;
|
1123
1333
|
|
1124
|
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1125
1334
|
float m_k;
|
1126
|
-
if (
|
1127
|
-
m_k = pow(m0,
|
1335
|
+
if (k < n_heads_log2_floor) {
|
1336
|
+
m_k = pow(m0, k + 1);
|
1128
1337
|
} else {
|
1129
|
-
m_k = pow(m1, 2 * (
|
1338
|
+
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
1130
1339
|
}
|
1340
|
+
|
1341
|
+
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
1342
|
+
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
1131
1343
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1132
|
-
|
1133
|
-
|
1344
|
+
const float src_v = *(device float *)(src_row + i00*nb00);
|
1345
|
+
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
1346
|
+
*dst_v = i00 * m_k + src_v;
|
1134
1347
|
}
|
1135
1348
|
}
|
1136
1349
|
|
@@ -1335,6 +1548,58 @@ kernel void kernel_im2col_f16(
|
|
1335
1548
|
}
|
1336
1549
|
}
|
1337
1550
|
|
1551
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
1552
|
+
typedef void (argsort_t)(
|
1553
|
+
device const float * x,
|
1554
|
+
device int32_t * dst,
|
1555
|
+
constant int64_t & ncols,
|
1556
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1557
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1558
|
+
|
1559
|
+
template<ggml_sort_order order>
|
1560
|
+
kernel void kernel_argsort_f32_i32(
|
1561
|
+
device const float * x,
|
1562
|
+
device int32_t * dst,
|
1563
|
+
constant int64_t & ncols,
|
1564
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1565
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1566
|
+
// bitonic sort
|
1567
|
+
int col = tpitg[0];
|
1568
|
+
int row = tgpig[1];
|
1569
|
+
|
1570
|
+
if (col >= ncols) return;
|
1571
|
+
|
1572
|
+
device const float * x_row = x + row * ncols;
|
1573
|
+
device int32_t * dst_row = dst + row * ncols;
|
1574
|
+
|
1575
|
+
// initialize indices
|
1576
|
+
if (col < ncols) {
|
1577
|
+
dst_row[col] = col;
|
1578
|
+
}
|
1579
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1580
|
+
|
1581
|
+
for (int k = 2; k <= ncols; k *= 2) {
|
1582
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
1583
|
+
int ixj = col ^ j;
|
1584
|
+
if (ixj > col) {
|
1585
|
+
if ((col & k) == 0) {
|
1586
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
1587
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
1588
|
+
}
|
1589
|
+
} else {
|
1590
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
1591
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
1592
|
+
}
|
1593
|
+
}
|
1594
|
+
}
|
1595
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1596
|
+
}
|
1597
|
+
}
|
1598
|
+
}
|
1599
|
+
|
1600
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
1601
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
1602
|
+
|
1338
1603
|
kernel void kernel_cpy_f16_f16(
|
1339
1604
|
device const half * src0,
|
1340
1605
|
device half * dst,
|
@@ -1460,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
|
|
1460
1725
|
}
|
1461
1726
|
}
|
1462
1727
|
|
1728
|
+
kernel void kernel_cpy_f32_q8_0(
|
1729
|
+
device const float * src0,
|
1730
|
+
device void * dst,
|
1731
|
+
constant int64_t & ne00,
|
1732
|
+
constant int64_t & ne01,
|
1733
|
+
constant int64_t & ne02,
|
1734
|
+
constant int64_t & ne03,
|
1735
|
+
constant uint64_t & nb00,
|
1736
|
+
constant uint64_t & nb01,
|
1737
|
+
constant uint64_t & nb02,
|
1738
|
+
constant uint64_t & nb03,
|
1739
|
+
constant int64_t & ne0,
|
1740
|
+
constant int64_t & ne1,
|
1741
|
+
constant int64_t & ne2,
|
1742
|
+
constant int64_t & ne3,
|
1743
|
+
constant uint64_t & nb0,
|
1744
|
+
constant uint64_t & nb1,
|
1745
|
+
constant uint64_t & nb2,
|
1746
|
+
constant uint64_t & nb3,
|
1747
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1748
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1749
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1750
|
+
const int64_t i03 = tgpig[2];
|
1751
|
+
const int64_t i02 = tgpig[1];
|
1752
|
+
const int64_t i01 = tgpig[0];
|
1753
|
+
|
1754
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1755
|
+
|
1756
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1757
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1758
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1759
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
1760
|
+
|
1761
|
+
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1762
|
+
|
1763
|
+
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
1764
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1765
|
+
|
1766
|
+
float amax = 0.0f; // absolute max
|
1767
|
+
|
1768
|
+
for (int j = 0; j < QK8_0; j++) {
|
1769
|
+
const float v = src[j];
|
1770
|
+
amax = MAX(amax, fabs(v));
|
1771
|
+
}
|
1772
|
+
|
1773
|
+
const float d = amax / ((1 << 7) - 1);
|
1774
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1775
|
+
|
1776
|
+
dst_data[i00/QK8_0].d = d;
|
1777
|
+
|
1778
|
+
for (int j = 0; j < QK8_0; ++j) {
|
1779
|
+
const float x0 = src[j]*id;
|
1780
|
+
|
1781
|
+
dst_data[i00/QK8_0].qs[j] = round(x0);
|
1782
|
+
}
|
1783
|
+
}
|
1784
|
+
}
|
1785
|
+
|
1786
|
+
kernel void kernel_cpy_f32_q4_0(
|
1787
|
+
device const float * src0,
|
1788
|
+
device void * dst,
|
1789
|
+
constant int64_t & ne00,
|
1790
|
+
constant int64_t & ne01,
|
1791
|
+
constant int64_t & ne02,
|
1792
|
+
constant int64_t & ne03,
|
1793
|
+
constant uint64_t & nb00,
|
1794
|
+
constant uint64_t & nb01,
|
1795
|
+
constant uint64_t & nb02,
|
1796
|
+
constant uint64_t & nb03,
|
1797
|
+
constant int64_t & ne0,
|
1798
|
+
constant int64_t & ne1,
|
1799
|
+
constant int64_t & ne2,
|
1800
|
+
constant int64_t & ne3,
|
1801
|
+
constant uint64_t & nb0,
|
1802
|
+
constant uint64_t & nb1,
|
1803
|
+
constant uint64_t & nb2,
|
1804
|
+
constant uint64_t & nb3,
|
1805
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1806
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1807
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1808
|
+
const int64_t i03 = tgpig[2];
|
1809
|
+
const int64_t i02 = tgpig[1];
|
1810
|
+
const int64_t i01 = tgpig[0];
|
1811
|
+
|
1812
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1813
|
+
|
1814
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1815
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1816
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1817
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
|
1818
|
+
|
1819
|
+
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1820
|
+
|
1821
|
+
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
|
1822
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1823
|
+
|
1824
|
+
float amax = 0.0f; // absolute max
|
1825
|
+
float max = 0.0f;
|
1826
|
+
|
1827
|
+
for (int j = 0; j < QK4_0; j++) {
|
1828
|
+
const float v = src[j];
|
1829
|
+
if (amax < fabs(v)) {
|
1830
|
+
amax = fabs(v);
|
1831
|
+
max = v;
|
1832
|
+
}
|
1833
|
+
}
|
1834
|
+
|
1835
|
+
const float d = max / -8;
|
1836
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1837
|
+
|
1838
|
+
dst_data[i00/QK4_0].d = d;
|
1839
|
+
|
1840
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
1841
|
+
const float x0 = src[0 + j]*id;
|
1842
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
1843
|
+
|
1844
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
1845
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
1846
|
+
|
1847
|
+
dst_data[i00/QK4_0].qs[j] = xi0;
|
1848
|
+
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
1849
|
+
}
|
1850
|
+
}
|
1851
|
+
}
|
1852
|
+
|
1853
|
+
kernel void kernel_cpy_f32_q4_1(
|
1854
|
+
device const float * src0,
|
1855
|
+
device void * dst,
|
1856
|
+
constant int64_t & ne00,
|
1857
|
+
constant int64_t & ne01,
|
1858
|
+
constant int64_t & ne02,
|
1859
|
+
constant int64_t & ne03,
|
1860
|
+
constant uint64_t & nb00,
|
1861
|
+
constant uint64_t & nb01,
|
1862
|
+
constant uint64_t & nb02,
|
1863
|
+
constant uint64_t & nb03,
|
1864
|
+
constant int64_t & ne0,
|
1865
|
+
constant int64_t & ne1,
|
1866
|
+
constant int64_t & ne2,
|
1867
|
+
constant int64_t & ne3,
|
1868
|
+
constant uint64_t & nb0,
|
1869
|
+
constant uint64_t & nb1,
|
1870
|
+
constant uint64_t & nb2,
|
1871
|
+
constant uint64_t & nb3,
|
1872
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1873
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1874
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1875
|
+
const int64_t i03 = tgpig[2];
|
1876
|
+
const int64_t i02 = tgpig[1];
|
1877
|
+
const int64_t i01 = tgpig[0];
|
1878
|
+
|
1879
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1880
|
+
|
1881
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1882
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1883
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1884
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
|
1885
|
+
|
1886
|
+
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1887
|
+
|
1888
|
+
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
|
1889
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1890
|
+
|
1891
|
+
float min = FLT_MAX;
|
1892
|
+
float max = -FLT_MAX;
|
1893
|
+
|
1894
|
+
for (int j = 0; j < QK4_1; j++) {
|
1895
|
+
const float v = src[j];
|
1896
|
+
if (min > v) min = v;
|
1897
|
+
if (max < v) max = v;
|
1898
|
+
}
|
1899
|
+
|
1900
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
1901
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1902
|
+
|
1903
|
+
dst_data[i00/QK4_1].d = d;
|
1904
|
+
dst_data[i00/QK4_1].m = min;
|
1905
|
+
|
1906
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
1907
|
+
const float x0 = (src[0 + j] - min)*id;
|
1908
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
1909
|
+
|
1910
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
1911
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
1912
|
+
|
1913
|
+
dst_data[i00/QK4_1].qs[j] = xi0;
|
1914
|
+
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
1915
|
+
}
|
1916
|
+
}
|
1917
|
+
}
|
1918
|
+
|
1463
1919
|
kernel void kernel_concat(
|
1464
1920
|
device const char * src0,
|
1465
1921
|
device const char * src1,
|
@@ -1617,23 +2073,30 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1617
2073
|
constant int64_t & ne02[[buffer(5)]],
|
1618
2074
|
constant int64_t & ne10[[buffer(9)]],
|
1619
2075
|
constant int64_t & ne12[[buffer(11)]],
|
1620
|
-
constant int64_t & ne0[[buffer(15)]],
|
1621
|
-
constant int64_t & ne1[[buffer(16)]],
|
1622
|
-
constant uint &
|
2076
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2077
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2078
|
+
constant uint & r2 [[buffer(17)]],
|
2079
|
+
constant uint & r3 [[buffer(18)]],
|
1623
2080
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1624
|
-
uint
|
1625
|
-
uint
|
2081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1626
2083
|
|
1627
2084
|
const int nb = ne00/QK_K;
|
1628
2085
|
const int r0 = tgpig.x;
|
1629
2086
|
const int r1 = tgpig.y;
|
1630
|
-
const int
|
2087
|
+
const int im = tgpig.z;
|
1631
2088
|
|
1632
2089
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1633
2090
|
const int ib_row = first_row * nb;
|
1634
|
-
|
2091
|
+
|
2092
|
+
const uint i12 = im%ne12;
|
2093
|
+
const uint i13 = im/ne12;
|
2094
|
+
|
2095
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2096
|
+
|
1635
2097
|
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
1636
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2098
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2099
|
+
|
1637
2100
|
float yl[32];
|
1638
2101
|
float sumf[N_DST]={0.f}, all_sum;
|
1639
2102
|
|
@@ -1642,11 +2105,11 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1642
2105
|
#if QK_K == 256
|
1643
2106
|
const int ix = tiisg/8; // 0...3
|
1644
2107
|
const int it = tiisg%8; // 0...7
|
1645
|
-
const int
|
2108
|
+
const int iq = it/4; // 0 or 1
|
1646
2109
|
const int ir = it%4; // 0...3
|
1647
2110
|
const int is = (8*ir)/16;// 0 or 1
|
1648
2111
|
|
1649
|
-
device const float * y4 = y + ix * QK_K + 128 *
|
2112
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
1650
2113
|
|
1651
2114
|
for (int ib = ix; ib < nb; ib += 4) {
|
1652
2115
|
|
@@ -1658,8 +2121,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1658
2121
|
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
1659
2122
|
}
|
1660
2123
|
|
1661
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*
|
1662
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 *
|
2124
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
2125
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
1663
2126
|
device const half * dh = &x[ib].d;
|
1664
2127
|
|
1665
2128
|
for (int row = 0; row < N_DST; row++) {
|
@@ -1746,7 +2209,7 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1746
2209
|
for (int row = 0; row < N_DST; ++row) {
|
1747
2210
|
all_sum = simd_sum(sumf[row]);
|
1748
2211
|
if (tiisg == 0) {
|
1749
|
-
dst[r1*ne0 +
|
2212
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
1750
2213
|
}
|
1751
2214
|
}
|
1752
2215
|
}
|
@@ -1761,9 +2224,10 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1761
2224
|
constant int64_t & ne02[[buffer(5)]],
|
1762
2225
|
constant int64_t & ne10[[buffer(9)]],
|
1763
2226
|
constant int64_t & ne12[[buffer(11)]],
|
1764
|
-
constant int64_t & ne0[[buffer(15)]],
|
1765
|
-
constant int64_t & ne1[[buffer(16)]],
|
1766
|
-
constant uint &
|
2227
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2228
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2229
|
+
constant uint & r2 [[buffer(17)]],
|
2230
|
+
constant uint & r3 [[buffer(18)]],
|
1767
2231
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1768
2232
|
uint tiisg[[thread_index_in_simdgroup]],
|
1769
2233
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1772,12 +2236,17 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1772
2236
|
|
1773
2237
|
const int64_t r0 = tgpig.x;
|
1774
2238
|
const int64_t r1 = tgpig.y;
|
1775
|
-
const int64_t
|
2239
|
+
const int64_t im = tgpig.z;
|
1776
2240
|
|
1777
2241
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1778
|
-
|
2242
|
+
|
2243
|
+
const uint i12 = im%ne12;
|
2244
|
+
const uint i13 = im/ne12;
|
2245
|
+
|
2246
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2247
|
+
|
1779
2248
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1780
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2249
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
1781
2250
|
|
1782
2251
|
float yl[32];
|
1783
2252
|
|
@@ -1899,7 +2368,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1899
2368
|
}
|
1900
2369
|
if (tiisg == 0) {
|
1901
2370
|
for (int row = 0; row < 2; ++row) {
|
1902
|
-
dst[r1*ne0 +
|
2371
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
|
1903
2372
|
}
|
1904
2373
|
}
|
1905
2374
|
}
|
@@ -1913,26 +2382,33 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1913
2382
|
constant int64_t & ne02[[buffer(5)]],
|
1914
2383
|
constant int64_t & ne10[[buffer(9)]],
|
1915
2384
|
constant int64_t & ne12[[buffer(11)]],
|
1916
|
-
constant int64_t & ne0[[buffer(15)]],
|
1917
|
-
constant int64_t & ne1[[buffer(16)]],
|
1918
|
-
constant uint &
|
2385
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2386
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2387
|
+
constant uint & r2 [[buffer(17)]],
|
2388
|
+
constant uint & r3 [[buffer(18)]],
|
1919
2389
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1920
|
-
uint
|
1921
|
-
uint
|
2390
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2391
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1922
2392
|
|
1923
2393
|
const int nb = ne00/QK_K;
|
1924
2394
|
|
1925
2395
|
const int64_t r0 = tgpig.x;
|
1926
2396
|
const int64_t r1 = tgpig.y;
|
1927
|
-
const int64_t
|
2397
|
+
const int64_t im = tgpig.z;
|
1928
2398
|
|
1929
2399
|
const int row = 2 * r0 + sgitg;
|
1930
|
-
|
2400
|
+
|
2401
|
+
const uint i12 = im%ne12;
|
2402
|
+
const uint i13 = im/ne12;
|
2403
|
+
|
2404
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2405
|
+
|
1931
2406
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
1932
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2407
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2408
|
+
|
1933
2409
|
const int ix = tiisg/4;
|
1934
2410
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1935
|
-
const int
|
2411
|
+
const int iq = il/8; // 0, 0, 1, 1
|
1936
2412
|
const int in = il%8; // 0, 4, 0, 4
|
1937
2413
|
|
1938
2414
|
float2 sum = {0.f, 0.f};
|
@@ -1952,7 +2428,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1952
2428
|
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
1953
2429
|
|
1954
2430
|
for (int l = 0; l < 4; l += 2) {
|
1955
|
-
const uint16_t hm = h[l/2] >>
|
2431
|
+
const uint16_t hm = h[l/2] >> iq;
|
1956
2432
|
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
1957
2433
|
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
1958
2434
|
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
@@ -1968,7 +2444,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1968
2444
|
|
1969
2445
|
const float tot = simd_sum(sumf);
|
1970
2446
|
if (tiisg == 0) {
|
1971
|
-
dst[r1*ne0 +
|
2447
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
1972
2448
|
}
|
1973
2449
|
|
1974
2450
|
}
|
@@ -1986,10 +2462,11 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
1986
2462
|
constant int64_t & ne12 [[buffer(11)]],
|
1987
2463
|
constant int64_t & ne0 [[buffer(15)]],
|
1988
2464
|
constant int64_t & ne1 [[buffer(16)]],
|
1989
|
-
constant uint &
|
2465
|
+
constant uint & r2 [[buffer(17)]],
|
2466
|
+
constant uint & r3 [[buffer(18)]],
|
1990
2467
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1991
|
-
uint
|
1992
|
-
uint
|
2468
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2469
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1993
2470
|
|
1994
2471
|
const uint16_t kmask1 = 0x3f3f;
|
1995
2472
|
const uint16_t kmask2 = 0x0f0f;
|
@@ -1997,26 +2474,32 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
1997
2474
|
|
1998
2475
|
const int ix = tiisg/8; // 0...3
|
1999
2476
|
const int it = tiisg%8; // 0...7
|
2000
|
-
const int
|
2477
|
+
const int iq = it/4; // 0 or 1
|
2001
2478
|
const int ir = it%4; // 0...3
|
2002
2479
|
|
2003
2480
|
const int nb = ne00/QK_K;
|
2004
2481
|
const int r0 = tgpig.x;
|
2005
2482
|
const int r1 = tgpig.y;
|
2006
|
-
const int
|
2483
|
+
const int im = tgpig.z;
|
2007
2484
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
2008
2485
|
const int first_row = r0 * N_DST;
|
2009
2486
|
const int ib_row = first_row * nb;
|
2010
|
-
|
2487
|
+
|
2488
|
+
const uint i12 = im%ne12;
|
2489
|
+
const uint i13 = im/ne12;
|
2490
|
+
|
2491
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2492
|
+
|
2011
2493
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
2012
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2494
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2495
|
+
|
2013
2496
|
float yl[16];
|
2014
2497
|
float yh[16];
|
2015
2498
|
float sumf[N_DST]={0.f}, all_sum;
|
2016
2499
|
|
2017
2500
|
const int step = sizeof(block_q4_K) * nb / 2;
|
2018
2501
|
|
2019
|
-
device const float * y4 = y + ix * QK_K + 64 *
|
2502
|
+
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
2020
2503
|
|
2021
2504
|
uint16_t sc16[4];
|
2022
2505
|
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
@@ -2031,8 +2514,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2031
2514
|
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
2032
2515
|
}
|
2033
2516
|
|
2034
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales +
|
2035
|
-
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 *
|
2517
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
|
2518
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
2036
2519
|
device const half * dh = &x[ib].d;
|
2037
2520
|
|
2038
2521
|
for (int row = 0; row < N_DST; row++) {
|
@@ -2076,7 +2559,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2076
2559
|
for (int row = 0; row < N_DST; ++row) {
|
2077
2560
|
all_sum = simd_sum(sumf[row]);
|
2078
2561
|
if (tiisg == 0) {
|
2079
|
-
dst[r1*ne0 +
|
2562
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
2080
2563
|
}
|
2081
2564
|
}
|
2082
2565
|
}
|
@@ -2090,9 +2573,10 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2090
2573
|
constant int64_t & ne02[[buffer(5)]],
|
2091
2574
|
constant int64_t & ne10[[buffer(9)]],
|
2092
2575
|
constant int64_t & ne12[[buffer(11)]],
|
2093
|
-
constant int64_t & ne0[[buffer(15)]],
|
2094
|
-
constant int64_t & ne1[[buffer(16)]],
|
2095
|
-
constant uint &
|
2576
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2577
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2578
|
+
constant uint & r2 [[buffer(17)]],
|
2579
|
+
constant uint & r3 [[buffer(18)]],
|
2096
2580
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2097
2581
|
uint tiisg[[thread_index_in_simdgroup]],
|
2098
2582
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2103,12 +2587,18 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2103
2587
|
const int nb = ne00/QK_K;
|
2104
2588
|
const int r0 = tgpig.x;
|
2105
2589
|
const int r1 = tgpig.y;
|
2106
|
-
const int
|
2590
|
+
const int im = tgpig.z;
|
2107
2591
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
2108
2592
|
const int ib_row = first_row * nb;
|
2109
|
-
|
2593
|
+
|
2594
|
+
const uint i12 = im%ne12;
|
2595
|
+
const uint i13 = im/ne12;
|
2596
|
+
|
2597
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2598
|
+
|
2110
2599
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
2111
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2600
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2601
|
+
|
2112
2602
|
float yl[8];
|
2113
2603
|
float yh[8];
|
2114
2604
|
float sumf[N_DST]={0.f}, all_sum;
|
@@ -2164,7 +2654,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2164
2654
|
for (int row = 0; row < N_DST; ++row) {
|
2165
2655
|
all_sum = simd_sum(sumf[row]);
|
2166
2656
|
if (tiisg == 0) {
|
2167
|
-
dst[r1*ne0+
|
2657
|
+
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
2168
2658
|
}
|
2169
2659
|
}
|
2170
2660
|
}
|
@@ -2179,9 +2669,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2179
2669
|
constant int64_t & ne02[[buffer(5)]],
|
2180
2670
|
constant int64_t & ne10[[buffer(9)]],
|
2181
2671
|
constant int64_t & ne12[[buffer(11)]],
|
2182
|
-
constant int64_t & ne0[[buffer(15)]],
|
2183
|
-
constant int64_t & ne1[[buffer(16)]],
|
2184
|
-
constant uint &
|
2672
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2673
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2674
|
+
constant uint & r2 [[buffer(17)]],
|
2675
|
+
constant uint & r3 [[buffer(18)]],
|
2185
2676
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2186
2677
|
uint tiisg[[thread_index_in_simdgroup]],
|
2187
2678
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2190,12 +2681,17 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2190
2681
|
|
2191
2682
|
const int64_t r0 = tgpig.x;
|
2192
2683
|
const int64_t r1 = tgpig.y;
|
2193
|
-
const int
|
2684
|
+
const int im = tgpig.z;
|
2194
2685
|
|
2195
2686
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
2196
|
-
|
2687
|
+
|
2688
|
+
const uint i12 = im%ne12;
|
2689
|
+
const uint i13 = im/ne12;
|
2690
|
+
|
2691
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2692
|
+
|
2197
2693
|
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
2198
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2694
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2199
2695
|
|
2200
2696
|
float sumf[2]={0.f};
|
2201
2697
|
|
@@ -2211,15 +2707,15 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2211
2707
|
|
2212
2708
|
const int tid = tiisg/4;
|
2213
2709
|
const int ix = tiisg%4;
|
2214
|
-
const int
|
2710
|
+
const int iq = tid/4;
|
2215
2711
|
const int ir = tid%4;
|
2216
2712
|
const int n = 8;
|
2217
2713
|
|
2218
2714
|
const int l0 = n*ir;
|
2219
|
-
const int q_offset = 32*
|
2220
|
-
const int y_offset = 64*
|
2715
|
+
const int q_offset = 32*iq + l0;
|
2716
|
+
const int y_offset = 64*iq + l0;
|
2221
2717
|
|
2222
|
-
const uint8_t hm1 = 1u << (2*
|
2718
|
+
const uint8_t hm1 = 1u << (2*iq);
|
2223
2719
|
const uint8_t hm2 = hm1 << 1;
|
2224
2720
|
const uint8_t hm3 = hm1 << 4;
|
2225
2721
|
const uint8_t hm4 = hm2 << 4;
|
@@ -2234,7 +2730,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2234
2730
|
device const uint8_t * q1 = x[i].qs + q_offset;
|
2235
2731
|
device const uint8_t * qh = x[i].qh + l0;
|
2236
2732
|
device const half * dh = &x[i].d;
|
2237
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales +
|
2733
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
|
2238
2734
|
|
2239
2735
|
device const float * y2 = y1 + 128;
|
2240
2736
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
@@ -2290,7 +2786,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2290
2786
|
|
2291
2787
|
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
2292
2788
|
const int ix = tiisg%8;
|
2293
|
-
const int
|
2789
|
+
const int iq = il/8; // 0, 0, 1, 1
|
2294
2790
|
const int in = il%8; // 0, 4, 0, 4
|
2295
2791
|
|
2296
2792
|
device const float * y = yy + ix*QK_K + il;
|
@@ -2315,7 +2811,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2315
2811
|
|
2316
2812
|
float2 acc = {0.f, 0.f};
|
2317
2813
|
for (int l = 0; l < 4; ++l) {
|
2318
|
-
const uint8_t hl = h[l] >>
|
2814
|
+
const uint8_t hl = h[l] >> iq;
|
2319
2815
|
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
2320
2816
|
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
2321
2817
|
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
@@ -2337,7 +2833,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2337
2833
|
for (int row = 0; row < 2; ++row) {
|
2338
2834
|
const float tot = simd_sum(sumf[row]);
|
2339
2835
|
if (tiisg == 0) {
|
2340
|
-
dst[r1*ne0 +
|
2836
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
2341
2837
|
}
|
2342
2838
|
}
|
2343
2839
|
|
@@ -2352,9 +2848,10 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2352
2848
|
constant int64_t & ne02[[buffer(5)]],
|
2353
2849
|
constant int64_t & ne10[[buffer(9)]],
|
2354
2850
|
constant int64_t & ne12[[buffer(11)]],
|
2355
|
-
constant int64_t & ne0[[buffer(15)]],
|
2356
|
-
constant int64_t & ne1[[buffer(16)]],
|
2357
|
-
constant uint &
|
2851
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2852
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2853
|
+
constant uint & r2 [[buffer(17)]],
|
2854
|
+
constant uint & r3 [[buffer(18)]],
|
2358
2855
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2359
2856
|
uint tiisg[[thread_index_in_simdgroup]],
|
2360
2857
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2368,12 +2865,17 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2368
2865
|
|
2369
2866
|
const int64_t r0 = tgpig.x;
|
2370
2867
|
const int64_t r1 = tgpig.y;
|
2371
|
-
const int
|
2868
|
+
const int im = tgpig.z;
|
2372
2869
|
|
2373
2870
|
const int row = 2 * r0 + sgitg;
|
2374
|
-
|
2871
|
+
|
2872
|
+
const uint i12 = im%ne12;
|
2873
|
+
const uint i13 = im/ne12;
|
2874
|
+
|
2875
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2876
|
+
|
2375
2877
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
2376
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2878
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2377
2879
|
|
2378
2880
|
float sumf = 0;
|
2379
2881
|
|
@@ -2439,7 +2941,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2439
2941
|
|
2440
2942
|
const float tot = simd_sum(sumf);
|
2441
2943
|
if (tiisg == 0) {
|
2442
|
-
dst[r1*ne0 +
|
2944
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
2443
2945
|
}
|
2444
2946
|
}
|
2445
2947
|
|
@@ -2749,24 +3251,25 @@ kernel void kernel_get_rows(
|
|
2749
3251
|
|
2750
3252
|
// each block_q contains 16*nl weights
|
2751
3253
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
2752
|
-
|
2753
|
-
|
2754
|
-
|
2755
|
-
|
2756
|
-
|
2757
|
-
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2762
|
-
|
2763
|
-
|
2764
|
-
|
2765
|
-
|
2766
|
-
|
2767
|
-
|
2768
|
-
|
2769
|
-
|
3254
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
3255
|
+
device const uchar * src1,
|
3256
|
+
device float * dst,
|
3257
|
+
constant int64_t & ne00,
|
3258
|
+
constant int64_t & ne02,
|
3259
|
+
constant int64_t & nb01,
|
3260
|
+
constant int64_t & nb02,
|
3261
|
+
constant int64_t & ne12,
|
3262
|
+
constant int64_t & nb10,
|
3263
|
+
constant int64_t & nb11,
|
3264
|
+
constant int64_t & nb12,
|
3265
|
+
constant int64_t & ne0,
|
3266
|
+
constant int64_t & ne1,
|
3267
|
+
constant uint & r2,
|
3268
|
+
constant uint & r3,
|
3269
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3270
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3271
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3272
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2770
3273
|
|
2771
3274
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
2772
3275
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
@@ -2792,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2792
3295
|
|
2793
3296
|
short il = (tiitg % THREAD_PER_ROW);
|
2794
3297
|
|
2795
|
-
uint
|
3298
|
+
const uint i12 = im%ne12;
|
3299
|
+
const uint i13 = im/ne12;
|
3300
|
+
|
3301
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
2796
3302
|
ushort offset1 = il/nl;
|
2797
3303
|
|
2798
3304
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
@@ -2876,14 +3382,116 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2876
3382
|
}
|
2877
3383
|
}
|
2878
3384
|
|
3385
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3386
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
3387
|
+
device const uchar * src1,
|
3388
|
+
device float * dst,
|
3389
|
+
constant int64_t & ne00,
|
3390
|
+
constant int64_t & ne02,
|
3391
|
+
constant int64_t & nb01,
|
3392
|
+
constant int64_t & nb02,
|
3393
|
+
constant int64_t & ne12,
|
3394
|
+
constant int64_t & nb10,
|
3395
|
+
constant int64_t & nb11,
|
3396
|
+
constant int64_t & nb12,
|
3397
|
+
constant int64_t & ne0,
|
3398
|
+
constant int64_t & ne1,
|
3399
|
+
constant uint & r2,
|
3400
|
+
constant uint & r3,
|
3401
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3402
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3403
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3404
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3405
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3406
|
+
src0,
|
3407
|
+
src1,
|
3408
|
+
dst,
|
3409
|
+
ne00,
|
3410
|
+
ne02,
|
3411
|
+
nb01,
|
3412
|
+
nb02,
|
3413
|
+
ne12,
|
3414
|
+
nb10,
|
3415
|
+
nb11,
|
3416
|
+
nb12,
|
3417
|
+
ne0,
|
3418
|
+
ne1,
|
3419
|
+
r2,
|
3420
|
+
r3,
|
3421
|
+
shared_memory,
|
3422
|
+
tgpig,
|
3423
|
+
tiitg,
|
3424
|
+
sgitg);
|
3425
|
+
}
|
3426
|
+
|
3427
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3428
|
+
kernel void kernel_mul_mm_id(
|
3429
|
+
device const int32_t * ids,
|
3430
|
+
device const uchar * src1,
|
3431
|
+
device float * dst,
|
3432
|
+
constant int64_t & ne00,
|
3433
|
+
constant int64_t & ne02,
|
3434
|
+
constant int64_t & nb01,
|
3435
|
+
constant int64_t & nb02,
|
3436
|
+
constant int64_t & ne12,
|
3437
|
+
constant int64_t & nb10,
|
3438
|
+
constant int64_t & nb11,
|
3439
|
+
constant int64_t & nb12,
|
3440
|
+
constant int64_t & ne0,
|
3441
|
+
constant int64_t & ne1,
|
3442
|
+
constant uint & r2,
|
3443
|
+
constant uint & r3,
|
3444
|
+
constant int & idx,
|
3445
|
+
device const uchar * src00,
|
3446
|
+
device const uchar * src01,
|
3447
|
+
device const uchar * src02,
|
3448
|
+
device const uchar * src03,
|
3449
|
+
device const uchar * src04,
|
3450
|
+
device const uchar * src05,
|
3451
|
+
device const uchar * src06,
|
3452
|
+
device const uchar * src07,
|
3453
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3454
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3455
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3456
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3457
|
+
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3458
|
+
|
3459
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3460
|
+
src0[ids[idx]],
|
3461
|
+
src1,
|
3462
|
+
dst,
|
3463
|
+
ne00,
|
3464
|
+
ne02,
|
3465
|
+
nb01,
|
3466
|
+
nb02,
|
3467
|
+
ne12,
|
3468
|
+
nb10,
|
3469
|
+
nb11,
|
3470
|
+
nb12,
|
3471
|
+
ne0,
|
3472
|
+
ne1,
|
3473
|
+
r2,
|
3474
|
+
r3,
|
3475
|
+
shared_memory,
|
3476
|
+
tgpig,
|
3477
|
+
tiitg,
|
3478
|
+
sgitg);
|
3479
|
+
}
|
3480
|
+
|
2879
3481
|
#if QK_K == 256
|
2880
3482
|
#define QK_NL 16
|
2881
3483
|
#else
|
2882
3484
|
#define QK_NL 4
|
2883
3485
|
#endif
|
2884
3486
|
|
2885
|
-
typedef void (get_rows_t)(
|
2886
|
-
|
3487
|
+
typedef void (get_rows_t)(
|
3488
|
+
device const void * src0,
|
3489
|
+
device const int * src1,
|
3490
|
+
device float * dst,
|
3491
|
+
constant int64_t & ne00,
|
3492
|
+
constant uint64_t & nb01,
|
3493
|
+
constant uint64_t & nb1,
|
3494
|
+
uint, uint, uint);
|
2887
3495
|
|
2888
3496
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
2889
3497
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
@@ -2912,8 +3520,10 @@ typedef void (mat_mm_t)(
|
|
2912
3520
|
constant int64_t & nb12,
|
2913
3521
|
constant int64_t & ne0,
|
2914
3522
|
constant int64_t & ne1,
|
2915
|
-
constant uint &
|
2916
|
-
|
3523
|
+
constant uint & r2,
|
3524
|
+
constant uint & r3,
|
3525
|
+
threadgroup uchar *,
|
3526
|
+
uint3, uint, uint);
|
2917
3527
|
|
2918
3528
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
2919
3529
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
@@ -2927,3 +3537,44 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
2927
3537
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
2928
3538
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
2929
3539
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
3540
|
+
|
3541
|
+
typedef void (mat_mm_id_t)(
|
3542
|
+
device const int32_t * ids,
|
3543
|
+
device const uchar * src1,
|
3544
|
+
device float * dst,
|
3545
|
+
constant int64_t & ne00,
|
3546
|
+
constant int64_t & ne02,
|
3547
|
+
constant int64_t & nb01,
|
3548
|
+
constant int64_t & nb02,
|
3549
|
+
constant int64_t & ne12,
|
3550
|
+
constant int64_t & nb10,
|
3551
|
+
constant int64_t & nb11,
|
3552
|
+
constant int64_t & nb12,
|
3553
|
+
constant int64_t & ne0,
|
3554
|
+
constant int64_t & ne1,
|
3555
|
+
constant uint & r2,
|
3556
|
+
constant uint & r3,
|
3557
|
+
constant int & idx,
|
3558
|
+
device const uchar * src00,
|
3559
|
+
device const uchar * src01,
|
3560
|
+
device const uchar * src02,
|
3561
|
+
device const uchar * src03,
|
3562
|
+
device const uchar * src04,
|
3563
|
+
device const uchar * src05,
|
3564
|
+
device const uchar * src06,
|
3565
|
+
device const uchar * src07,
|
3566
|
+
threadgroup uchar *,
|
3567
|
+
uint3, uint, uint);
|
3568
|
+
|
3569
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
3570
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
3571
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
3572
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
3573
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
3574
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
3575
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
3576
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
3577
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
3578
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
3579
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
3580
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|