llama_cpp 0.10.0 → 0.10.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +18 -1
- data/ext/llama_cpp/src/ggml-alloc.c +12 -4
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-backend-impl.h +12 -8
- data/ext/llama_cpp/src/ggml-backend.c +75 -5
- data/ext/llama_cpp/src/ggml-backend.h +7 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +952 -232
- data/ext/llama_cpp/src/ggml-metal.h +3 -0
- data/ext/llama_cpp/src/ggml-metal.m +725 -98
- data/ext/llama_cpp/src/ggml-metal.metal +1508 -171
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +554 -215
- data/ext/llama_cpp/src/ggml.h +58 -23
- data/ext/llama_cpp/src/llama.cpp +1157 -851
- data/ext/llama_cpp/src/llama.h +9 -4
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- metadata +2 -2
@@ -79,6 +79,7 @@ kernel void kernel_add(
|
|
79
79
|
constant int64_t & nb1,
|
80
80
|
constant int64_t & nb2,
|
81
81
|
constant int64_t & nb3,
|
82
|
+
constant int64_t & offs,
|
82
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
83
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
84
85
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -90,9 +91,9 @@ kernel void kernel_add(
|
|
90
91
|
const int64_t i12 = i02 % ne12;
|
91
92
|
const int64_t i11 = i01 % ne11;
|
92
93
|
|
93
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
94
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
94
95
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
95
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
96
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
96
97
|
|
97
98
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
98
99
|
const int i10 = i0 % ne10;
|
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
|
|
204
205
|
device const float4 * src0,
|
205
206
|
device const float4 * src1,
|
206
207
|
device float4 * dst,
|
207
|
-
constant int64_t & nb [[buffer(
|
208
|
+
constant int64_t & nb [[buffer(28)]],
|
208
209
|
uint tpig[[thread_position_in_grid]]) {
|
209
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
210
211
|
}
|
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
|
|
213
214
|
device const float4 * src0,
|
214
215
|
device const float4 * src1,
|
215
216
|
device float4 * dst,
|
216
|
-
constant int64_t & nb [[buffer(
|
217
|
+
constant int64_t & nb [[buffer(28)]],
|
217
218
|
uint tpig[[thread_position_in_grid]]) {
|
218
219
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
219
220
|
}
|
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
|
|
222
223
|
device const float4 * src0,
|
223
224
|
device const float4 * src1,
|
224
225
|
device float4 * dst,
|
225
|
-
constant int64_t & nb [[buffer(
|
226
|
+
constant int64_t & nb [[buffer(28)]],
|
226
227
|
uint tpig[[thread_position_in_grid]]) {
|
227
228
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
228
229
|
}
|
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
|
|
243
244
|
dst[tpig] = src0[tpig] * scale;
|
244
245
|
}
|
245
246
|
|
246
|
-
kernel void
|
247
|
-
device const
|
248
|
-
device
|
247
|
+
kernel void kernel_relu(
|
248
|
+
device const float * src0,
|
249
|
+
device float * dst,
|
249
250
|
uint tpig[[thread_position_in_grid]]) {
|
250
|
-
|
251
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
251
|
+
dst[tpig] = max(0.0f, src0[tpig]);
|
252
252
|
}
|
253
253
|
|
254
|
-
kernel void
|
254
|
+
kernel void kernel_tanh(
|
255
255
|
device const float * src0,
|
256
256
|
device float * dst,
|
257
257
|
uint tpig[[thread_position_in_grid]]) {
|
258
|
-
|
258
|
+
device const float & x = src0[tpig];
|
259
|
+
dst[tpig] = precise::tanh(x);
|
260
|
+
}
|
261
|
+
|
262
|
+
constant float GELU_COEF_A = 0.044715f;
|
263
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
264
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
265
|
+
|
266
|
+
kernel void kernel_gelu(
|
267
|
+
device const float4 * src0,
|
268
|
+
device float4 * dst,
|
269
|
+
uint tpig[[thread_position_in_grid]]) {
|
270
|
+
device const float4 & x = src0[tpig];
|
271
|
+
|
272
|
+
// BEWARE !!!
|
273
|
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
274
|
+
// This was observed with Falcon 7B and 40B models
|
275
|
+
//
|
276
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
277
|
+
}
|
278
|
+
|
279
|
+
kernel void kernel_gelu_quick(
|
280
|
+
device const float4 * src0,
|
281
|
+
device float4 * dst,
|
282
|
+
uint tpig[[thread_position_in_grid]]) {
|
283
|
+
device const float4 & x = src0[tpig];
|
284
|
+
|
285
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
286
|
+
}
|
287
|
+
|
288
|
+
kernel void kernel_silu(
|
289
|
+
device const float4 * src0,
|
290
|
+
device float4 * dst,
|
291
|
+
uint tpig[[thread_position_in_grid]]) {
|
292
|
+
device const float4 & x = src0[tpig];
|
293
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
259
294
|
}
|
260
295
|
|
261
296
|
kernel void kernel_sqr(
|
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
|
|
313
348
|
dst_row[0] = row_sum;
|
314
349
|
}
|
315
350
|
|
316
|
-
constant float GELU_COEF_A = 0.044715f;
|
317
|
-
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
318
|
-
|
319
|
-
kernel void kernel_gelu(
|
320
|
-
device const float4 * src0,
|
321
|
-
device float4 * dst,
|
322
|
-
uint tpig[[thread_position_in_grid]]) {
|
323
|
-
device const float4 & x = src0[tpig];
|
324
|
-
|
325
|
-
// BEWARE !!!
|
326
|
-
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
327
|
-
// This was observed with Falcon 7B and 40B models
|
328
|
-
//
|
329
|
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
330
|
-
}
|
331
|
-
|
332
351
|
kernel void kernel_soft_max(
|
333
352
|
device const float * src0,
|
334
353
|
device const float * src1,
|
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
|
|
347
366
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
348
367
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
349
368
|
|
350
|
-
device const float * psrc0 =
|
351
|
-
device const float * pmask = src1 ? src1
|
352
|
-
device float * pdst =
|
369
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
370
|
+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
371
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
353
372
|
|
354
373
|
// parallel max
|
355
374
|
float lmax = -INFINITY;
|
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
|
|
385
404
|
pdst[i00] = exp_psrc0;
|
386
405
|
}
|
387
406
|
|
407
|
+
// This barrier fixes a failing test
|
408
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
409
|
+
threadgroup_barrier(mem_flags::mem_none);
|
410
|
+
|
388
411
|
float sum = simd_sum(lsum);
|
412
|
+
|
389
413
|
if (ntg > N_SIMDWIDTH) {
|
390
414
|
if (sgitg == 0) {
|
391
415
|
buf[tiisg] = 0.0f;
|
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
|
|
428
452
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
429
453
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
430
454
|
|
431
|
-
device const float4 * psrc4 =
|
432
|
-
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
433
|
-
device float4 * pdst4 =
|
455
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
456
|
+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
457
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
434
458
|
|
435
459
|
// parallel max
|
436
460
|
float4 lmax4 = -INFINITY;
|
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
|
|
468
492
|
}
|
469
493
|
|
470
494
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
495
|
+
|
496
|
+
// This barrier fixes a failing test
|
497
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
498
|
+
threadgroup_barrier(mem_flags::mem_none);
|
499
|
+
|
471
500
|
float sum = simd_sum(lsum);
|
501
|
+
|
472
502
|
if (ntg > N_SIMDWIDTH) {
|
473
503
|
if (sgitg == 0) {
|
474
504
|
buf[tiisg] = 0.0f;
|
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
|
|
639
669
|
}
|
640
670
|
}
|
641
671
|
|
672
|
+
kernel void kernel_group_norm(
|
673
|
+
device const float * src0,
|
674
|
+
device float * dst,
|
675
|
+
constant int64_t & ne00,
|
676
|
+
constant int64_t & ne01,
|
677
|
+
constant int64_t & ne02,
|
678
|
+
constant uint64_t & nb00,
|
679
|
+
constant uint64_t & nb01,
|
680
|
+
constant uint64_t & nb02,
|
681
|
+
constant int32_t & n_groups,
|
682
|
+
constant float & eps,
|
683
|
+
threadgroup float * buf [[threadgroup(0)]],
|
684
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
685
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
686
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
687
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
688
|
+
uint ntg[[threads_per_threadgroup]]) {
|
689
|
+
const int64_t ne = ne00*ne01*ne02;
|
690
|
+
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
691
|
+
|
692
|
+
int start = tgpig * gs;
|
693
|
+
int end = start + gs;
|
694
|
+
|
695
|
+
start += tpitg;
|
696
|
+
|
697
|
+
if (end >= ne) {
|
698
|
+
end = ne;
|
699
|
+
}
|
700
|
+
|
701
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
702
|
+
|
703
|
+
for (int j = start; j < end; j += ntg) {
|
704
|
+
tmp += src0[j];
|
705
|
+
}
|
706
|
+
|
707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
708
|
+
tmp = simd_sum(tmp);
|
709
|
+
if (ntg > N_SIMDWIDTH) {
|
710
|
+
if (sgitg == 0) {
|
711
|
+
buf[tiisg] = 0.0f;
|
712
|
+
}
|
713
|
+
|
714
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
715
|
+
|
716
|
+
if (tiisg == 0) {
|
717
|
+
buf[sgitg] = tmp;
|
718
|
+
}
|
719
|
+
|
720
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
721
|
+
|
722
|
+
tmp = buf[tiisg];
|
723
|
+
tmp = simd_sum(tmp);
|
724
|
+
}
|
725
|
+
|
726
|
+
const float mean = tmp / gs;
|
727
|
+
tmp = 0.0f;
|
728
|
+
|
729
|
+
for (int j = start; j < end; j += ntg) {
|
730
|
+
float xi = src0[j] - mean;
|
731
|
+
dst[j] = xi;
|
732
|
+
tmp += xi * xi;
|
733
|
+
}
|
734
|
+
|
735
|
+
tmp = simd_sum(tmp);
|
736
|
+
if (ntg > N_SIMDWIDTH) {
|
737
|
+
if (sgitg == 0) {
|
738
|
+
buf[tiisg] = 0.0f;
|
739
|
+
}
|
740
|
+
|
741
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
742
|
+
|
743
|
+
if (tiisg == 0) {
|
744
|
+
buf[sgitg] = tmp;
|
745
|
+
}
|
746
|
+
|
747
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
748
|
+
|
749
|
+
tmp = buf[tiisg];
|
750
|
+
tmp = simd_sum(tmp);
|
751
|
+
}
|
752
|
+
|
753
|
+
const float variance = tmp / gs;
|
754
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
755
|
+
for (int j = start; j < end; j += ntg) {
|
756
|
+
dst[j] *= scale;
|
757
|
+
}
|
758
|
+
}
|
759
|
+
|
642
760
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
643
761
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
644
762
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
@@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
731
849
|
// giard against the number of rows not being divisible by
|
732
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
733
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
734
|
-
void
|
852
|
+
void mul_vec_q_n_f32_impl(
|
735
853
|
device const void * src0,
|
736
854
|
device const float * src1,
|
737
855
|
device float * dst,
|
@@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
813
931
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
814
932
|
uint tiisg[[thread_index_in_simdgroup]],
|
815
933
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
816
|
-
|
934
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
817
935
|
}
|
818
936
|
|
819
937
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
832
950
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
833
951
|
uint tiisg[[thread_index_in_simdgroup]],
|
834
952
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
835
|
-
|
953
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
836
954
|
}
|
837
955
|
|
838
956
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
851
969
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
852
970
|
uint tiisg[[thread_index_in_simdgroup]],
|
853
971
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
854
|
-
|
972
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
855
973
|
}
|
856
974
|
|
857
975
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
870
988
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
871
989
|
uint tiisg[[thread_index_in_simdgroup]],
|
872
990
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
873
|
-
|
991
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
874
992
|
}
|
875
993
|
|
876
994
|
|
877
995
|
#define NB_Q8_0 8
|
878
996
|
|
879
|
-
|
997
|
+
void kernel_mul_mv_q8_0_f32_impl(
|
880
998
|
device const void * src0,
|
881
999
|
device const float * src1,
|
882
1000
|
device float * dst,
|
883
1001
|
constant int64_t & ne00,
|
884
|
-
constant int64_t & ne01
|
885
|
-
constant int64_t & ne02
|
886
|
-
constant int64_t & ne10
|
887
|
-
constant int64_t & ne12
|
888
|
-
constant int64_t & ne0
|
889
|
-
constant int64_t & ne1
|
890
|
-
constant uint & r2
|
891
|
-
constant uint & r3
|
1002
|
+
constant int64_t & ne01,
|
1003
|
+
constant int64_t & ne02,
|
1004
|
+
constant int64_t & ne10,
|
1005
|
+
constant int64_t & ne12,
|
1006
|
+
constant int64_t & ne0,
|
1007
|
+
constant int64_t & ne1,
|
1008
|
+
constant uint & r2,
|
1009
|
+
constant uint & r3,
|
892
1010
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
893
|
-
uint
|
894
|
-
uint
|
1011
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1012
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
895
1013
|
const int nr = N_DST;
|
896
1014
|
const int nsg = N_SIMDGROUP;
|
897
1015
|
const int nw = N_SIMDWIDTH;
|
@@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
945
1063
|
}
|
946
1064
|
}
|
947
1065
|
|
1066
|
+
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
1067
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
1068
|
+
device const void * src0,
|
1069
|
+
device const float * src1,
|
1070
|
+
device float * dst,
|
1071
|
+
constant int64_t & ne00,
|
1072
|
+
constant int64_t & ne01,
|
1073
|
+
constant int64_t & ne02,
|
1074
|
+
constant int64_t & ne10,
|
1075
|
+
constant int64_t & ne12,
|
1076
|
+
constant int64_t & ne0,
|
1077
|
+
constant int64_t & ne1,
|
1078
|
+
constant uint & r2 [[buffer(17)]],
|
1079
|
+
constant uint & r3 [[buffer(18)]],
|
1080
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1083
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1084
|
+
}
|
1085
|
+
|
948
1086
|
#define N_F32_F32 4
|
949
1087
|
|
950
|
-
|
1088
|
+
void kernel_mul_mv_f32_f32_impl(
|
951
1089
|
device const char * src0,
|
952
1090
|
device const char * src1,
|
953
1091
|
device float * dst,
|
@@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
965
1103
|
constant uint64_t & nb12,
|
966
1104
|
constant int64_t & ne0,
|
967
1105
|
constant int64_t & ne1,
|
968
|
-
constant uint & r2
|
969
|
-
constant uint & r3
|
1106
|
+
constant uint & r2,
|
1107
|
+
constant uint & r3,
|
970
1108
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
971
1109
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
972
1110
|
|
@@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
|
|
1025
1163
|
}
|
1026
1164
|
}
|
1027
1165
|
|
1166
|
+
[[host_name("kernel_mul_mv_f32_f32")]]
|
1167
|
+
kernel void kernel_mul_mv_f32_f32(
|
1168
|
+
device const char * src0,
|
1169
|
+
device const char * src1,
|
1170
|
+
device float * dst,
|
1171
|
+
constant int64_t & ne00,
|
1172
|
+
constant int64_t & ne01,
|
1173
|
+
constant int64_t & ne02,
|
1174
|
+
constant uint64_t & nb00,
|
1175
|
+
constant uint64_t & nb01,
|
1176
|
+
constant uint64_t & nb02,
|
1177
|
+
constant int64_t & ne10,
|
1178
|
+
constant int64_t & ne11,
|
1179
|
+
constant int64_t & ne12,
|
1180
|
+
constant uint64_t & nb10,
|
1181
|
+
constant uint64_t & nb11,
|
1182
|
+
constant uint64_t & nb12,
|
1183
|
+
constant int64_t & ne0,
|
1184
|
+
constant int64_t & ne1,
|
1185
|
+
constant uint & r2 [[buffer(17)]],
|
1186
|
+
constant uint & r3 [[buffer(18)]],
|
1187
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1188
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1189
|
+
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1190
|
+
}
|
1191
|
+
|
1028
1192
|
#define N_F16_F16 4
|
1029
1193
|
|
1030
1194
|
kernel void kernel_mul_mv_f16_f16(
|
@@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
|
|
1105
1269
|
}
|
1106
1270
|
}
|
1107
1271
|
|
1108
|
-
|
1272
|
+
void kernel_mul_mv_f16_f32_1row_impl(
|
1109
1273
|
device const char * src0,
|
1110
1274
|
device const char * src1,
|
1111
1275
|
device float * dst,
|
@@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
1123
1287
|
constant uint64_t & nb12,
|
1124
1288
|
constant int64_t & ne0,
|
1125
1289
|
constant int64_t & ne1,
|
1126
|
-
constant uint & r2
|
1127
|
-
constant uint & r3
|
1290
|
+
constant uint & r2,
|
1291
|
+
constant uint & r3,
|
1128
1292
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1129
1293
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1130
1294
|
|
@@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
1161
1325
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
1162
1326
|
}
|
1163
1327
|
}
|
1328
|
+
}
|
1164
1329
|
|
1330
|
+
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
1331
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
1332
|
+
device const char * src0,
|
1333
|
+
device const char * src1,
|
1334
|
+
device float * dst,
|
1335
|
+
constant int64_t & ne00,
|
1336
|
+
constant int64_t & ne01,
|
1337
|
+
constant int64_t & ne02,
|
1338
|
+
constant uint64_t & nb00,
|
1339
|
+
constant uint64_t & nb01,
|
1340
|
+
constant uint64_t & nb02,
|
1341
|
+
constant int64_t & ne10,
|
1342
|
+
constant int64_t & ne11,
|
1343
|
+
constant int64_t & ne12,
|
1344
|
+
constant uint64_t & nb10,
|
1345
|
+
constant uint64_t & nb11,
|
1346
|
+
constant uint64_t & nb12,
|
1347
|
+
constant int64_t & ne0,
|
1348
|
+
constant int64_t & ne1,
|
1349
|
+
constant uint & r2 [[buffer(17)]],
|
1350
|
+
constant uint & r3 [[buffer(18)]],
|
1351
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1353
|
+
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1165
1354
|
}
|
1166
1355
|
|
1167
1356
|
#define N_F16_F32 4
|
1168
1357
|
|
1169
|
-
|
1358
|
+
void kernel_mul_mv_f16_f32_impl(
|
1170
1359
|
device const char * src0,
|
1171
1360
|
device const char * src1,
|
1172
1361
|
device float * dst,
|
@@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1184
1373
|
constant uint64_t & nb12,
|
1185
1374
|
constant int64_t & ne0,
|
1186
1375
|
constant int64_t & ne1,
|
1187
|
-
constant uint & r2
|
1188
|
-
constant uint & r3
|
1376
|
+
constant uint & r2,
|
1377
|
+
constant uint & r3,
|
1189
1378
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1190
1379
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1191
1380
|
|
@@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1244
1433
|
}
|
1245
1434
|
}
|
1246
1435
|
|
1436
|
+
[[host_name("kernel_mul_mv_f16_f32")]]
|
1437
|
+
kernel void kernel_mul_mv_f16_f32(
|
1438
|
+
device const char * src0,
|
1439
|
+
device const char * src1,
|
1440
|
+
device float * dst,
|
1441
|
+
constant int64_t & ne00,
|
1442
|
+
constant int64_t & ne01,
|
1443
|
+
constant int64_t & ne02,
|
1444
|
+
constant uint64_t & nb00,
|
1445
|
+
constant uint64_t & nb01,
|
1446
|
+
constant uint64_t & nb02,
|
1447
|
+
constant int64_t & ne10,
|
1448
|
+
constant int64_t & ne11,
|
1449
|
+
constant int64_t & ne12,
|
1450
|
+
constant uint64_t & nb10,
|
1451
|
+
constant uint64_t & nb11,
|
1452
|
+
constant uint64_t & nb12,
|
1453
|
+
constant int64_t & ne0,
|
1454
|
+
constant int64_t & ne1,
|
1455
|
+
constant uint & r2 [[buffer(17)]],
|
1456
|
+
constant uint & r3 [[buffer(18)]],
|
1457
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1458
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1459
|
+
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1460
|
+
}
|
1461
|
+
|
1247
1462
|
// Assumes row size (ne00) is a multiple of 4
|
1248
1463
|
kernel void kernel_mul_mv_f16_f32_l4(
|
1249
1464
|
device const char * src0,
|
@@ -1487,8 +1702,9 @@ kernel void kernel_rope(
|
|
1487
1702
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
1488
1703
|
}
|
1489
1704
|
} else {
|
1490
|
-
for (int64_t
|
1491
|
-
|
1705
|
+
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1706
|
+
if (ic < n_dims) {
|
1707
|
+
const int64_t ib = 0;
|
1492
1708
|
|
1493
1709
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1494
1710
|
const float cur_rot = inv_ndims*ic - ib;
|
@@ -1507,6 +1723,14 @@ kernel void kernel_rope(
|
|
1507
1723
|
|
1508
1724
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
1509
1725
|
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
1726
|
+
} else {
|
1727
|
+
const int64_t i0 = ic;
|
1728
|
+
|
1729
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1730
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1731
|
+
|
1732
|
+
dst_data[0] = src[0];
|
1733
|
+
dst_data[1] = src[1];
|
1510
1734
|
}
|
1511
1735
|
}
|
1512
1736
|
}
|
@@ -1548,21 +1772,112 @@ kernel void kernel_im2col_f16(
|
|
1548
1772
|
}
|
1549
1773
|
}
|
1550
1774
|
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1775
|
+
kernel void kernel_upscale_f32(
|
1776
|
+
device const char * src0,
|
1777
|
+
device char * dst,
|
1778
|
+
constant int64_t & ne00,
|
1779
|
+
constant int64_t & ne01,
|
1780
|
+
constant int64_t & ne02,
|
1781
|
+
constant int64_t & ne03,
|
1782
|
+
constant uint64_t & nb00,
|
1783
|
+
constant uint64_t & nb01,
|
1784
|
+
constant uint64_t & nb02,
|
1785
|
+
constant uint64_t & nb03,
|
1786
|
+
constant int64_t & ne0,
|
1787
|
+
constant int64_t & ne1,
|
1788
|
+
constant int64_t & ne2,
|
1789
|
+
constant int64_t & ne3,
|
1790
|
+
constant uint64_t & nb0,
|
1791
|
+
constant uint64_t & nb1,
|
1792
|
+
constant uint64_t & nb2,
|
1793
|
+
constant uint64_t & nb3,
|
1794
|
+
constant int32_t & sf,
|
1795
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1796
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1797
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1798
|
+
|
1799
|
+
const int64_t i3 = tgpig.z;
|
1800
|
+
const int64_t i2 = tgpig.y;
|
1801
|
+
const int64_t i1 = tgpig.x;
|
1802
|
+
|
1803
|
+
const int64_t i03 = i3;
|
1804
|
+
const int64_t i02 = i2;
|
1805
|
+
const int64_t i01 = i1/sf;
|
1806
|
+
|
1807
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1808
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1809
|
+
|
1810
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1811
|
+
dst_ptr[i0] = src0_ptr[i0/sf];
|
1812
|
+
}
|
1813
|
+
}
|
1814
|
+
|
1815
|
+
kernel void kernel_pad_f32(
|
1816
|
+
device const char * src0,
|
1817
|
+
device char * dst,
|
1818
|
+
constant int64_t & ne00,
|
1819
|
+
constant int64_t & ne01,
|
1820
|
+
constant int64_t & ne02,
|
1821
|
+
constant int64_t & ne03,
|
1822
|
+
constant uint64_t & nb00,
|
1823
|
+
constant uint64_t & nb01,
|
1824
|
+
constant uint64_t & nb02,
|
1825
|
+
constant uint64_t & nb03,
|
1826
|
+
constant int64_t & ne0,
|
1827
|
+
constant int64_t & ne1,
|
1828
|
+
constant int64_t & ne2,
|
1829
|
+
constant int64_t & ne3,
|
1830
|
+
constant uint64_t & nb0,
|
1831
|
+
constant uint64_t & nb1,
|
1832
|
+
constant uint64_t & nb2,
|
1833
|
+
constant uint64_t & nb3,
|
1834
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1835
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1836
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1837
|
+
|
1838
|
+
const int64_t i3 = tgpig.z;
|
1839
|
+
const int64_t i2 = tgpig.y;
|
1840
|
+
const int64_t i1 = tgpig.x;
|
1841
|
+
|
1842
|
+
const int64_t i03 = i3;
|
1843
|
+
const int64_t i02 = i2;
|
1844
|
+
const int64_t i01 = i1;
|
1845
|
+
|
1846
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1847
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1848
|
+
|
1849
|
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
1850
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1851
|
+
if (i0 < ne00) {
|
1852
|
+
dst_ptr[i0] = src0_ptr[i0];
|
1853
|
+
} else {
|
1854
|
+
dst_ptr[i0] = 0.0f;
|
1855
|
+
}
|
1856
|
+
}
|
1857
|
+
|
1858
|
+
return;
|
1859
|
+
}
|
1860
|
+
|
1861
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1862
|
+
dst_ptr[i0] = 0.0f;
|
1863
|
+
}
|
1864
|
+
}
|
1865
|
+
|
1866
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
1867
|
+
typedef void (argsort_t)(
|
1868
|
+
device const float * x,
|
1869
|
+
device int32_t * dst,
|
1870
|
+
constant int64_t & ncols,
|
1871
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1872
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1873
|
+
|
1874
|
+
template<ggml_sort_order order>
|
1875
|
+
kernel void kernel_argsort_f32_i32(
|
1876
|
+
device const float * x,
|
1877
|
+
device int32_t * dst,
|
1878
|
+
constant int64_t & ncols,
|
1879
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1880
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1566
1881
|
// bitonic sort
|
1567
1882
|
int col = tpitg[0];
|
1568
1883
|
int row = tgpig[1];
|
@@ -1600,9 +1915,17 @@ kernel void kernel_argsort_f32_i32(
|
|
1600
1915
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
1601
1916
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
1602
1917
|
|
1918
|
+
kernel void kernel_leaky_relu_f32(
|
1919
|
+
device const float * src0,
|
1920
|
+
device float * dst,
|
1921
|
+
constant float & slope,
|
1922
|
+
uint tpig[[thread_position_in_grid]]) {
|
1923
|
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
1924
|
+
}
|
1925
|
+
|
1603
1926
|
kernel void kernel_cpy_f16_f16(
|
1604
|
-
device
|
1605
|
-
device
|
1927
|
+
device const half * src0,
|
1928
|
+
device half * dst,
|
1606
1929
|
constant int64_t & ne00,
|
1607
1930
|
constant int64_t & ne01,
|
1608
1931
|
constant int64_t & ne02,
|
@@ -1641,6 +1964,47 @@ kernel void kernel_cpy_f16_f16(
|
|
1641
1964
|
}
|
1642
1965
|
}
|
1643
1966
|
|
1967
|
+
kernel void kernel_cpy_f16_f32(
|
1968
|
+
device const half * src0,
|
1969
|
+
device float * dst,
|
1970
|
+
constant int64_t & ne00,
|
1971
|
+
constant int64_t & ne01,
|
1972
|
+
constant int64_t & ne02,
|
1973
|
+
constant int64_t & ne03,
|
1974
|
+
constant uint64_t & nb00,
|
1975
|
+
constant uint64_t & nb01,
|
1976
|
+
constant uint64_t & nb02,
|
1977
|
+
constant uint64_t & nb03,
|
1978
|
+
constant int64_t & ne0,
|
1979
|
+
constant int64_t & ne1,
|
1980
|
+
constant int64_t & ne2,
|
1981
|
+
constant int64_t & ne3,
|
1982
|
+
constant uint64_t & nb0,
|
1983
|
+
constant uint64_t & nb1,
|
1984
|
+
constant uint64_t & nb2,
|
1985
|
+
constant uint64_t & nb3,
|
1986
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1987
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1988
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1989
|
+
const int64_t i03 = tgpig[2];
|
1990
|
+
const int64_t i02 = tgpig[1];
|
1991
|
+
const int64_t i01 = tgpig[0];
|
1992
|
+
|
1993
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1994
|
+
|
1995
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1996
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1997
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1998
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1999
|
+
|
2000
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2001
|
+
|
2002
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
2003
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2004
|
+
dst_data[i00] = src[0];
|
2005
|
+
}
|
2006
|
+
}
|
2007
|
+
|
1644
2008
|
kernel void kernel_cpy_f32_f16(
|
1645
2009
|
device const float * src0,
|
1646
2010
|
device half * dst,
|
@@ -1917,9 +2281,9 @@ kernel void kernel_cpy_f32_q4_1(
|
|
1917
2281
|
}
|
1918
2282
|
|
1919
2283
|
kernel void kernel_concat(
|
1920
|
-
device
|
1921
|
-
device
|
1922
|
-
device
|
2284
|
+
device const char * src0,
|
2285
|
+
device const char * src1,
|
2286
|
+
device char * dst,
|
1923
2287
|
constant int64_t & ne00,
|
1924
2288
|
constant int64_t & ne01,
|
1925
2289
|
constant int64_t & ne02,
|
@@ -1956,7 +2320,7 @@ kernel void kernel_concat(
|
|
1956
2320
|
const int64_t i12 = i02 % ne12;
|
1957
2321
|
const int64_t i11 = i01 % ne11;
|
1958
2322
|
|
1959
|
-
device const char * src0_ptr = src0 + i03
|
2323
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
1960
2324
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1961
2325
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1962
2326
|
|
@@ -2064,19 +2428,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
2064
2428
|
|
2065
2429
|
//====================================== dot products =========================
|
2066
2430
|
|
2067
|
-
|
2431
|
+
void kernel_mul_mv_q2_K_f32_impl(
|
2068
2432
|
device const void * src0,
|
2069
2433
|
device const float * src1,
|
2070
2434
|
device float * dst,
|
2071
2435
|
constant int64_t & ne00,
|
2072
|
-
constant int64_t & ne01
|
2073
|
-
constant int64_t & ne02
|
2074
|
-
constant int64_t & ne10
|
2075
|
-
constant int64_t & ne12
|
2076
|
-
constant int64_t & ne0
|
2077
|
-
constant int64_t & ne1
|
2078
|
-
constant uint & r2
|
2079
|
-
constant uint & r3
|
2436
|
+
constant int64_t & ne01,
|
2437
|
+
constant int64_t & ne02,
|
2438
|
+
constant int64_t & ne10,
|
2439
|
+
constant int64_t & ne12,
|
2440
|
+
constant int64_t & ne0,
|
2441
|
+
constant int64_t & ne1,
|
2442
|
+
constant uint & r2,
|
2443
|
+
constant uint & r3,
|
2080
2444
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2081
2445
|
uint tiisg[[thread_index_in_simdgroup]],
|
2082
2446
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2214,8 +2578,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
2214
2578
|
}
|
2215
2579
|
}
|
2216
2580
|
|
2217
|
-
|
2218
|
-
kernel void
|
2581
|
+
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
2582
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
2219
2583
|
device const void * src0,
|
2220
2584
|
device const float * src1,
|
2221
2585
|
device float * dst,
|
@@ -2229,8 +2593,29 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2229
2593
|
constant uint & r2 [[buffer(17)]],
|
2230
2594
|
constant uint & r3 [[buffer(18)]],
|
2231
2595
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2232
|
-
uint
|
2233
|
-
uint
|
2596
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2597
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2598
|
+
|
2599
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2600
|
+
}
|
2601
|
+
|
2602
|
+
#if QK_K == 256
|
2603
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
2604
|
+
device const void * src0,
|
2605
|
+
device const float * src1,
|
2606
|
+
device float * dst,
|
2607
|
+
constant int64_t & ne00,
|
2608
|
+
constant int64_t & ne01,
|
2609
|
+
constant int64_t & ne02,
|
2610
|
+
constant int64_t & ne10,
|
2611
|
+
constant int64_t & ne12,
|
2612
|
+
constant int64_t & ne0,
|
2613
|
+
constant int64_t & ne1,
|
2614
|
+
constant uint & r2,
|
2615
|
+
constant uint & r3,
|
2616
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2617
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2618
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2234
2619
|
|
2235
2620
|
const int nb = ne00/QK_K;
|
2236
2621
|
|
@@ -2373,19 +2758,19 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2373
2758
|
}
|
2374
2759
|
}
|
2375
2760
|
#else
|
2376
|
-
|
2761
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
2377
2762
|
device const void * src0,
|
2378
2763
|
device const float * src1,
|
2379
2764
|
device float * dst,
|
2380
2765
|
constant int64_t & ne00,
|
2381
|
-
constant int64_t & ne01
|
2382
|
-
constant int64_t & ne02
|
2383
|
-
constant int64_t & ne10
|
2384
|
-
constant int64_t & ne12
|
2385
|
-
constant int64_t & ne0
|
2386
|
-
constant int64_t & ne1
|
2387
|
-
constant uint & r2
|
2388
|
-
constant uint & r3
|
2766
|
+
constant int64_t & ne01,
|
2767
|
+
constant int64_t & ne02,
|
2768
|
+
constant int64_t & ne10,
|
2769
|
+
constant int64_t & ne12,
|
2770
|
+
constant int64_t & ne0,
|
2771
|
+
constant int64_t & ne1,
|
2772
|
+
constant uint & r2,
|
2773
|
+
constant uint & r3,
|
2389
2774
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2390
2775
|
uint tiisg[[thread_index_in_simdgroup]],
|
2391
2776
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2450,20 +2835,41 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2450
2835
|
}
|
2451
2836
|
#endif
|
2452
2837
|
|
2838
|
+
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
2839
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
2840
|
+
device const void * src0,
|
2841
|
+
device const float * src1,
|
2842
|
+
device float * dst,
|
2843
|
+
constant int64_t & ne00,
|
2844
|
+
constant int64_t & ne01[[buffer(4)]],
|
2845
|
+
constant int64_t & ne02[[buffer(5)]],
|
2846
|
+
constant int64_t & ne10[[buffer(9)]],
|
2847
|
+
constant int64_t & ne12[[buffer(11)]],
|
2848
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2849
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2850
|
+
constant uint & r2 [[buffer(17)]],
|
2851
|
+
constant uint & r3 [[buffer(18)]],
|
2852
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2853
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2854
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2855
|
+
|
2856
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2857
|
+
}
|
2858
|
+
|
2453
2859
|
#if QK_K == 256
|
2454
|
-
|
2860
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
2455
2861
|
device const void * src0,
|
2456
2862
|
device const float * src1,
|
2457
2863
|
device float * dst,
|
2458
2864
|
constant int64_t & ne00,
|
2459
|
-
constant int64_t & ne01
|
2460
|
-
constant int64_t & ne02
|
2461
|
-
constant int64_t & ne10
|
2462
|
-
constant int64_t & ne12
|
2463
|
-
constant int64_t & ne0
|
2464
|
-
constant int64_t & ne1
|
2465
|
-
constant uint & r2
|
2466
|
-
constant uint & r3
|
2865
|
+
constant int64_t & ne01,
|
2866
|
+
constant int64_t & ne02,
|
2867
|
+
constant int64_t & ne10,
|
2868
|
+
constant int64_t & ne12,
|
2869
|
+
constant int64_t & ne0,
|
2870
|
+
constant int64_t & ne1,
|
2871
|
+
constant uint & r2,
|
2872
|
+
constant uint & r3,
|
2467
2873
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2468
2874
|
uint tiisg[[thread_index_in_simdgroup]],
|
2469
2875
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2564,19 +2970,19 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2564
2970
|
}
|
2565
2971
|
}
|
2566
2972
|
#else
|
2567
|
-
|
2973
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
2568
2974
|
device const void * src0,
|
2569
2975
|
device const float * src1,
|
2570
2976
|
device float * dst,
|
2571
2977
|
constant int64_t & ne00,
|
2572
|
-
constant int64_t & ne01
|
2573
|
-
constant int64_t & ne02
|
2574
|
-
constant int64_t & ne10
|
2575
|
-
constant int64_t & ne12
|
2576
|
-
constant int64_t & ne0
|
2577
|
-
constant int64_t & ne1
|
2578
|
-
constant uint & r2
|
2579
|
-
constant uint & r3
|
2978
|
+
constant int64_t & ne01,
|
2979
|
+
constant int64_t & ne02,
|
2980
|
+
constant int64_t & ne10,
|
2981
|
+
constant int64_t & ne12,
|
2982
|
+
constant int64_t & ne0,
|
2983
|
+
constant int64_t & ne1,
|
2984
|
+
constant uint & r2,
|
2985
|
+
constant uint & r3,
|
2580
2986
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2581
2987
|
uint tiisg[[thread_index_in_simdgroup]],
|
2582
2988
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2660,7 +3066,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2660
3066
|
}
|
2661
3067
|
#endif
|
2662
3068
|
|
2663
|
-
|
3069
|
+
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
3070
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
2664
3071
|
device const void * src0,
|
2665
3072
|
device const float * src1,
|
2666
3073
|
device float * dst,
|
@@ -2677,6 +3084,26 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2677
3084
|
uint tiisg[[thread_index_in_simdgroup]],
|
2678
3085
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2679
3086
|
|
3087
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3088
|
+
}
|
3089
|
+
|
3090
|
+
void kernel_mul_mv_q5_K_f32_impl(
|
3091
|
+
device const void * src0,
|
3092
|
+
device const float * src1,
|
3093
|
+
device float * dst,
|
3094
|
+
constant int64_t & ne00,
|
3095
|
+
constant int64_t & ne01,
|
3096
|
+
constant int64_t & ne02,
|
3097
|
+
constant int64_t & ne10,
|
3098
|
+
constant int64_t & ne12,
|
3099
|
+
constant int64_t & ne0,
|
3100
|
+
constant int64_t & ne1,
|
3101
|
+
constant uint & r2,
|
3102
|
+
constant uint & r3,
|
3103
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3104
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3105
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3106
|
+
|
2680
3107
|
const int nb = ne00/QK_K;
|
2681
3108
|
|
2682
3109
|
const int64_t r0 = tgpig.x;
|
@@ -2836,10 +3263,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2836
3263
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
2837
3264
|
}
|
2838
3265
|
}
|
2839
|
-
|
2840
3266
|
}
|
2841
3267
|
|
2842
|
-
|
3268
|
+
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
3269
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
2843
3270
|
device const void * src0,
|
2844
3271
|
device const float * src1,
|
2845
3272
|
device float * dst,
|
@@ -2853,21 +3280,41 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2853
3280
|
constant uint & r2 [[buffer(17)]],
|
2854
3281
|
constant uint & r3 [[buffer(18)]],
|
2855
3282
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2856
|
-
uint
|
2857
|
-
uint
|
2858
|
-
|
2859
|
-
const uint8_t kmask1 = 0x03;
|
2860
|
-
const uint8_t kmask2 = 0x0C;
|
2861
|
-
const uint8_t kmask3 = 0x30;
|
2862
|
-
const uint8_t kmask4 = 0xC0;
|
2863
|
-
|
2864
|
-
const int nb = ne00/QK_K;
|
2865
|
-
|
2866
|
-
const int64_t r0 = tgpig.x;
|
2867
|
-
const int64_t r1 = tgpig.y;
|
2868
|
-
const int im = tgpig.z;
|
3283
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3284
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2869
3285
|
|
2870
|
-
|
3286
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3287
|
+
}
|
3288
|
+
|
3289
|
+
void kernel_mul_mv_q6_K_f32_impl(
|
3290
|
+
device const void * src0,
|
3291
|
+
device const float * src1,
|
3292
|
+
device float * dst,
|
3293
|
+
constant int64_t & ne00,
|
3294
|
+
constant int64_t & ne01,
|
3295
|
+
constant int64_t & ne02,
|
3296
|
+
constant int64_t & ne10,
|
3297
|
+
constant int64_t & ne12,
|
3298
|
+
constant int64_t & ne0,
|
3299
|
+
constant int64_t & ne1,
|
3300
|
+
constant uint & r2,
|
3301
|
+
constant uint & r3,
|
3302
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3303
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3304
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3305
|
+
|
3306
|
+
const uint8_t kmask1 = 0x03;
|
3307
|
+
const uint8_t kmask2 = 0x0C;
|
3308
|
+
const uint8_t kmask3 = 0x30;
|
3309
|
+
const uint8_t kmask4 = 0xC0;
|
3310
|
+
|
3311
|
+
const int nb = ne00/QK_K;
|
3312
|
+
|
3313
|
+
const int64_t r0 = tgpig.x;
|
3314
|
+
const int64_t r1 = tgpig.y;
|
3315
|
+
const int im = tgpig.z;
|
3316
|
+
|
3317
|
+
const int row = 2 * r0 + sgitg;
|
2871
3318
|
|
2872
3319
|
const uint i12 = im%ne12;
|
2873
3320
|
const uint i13 = im/ne12;
|
@@ -2945,6 +3392,27 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2945
3392
|
}
|
2946
3393
|
}
|
2947
3394
|
|
3395
|
+
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
3396
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
3397
|
+
device const void * src0,
|
3398
|
+
device const float * src1,
|
3399
|
+
device float * dst,
|
3400
|
+
constant int64_t & ne00,
|
3401
|
+
constant int64_t & ne01[[buffer(4)]],
|
3402
|
+
constant int64_t & ne02[[buffer(5)]],
|
3403
|
+
constant int64_t & ne10[[buffer(9)]],
|
3404
|
+
constant int64_t & ne12[[buffer(11)]],
|
3405
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3406
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3407
|
+
constant uint & r2 [[buffer(17)]],
|
3408
|
+
constant uint & r3 [[buffer(18)]],
|
3409
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3410
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3411
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3412
|
+
|
3413
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3414
|
+
}
|
3415
|
+
|
2948
3416
|
//============================= templates and their specializations =============================
|
2949
3417
|
|
2950
3418
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
@@ -3062,10 +3530,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
3062
3530
|
|
3063
3531
|
template <typename type4x4>
|
3064
3532
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
3065
|
-
const
|
3066
|
-
const
|
3533
|
+
const float d = xb->d;
|
3534
|
+
const float min = xb->dmin;
|
3067
3535
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
3068
|
-
|
3536
|
+
float dl, ml;
|
3069
3537
|
uint8_t sc = xb->scales[il];
|
3070
3538
|
|
3071
3539
|
#if QK_K == 256
|
@@ -3135,10 +3603,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
3135
3603
|
q = q + (il/4) * 32 + 16 * (il&1);
|
3136
3604
|
il = il & 3;
|
3137
3605
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
3138
|
-
const
|
3139
|
-
const
|
3140
|
-
const
|
3141
|
-
const
|
3606
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3607
|
+
const float min = xb->dmin;
|
3608
|
+
const float dl = d * sc[0];
|
3609
|
+
const float ml = min * sc[1];
|
3142
3610
|
#else
|
3143
3611
|
q = q + 16 * (il&1);
|
3144
3612
|
device const uint8_t * s = xb->scales;
|
@@ -3165,13 +3633,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
3165
3633
|
uint8_t ul = 1 << (il/2);
|
3166
3634
|
il = il & 3;
|
3167
3635
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
3168
|
-
const
|
3169
|
-
const
|
3170
|
-
const
|
3171
|
-
const
|
3636
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3637
|
+
const float min = xb->dmin;
|
3638
|
+
const float dl = d * sc[0];
|
3639
|
+
const float ml = min * sc[1];
|
3172
3640
|
|
3173
|
-
const ushort mask
|
3174
|
-
const
|
3641
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
3642
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
3175
3643
|
for (int i = 0; i < 16; ++i) {
|
3176
3644
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
3177
3645
|
}
|
@@ -3219,22 +3687,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
3219
3687
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
3220
3688
|
kernel void kernel_get_rows(
|
3221
3689
|
device const void * src0,
|
3222
|
-
device const
|
3690
|
+
device const char * src1,
|
3223
3691
|
device float * dst,
|
3224
3692
|
constant int64_t & ne00,
|
3225
3693
|
constant uint64_t & nb01,
|
3694
|
+
constant uint64_t & nb02,
|
3695
|
+
constant int64_t & ne10,
|
3696
|
+
constant uint64_t & nb10,
|
3697
|
+
constant uint64_t & nb11,
|
3226
3698
|
constant uint64_t & nb1,
|
3227
|
-
|
3699
|
+
constant uint64_t & nb2,
|
3700
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3228
3701
|
uint tiitg[[thread_index_in_threadgroup]],
|
3229
|
-
|
3230
|
-
const
|
3231
|
-
const
|
3702
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3703
|
+
//const int64_t i = tgpig;
|
3704
|
+
//const int64_t r = ((device int32_t *) src1)[i];
|
3705
|
+
|
3706
|
+
const int64_t i10 = tgpig.x;
|
3707
|
+
const int64_t i11 = tgpig.y;
|
3708
|
+
|
3709
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3232
3710
|
|
3233
|
-
|
3711
|
+
const int64_t i02 = i11;
|
3712
|
+
|
3713
|
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
3234
3714
|
float4x4 temp;
|
3235
3715
|
dequantize_func(
|
3236
|
-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
3237
|
-
*(((device float4x4 *) ((device char *) dst +
|
3716
|
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
3717
|
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
3718
|
+
}
|
3719
|
+
}
|
3720
|
+
|
3721
|
+
kernel void kernel_get_rows_f32(
|
3722
|
+
device const void * src0,
|
3723
|
+
device const char * src1,
|
3724
|
+
device float * dst,
|
3725
|
+
constant int64_t & ne00,
|
3726
|
+
constant uint64_t & nb01,
|
3727
|
+
constant uint64_t & nb02,
|
3728
|
+
constant int64_t & ne10,
|
3729
|
+
constant uint64_t & nb10,
|
3730
|
+
constant uint64_t & nb11,
|
3731
|
+
constant uint64_t & nb1,
|
3732
|
+
constant uint64_t & nb2,
|
3733
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3734
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3735
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3736
|
+
const int64_t i10 = tgpig.x;
|
3737
|
+
const int64_t i11 = tgpig.y;
|
3738
|
+
|
3739
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3740
|
+
|
3741
|
+
const int64_t i02 = i11;
|
3742
|
+
|
3743
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3744
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3745
|
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3746
|
+
}
|
3747
|
+
}
|
3748
|
+
|
3749
|
+
kernel void kernel_get_rows_f16(
|
3750
|
+
device const void * src0,
|
3751
|
+
device const char * src1,
|
3752
|
+
device float * dst,
|
3753
|
+
constant int64_t & ne00,
|
3754
|
+
constant uint64_t & nb01,
|
3755
|
+
constant uint64_t & nb02,
|
3756
|
+
constant int64_t & ne10,
|
3757
|
+
constant uint64_t & nb10,
|
3758
|
+
constant uint64_t & nb11,
|
3759
|
+
constant uint64_t & nb1,
|
3760
|
+
constant uint64_t & nb2,
|
3761
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3762
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3763
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3764
|
+
const int64_t i10 = tgpig.x;
|
3765
|
+
const int64_t i11 = tgpig.y;
|
3766
|
+
|
3767
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3768
|
+
|
3769
|
+
const int64_t i02 = i11;
|
3770
|
+
|
3771
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3772
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3773
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3238
3774
|
}
|
3239
3775
|
}
|
3240
3776
|
|
@@ -3426,19 +3962,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
3426
3962
|
|
3427
3963
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3428
3964
|
kernel void kernel_mul_mm_id(
|
3429
|
-
device const
|
3965
|
+
device const uchar * ids,
|
3430
3966
|
device const uchar * src1,
|
3431
|
-
device
|
3967
|
+
device uchar * dst,
|
3968
|
+
constant int64_t & nbi1,
|
3432
3969
|
constant int64_t & ne00,
|
3433
3970
|
constant int64_t & ne02,
|
3434
3971
|
constant int64_t & nb01,
|
3435
3972
|
constant int64_t & nb02,
|
3436
3973
|
constant int64_t & ne12,
|
3974
|
+
constant int64_t & ne13,
|
3437
3975
|
constant int64_t & nb10,
|
3438
3976
|
constant int64_t & nb11,
|
3439
3977
|
constant int64_t & nb12,
|
3440
3978
|
constant int64_t & ne0,
|
3441
3979
|
constant int64_t & ne1,
|
3980
|
+
constant int64_t & nb1,
|
3442
3981
|
constant uint & r2,
|
3443
3982
|
constant uint & r3,
|
3444
3983
|
constant int & idx,
|
@@ -3456,10 +3995,16 @@ kernel void kernel_mul_mm_id(
|
|
3456
3995
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3457
3996
|
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3458
3997
|
|
3998
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
3999
|
+
|
4000
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4001
|
+
|
4002
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4003
|
+
|
3459
4004
|
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3460
|
-
src0[
|
3461
|
-
src1,
|
3462
|
-
dst,
|
4005
|
+
src0[id],
|
4006
|
+
src1 + bid*nb11,
|
4007
|
+
(device float *) (dst + bid*nb1),
|
3463
4008
|
ne00,
|
3464
4009
|
ne02,
|
3465
4010
|
nb01,
|
@@ -3484,17 +4029,26 @@ kernel void kernel_mul_mm_id(
|
|
3484
4029
|
#define QK_NL 4
|
3485
4030
|
#endif
|
3486
4031
|
|
4032
|
+
//
|
4033
|
+
// get rows
|
4034
|
+
//
|
4035
|
+
|
3487
4036
|
typedef void (get_rows_t)(
|
3488
4037
|
device const void * src0,
|
3489
|
-
device const
|
4038
|
+
device const char * src1,
|
3490
4039
|
device float * dst,
|
3491
4040
|
constant int64_t & ne00,
|
3492
4041
|
constant uint64_t & nb01,
|
4042
|
+
constant uint64_t & nb02,
|
4043
|
+
constant int64_t & ne10,
|
4044
|
+
constant uint64_t & nb10,
|
4045
|
+
constant uint64_t & nb11,
|
3493
4046
|
constant uint64_t & nb1,
|
3494
|
-
|
4047
|
+
constant uint64_t & nb2,
|
4048
|
+
uint3, uint, uint3);
|
3495
4049
|
|
3496
|
-
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
3497
|
-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
4050
|
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
4051
|
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
3498
4052
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
3499
4053
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
3500
4054
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
@@ -3506,6 +4060,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
3506
4060
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
3507
4061
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
3508
4062
|
|
4063
|
+
//
|
4064
|
+
// matrix-matrix multiplication
|
4065
|
+
//
|
4066
|
+
|
3509
4067
|
typedef void (mat_mm_t)(
|
3510
4068
|
device const uchar * src0,
|
3511
4069
|
device const uchar * src1,
|
@@ -3538,20 +4096,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
3538
4096
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
3539
4097
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
3540
4098
|
|
4099
|
+
//
|
4100
|
+
// indirect matrix-matrix multiplication
|
4101
|
+
//
|
4102
|
+
|
3541
4103
|
typedef void (mat_mm_id_t)(
|
3542
|
-
device const
|
4104
|
+
device const uchar * ids,
|
3543
4105
|
device const uchar * src1,
|
3544
|
-
device
|
4106
|
+
device uchar * dst,
|
4107
|
+
constant int64_t & nbi1,
|
3545
4108
|
constant int64_t & ne00,
|
3546
4109
|
constant int64_t & ne02,
|
3547
4110
|
constant int64_t & nb01,
|
3548
4111
|
constant int64_t & nb02,
|
3549
4112
|
constant int64_t & ne12,
|
4113
|
+
constant int64_t & ne13,
|
3550
4114
|
constant int64_t & nb10,
|
3551
4115
|
constant int64_t & nb11,
|
3552
4116
|
constant int64_t & nb12,
|
3553
4117
|
constant int64_t & ne0,
|
3554
4118
|
constant int64_t & ne1,
|
4119
|
+
constant int64_t & nb1,
|
3555
4120
|
constant uint & r2,
|
3556
4121
|
constant uint & r3,
|
3557
4122
|
constant int & idx,
|
@@ -3578,3 +4143,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
3578
4143
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
3579
4144
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
3580
4145
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
4146
|
+
|
4147
|
+
//
|
4148
|
+
// matrix-vector multiplication
|
4149
|
+
//
|
4150
|
+
|
4151
|
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
4152
|
+
kernel void kernel_mul_mv_id_f32_f32(
|
4153
|
+
device const char * ids,
|
4154
|
+
device const char * src1,
|
4155
|
+
device uchar * dst,
|
4156
|
+
constant int64_t & nbi1,
|
4157
|
+
constant int64_t & ne00,
|
4158
|
+
constant int64_t & ne01,
|
4159
|
+
constant int64_t & ne02,
|
4160
|
+
constant uint64_t & nb00,
|
4161
|
+
constant uint64_t & nb01,
|
4162
|
+
constant uint64_t & nb02,
|
4163
|
+
constant int64_t & ne10,
|
4164
|
+
constant int64_t & ne11,
|
4165
|
+
constant int64_t & ne12,
|
4166
|
+
constant int64_t & ne13,
|
4167
|
+
constant uint64_t & nb10,
|
4168
|
+
constant uint64_t & nb11,
|
4169
|
+
constant uint64_t & nb12,
|
4170
|
+
constant int64_t & ne0,
|
4171
|
+
constant int64_t & ne1,
|
4172
|
+
constant int64_t & nb1,
|
4173
|
+
constant uint & r2,
|
4174
|
+
constant uint & r3,
|
4175
|
+
constant int & idx,
|
4176
|
+
device const char * src00,
|
4177
|
+
device const char * src01,
|
4178
|
+
device const char * src02,
|
4179
|
+
device const char * src03,
|
4180
|
+
device const char * src04,
|
4181
|
+
device const char * src05,
|
4182
|
+
device const char * src06,
|
4183
|
+
device const char * src07,
|
4184
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4185
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4186
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4187
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4188
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4189
|
+
|
4190
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4191
|
+
|
4192
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4193
|
+
|
4194
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4195
|
+
|
4196
|
+
kernel_mul_mv_f32_f32_impl(
|
4197
|
+
src0[id],
|
4198
|
+
src1 + bid*nb11,
|
4199
|
+
(device float *) (dst + bid*nb1),
|
4200
|
+
ne00,
|
4201
|
+
ne01,
|
4202
|
+
ne02,
|
4203
|
+
nb00,
|
4204
|
+
nb01,
|
4205
|
+
nb02,
|
4206
|
+
ne10,
|
4207
|
+
ne11,
|
4208
|
+
ne12,
|
4209
|
+
nb10,
|
4210
|
+
nb11,
|
4211
|
+
nb12,
|
4212
|
+
ne0,
|
4213
|
+
ne1,
|
4214
|
+
r2,
|
4215
|
+
r3,
|
4216
|
+
tgpig,
|
4217
|
+
tiisg);
|
4218
|
+
}
|
4219
|
+
|
4220
|
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
4221
|
+
kernel void kernel_mul_mv_id_f16_f32(
|
4222
|
+
device const char * ids,
|
4223
|
+
device const char * src1,
|
4224
|
+
device uchar * dst,
|
4225
|
+
constant int64_t & nbi1,
|
4226
|
+
constant int64_t & ne00,
|
4227
|
+
constant int64_t & ne01,
|
4228
|
+
constant int64_t & ne02,
|
4229
|
+
constant uint64_t & nb00,
|
4230
|
+
constant uint64_t & nb01,
|
4231
|
+
constant uint64_t & nb02,
|
4232
|
+
constant int64_t & ne10,
|
4233
|
+
constant int64_t & ne11,
|
4234
|
+
constant int64_t & ne12,
|
4235
|
+
constant int64_t & ne13,
|
4236
|
+
constant uint64_t & nb10,
|
4237
|
+
constant uint64_t & nb11,
|
4238
|
+
constant uint64_t & nb12,
|
4239
|
+
constant int64_t & ne0,
|
4240
|
+
constant int64_t & ne1,
|
4241
|
+
constant int64_t & nb1,
|
4242
|
+
constant uint & r2,
|
4243
|
+
constant uint & r3,
|
4244
|
+
constant int & idx,
|
4245
|
+
device const char * src00,
|
4246
|
+
device const char * src01,
|
4247
|
+
device const char * src02,
|
4248
|
+
device const char * src03,
|
4249
|
+
device const char * src04,
|
4250
|
+
device const char * src05,
|
4251
|
+
device const char * src06,
|
4252
|
+
device const char * src07,
|
4253
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4254
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4255
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4256
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4257
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4258
|
+
|
4259
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4260
|
+
|
4261
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4262
|
+
|
4263
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4264
|
+
|
4265
|
+
kernel_mul_mv_f16_f32_impl(
|
4266
|
+
src0[id],
|
4267
|
+
src1 + bid*nb11,
|
4268
|
+
(device float *) (dst + bid*nb1),
|
4269
|
+
ne00,
|
4270
|
+
ne01,
|
4271
|
+
ne02,
|
4272
|
+
nb00,
|
4273
|
+
nb01,
|
4274
|
+
nb02,
|
4275
|
+
ne10,
|
4276
|
+
ne11,
|
4277
|
+
ne12,
|
4278
|
+
nb10,
|
4279
|
+
nb11,
|
4280
|
+
nb12,
|
4281
|
+
ne0,
|
4282
|
+
ne1,
|
4283
|
+
r2,
|
4284
|
+
r3,
|
4285
|
+
tgpig,
|
4286
|
+
tiisg);
|
4287
|
+
}
|
4288
|
+
|
4289
|
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
4290
|
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
4291
|
+
device const char * ids,
|
4292
|
+
device const char * src1,
|
4293
|
+
device uchar * dst,
|
4294
|
+
constant int64_t & nbi1,
|
4295
|
+
constant int64_t & ne00,
|
4296
|
+
constant int64_t & ne01,
|
4297
|
+
constant int64_t & ne02,
|
4298
|
+
constant uint64_t & nb00,
|
4299
|
+
constant uint64_t & nb01,
|
4300
|
+
constant uint64_t & nb02,
|
4301
|
+
constant int64_t & ne10,
|
4302
|
+
constant int64_t & ne11,
|
4303
|
+
constant int64_t & ne12,
|
4304
|
+
constant int64_t & ne13,
|
4305
|
+
constant uint64_t & nb10,
|
4306
|
+
constant uint64_t & nb11,
|
4307
|
+
constant uint64_t & nb12,
|
4308
|
+
constant int64_t & ne0,
|
4309
|
+
constant int64_t & ne1,
|
4310
|
+
constant int64_t & nb1,
|
4311
|
+
constant uint & r2,
|
4312
|
+
constant uint & r3,
|
4313
|
+
constant int & idx,
|
4314
|
+
device const char * src00,
|
4315
|
+
device const char * src01,
|
4316
|
+
device const char * src02,
|
4317
|
+
device const char * src03,
|
4318
|
+
device const char * src04,
|
4319
|
+
device const char * src05,
|
4320
|
+
device const char * src06,
|
4321
|
+
device const char * src07,
|
4322
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4323
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4324
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4325
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4326
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4327
|
+
|
4328
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4329
|
+
|
4330
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4331
|
+
|
4332
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4333
|
+
|
4334
|
+
kernel_mul_mv_q8_0_f32_impl(
|
4335
|
+
src0[id],
|
4336
|
+
(device const float *) (src1 + bid*nb11),
|
4337
|
+
(device float *) ( dst + bid*nb1),
|
4338
|
+
ne00,
|
4339
|
+
ne01,
|
4340
|
+
ne02,
|
4341
|
+
ne10,
|
4342
|
+
ne12,
|
4343
|
+
ne0,
|
4344
|
+
ne1,
|
4345
|
+
r2,
|
4346
|
+
r3,
|
4347
|
+
tgpig,
|
4348
|
+
tiisg,
|
4349
|
+
sgitg);
|
4350
|
+
}
|
4351
|
+
|
4352
|
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
4353
|
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
4354
|
+
device const char * ids,
|
4355
|
+
device const char * src1,
|
4356
|
+
device uchar * dst,
|
4357
|
+
constant int64_t & nbi1,
|
4358
|
+
constant int64_t & ne00,
|
4359
|
+
constant int64_t & ne01,
|
4360
|
+
constant int64_t & ne02,
|
4361
|
+
constant uint64_t & nb00,
|
4362
|
+
constant uint64_t & nb01,
|
4363
|
+
constant uint64_t & nb02,
|
4364
|
+
constant int64_t & ne10,
|
4365
|
+
constant int64_t & ne11,
|
4366
|
+
constant int64_t & ne12,
|
4367
|
+
constant int64_t & ne13,
|
4368
|
+
constant uint64_t & nb10,
|
4369
|
+
constant uint64_t & nb11,
|
4370
|
+
constant uint64_t & nb12,
|
4371
|
+
constant int64_t & ne0,
|
4372
|
+
constant int64_t & ne1,
|
4373
|
+
constant int64_t & nb1,
|
4374
|
+
constant uint & r2,
|
4375
|
+
constant uint & r3,
|
4376
|
+
constant int & idx,
|
4377
|
+
device const char * src00,
|
4378
|
+
device const char * src01,
|
4379
|
+
device const char * src02,
|
4380
|
+
device const char * src03,
|
4381
|
+
device const char * src04,
|
4382
|
+
device const char * src05,
|
4383
|
+
device const char * src06,
|
4384
|
+
device const char * src07,
|
4385
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4386
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4387
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4388
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4389
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4390
|
+
|
4391
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4392
|
+
|
4393
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4394
|
+
|
4395
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4396
|
+
|
4397
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4398
|
+
src0[id],
|
4399
|
+
(device const float *) (src1 + bid*nb11),
|
4400
|
+
(device float *) ( dst + bid*nb1),
|
4401
|
+
ne00,
|
4402
|
+
ne01,
|
4403
|
+
ne02,
|
4404
|
+
ne10,
|
4405
|
+
ne12,
|
4406
|
+
ne0,
|
4407
|
+
ne1,
|
4408
|
+
r2,
|
4409
|
+
r3,
|
4410
|
+
tgpig,
|
4411
|
+
tiisg,
|
4412
|
+
sgitg);
|
4413
|
+
}
|
4414
|
+
|
4415
|
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
4416
|
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
4417
|
+
device const char * ids,
|
4418
|
+
device const char * src1,
|
4419
|
+
device uchar * dst,
|
4420
|
+
constant int64_t & nbi1,
|
4421
|
+
constant int64_t & ne00,
|
4422
|
+
constant int64_t & ne01,
|
4423
|
+
constant int64_t & ne02,
|
4424
|
+
constant uint64_t & nb00,
|
4425
|
+
constant uint64_t & nb01,
|
4426
|
+
constant uint64_t & nb02,
|
4427
|
+
constant int64_t & ne10,
|
4428
|
+
constant int64_t & ne11,
|
4429
|
+
constant int64_t & ne12,
|
4430
|
+
constant int64_t & ne13,
|
4431
|
+
constant uint64_t & nb10,
|
4432
|
+
constant uint64_t & nb11,
|
4433
|
+
constant uint64_t & nb12,
|
4434
|
+
constant int64_t & ne0,
|
4435
|
+
constant int64_t & ne1,
|
4436
|
+
constant int64_t & nb1,
|
4437
|
+
constant uint & r2,
|
4438
|
+
constant uint & r3,
|
4439
|
+
constant int & idx,
|
4440
|
+
device const char * src00,
|
4441
|
+
device const char * src01,
|
4442
|
+
device const char * src02,
|
4443
|
+
device const char * src03,
|
4444
|
+
device const char * src04,
|
4445
|
+
device const char * src05,
|
4446
|
+
device const char * src06,
|
4447
|
+
device const char * src07,
|
4448
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4449
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4450
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4451
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4452
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4453
|
+
|
4454
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4455
|
+
|
4456
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4457
|
+
|
4458
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4459
|
+
|
4460
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4461
|
+
src0[id],
|
4462
|
+
(device const float *) (src1 + bid*nb11),
|
4463
|
+
(device float *) ( dst + bid*nb1),
|
4464
|
+
ne00,
|
4465
|
+
ne01,
|
4466
|
+
ne02,
|
4467
|
+
ne10,
|
4468
|
+
ne12,
|
4469
|
+
ne0,
|
4470
|
+
ne1,
|
4471
|
+
r2,
|
4472
|
+
r3,
|
4473
|
+
tgpig,
|
4474
|
+
tiisg,
|
4475
|
+
sgitg);
|
4476
|
+
}
|
4477
|
+
|
4478
|
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
4479
|
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
4480
|
+
device const char * ids,
|
4481
|
+
device const char * src1,
|
4482
|
+
device uchar * dst,
|
4483
|
+
constant int64_t & nbi1,
|
4484
|
+
constant int64_t & ne00,
|
4485
|
+
constant int64_t & ne01,
|
4486
|
+
constant int64_t & ne02,
|
4487
|
+
constant uint64_t & nb00,
|
4488
|
+
constant uint64_t & nb01,
|
4489
|
+
constant uint64_t & nb02,
|
4490
|
+
constant int64_t & ne10,
|
4491
|
+
constant int64_t & ne11,
|
4492
|
+
constant int64_t & ne12,
|
4493
|
+
constant int64_t & ne13,
|
4494
|
+
constant uint64_t & nb10,
|
4495
|
+
constant uint64_t & nb11,
|
4496
|
+
constant uint64_t & nb12,
|
4497
|
+
constant int64_t & ne0,
|
4498
|
+
constant int64_t & ne1,
|
4499
|
+
constant int64_t & nb1,
|
4500
|
+
constant uint & r2,
|
4501
|
+
constant uint & r3,
|
4502
|
+
constant int & idx,
|
4503
|
+
device const char * src00,
|
4504
|
+
device const char * src01,
|
4505
|
+
device const char * src02,
|
4506
|
+
device const char * src03,
|
4507
|
+
device const char * src04,
|
4508
|
+
device const char * src05,
|
4509
|
+
device const char * src06,
|
4510
|
+
device const char * src07,
|
4511
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4512
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4513
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4514
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4515
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4516
|
+
|
4517
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4518
|
+
|
4519
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4520
|
+
|
4521
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4522
|
+
|
4523
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4524
|
+
src0[id],
|
4525
|
+
(device const float *) (src1 + bid*nb11),
|
4526
|
+
(device float *) ( dst + bid*nb1),
|
4527
|
+
ne00,
|
4528
|
+
ne01,
|
4529
|
+
ne02,
|
4530
|
+
ne10,
|
4531
|
+
ne12,
|
4532
|
+
ne0,
|
4533
|
+
ne1,
|
4534
|
+
r2,
|
4535
|
+
r3,
|
4536
|
+
tgpig,
|
4537
|
+
tiisg,
|
4538
|
+
sgitg);
|
4539
|
+
}
|
4540
|
+
|
4541
|
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
4542
|
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
4543
|
+
device const char * ids,
|
4544
|
+
device const char * src1,
|
4545
|
+
device uchar * dst,
|
4546
|
+
constant int64_t & nbi1,
|
4547
|
+
constant int64_t & ne00,
|
4548
|
+
constant int64_t & ne01,
|
4549
|
+
constant int64_t & ne02,
|
4550
|
+
constant uint64_t & nb00,
|
4551
|
+
constant uint64_t & nb01,
|
4552
|
+
constant uint64_t & nb02,
|
4553
|
+
constant int64_t & ne10,
|
4554
|
+
constant int64_t & ne11,
|
4555
|
+
constant int64_t & ne12,
|
4556
|
+
constant int64_t & ne13,
|
4557
|
+
constant uint64_t & nb10,
|
4558
|
+
constant uint64_t & nb11,
|
4559
|
+
constant uint64_t & nb12,
|
4560
|
+
constant int64_t & ne0,
|
4561
|
+
constant int64_t & ne1,
|
4562
|
+
constant int64_t & nb1,
|
4563
|
+
constant uint & r2,
|
4564
|
+
constant uint & r3,
|
4565
|
+
constant int & idx,
|
4566
|
+
device const char * src00,
|
4567
|
+
device const char * src01,
|
4568
|
+
device const char * src02,
|
4569
|
+
device const char * src03,
|
4570
|
+
device const char * src04,
|
4571
|
+
device const char * src05,
|
4572
|
+
device const char * src06,
|
4573
|
+
device const char * src07,
|
4574
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4575
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4576
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4577
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4578
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4579
|
+
|
4580
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4581
|
+
|
4582
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4583
|
+
|
4584
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4585
|
+
|
4586
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4587
|
+
src0[id],
|
4588
|
+
(device const float *) (src1 + bid*nb11),
|
4589
|
+
(device float *) ( dst + bid*nb1),
|
4590
|
+
ne00,
|
4591
|
+
ne01,
|
4592
|
+
ne02,
|
4593
|
+
ne10,
|
4594
|
+
ne12,
|
4595
|
+
ne0,
|
4596
|
+
ne1,
|
4597
|
+
r2,
|
4598
|
+
r3,
|
4599
|
+
tgpig,
|
4600
|
+
tiisg,
|
4601
|
+
sgitg);
|
4602
|
+
}
|
4603
|
+
|
4604
|
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
4605
|
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
4606
|
+
device const char * ids,
|
4607
|
+
device const char * src1,
|
4608
|
+
device uchar * dst,
|
4609
|
+
constant int64_t & nbi1,
|
4610
|
+
constant int64_t & ne00,
|
4611
|
+
constant int64_t & ne01,
|
4612
|
+
constant int64_t & ne02,
|
4613
|
+
constant uint64_t & nb00,
|
4614
|
+
constant uint64_t & nb01,
|
4615
|
+
constant uint64_t & nb02,
|
4616
|
+
constant int64_t & ne10,
|
4617
|
+
constant int64_t & ne11,
|
4618
|
+
constant int64_t & ne12,
|
4619
|
+
constant int64_t & ne13,
|
4620
|
+
constant uint64_t & nb10,
|
4621
|
+
constant uint64_t & nb11,
|
4622
|
+
constant uint64_t & nb12,
|
4623
|
+
constant int64_t & ne0,
|
4624
|
+
constant int64_t & ne1,
|
4625
|
+
constant int64_t & nb1,
|
4626
|
+
constant uint & r2,
|
4627
|
+
constant uint & r3,
|
4628
|
+
constant int & idx,
|
4629
|
+
device const char * src00,
|
4630
|
+
device const char * src01,
|
4631
|
+
device const char * src02,
|
4632
|
+
device const char * src03,
|
4633
|
+
device const char * src04,
|
4634
|
+
device const char * src05,
|
4635
|
+
device const char * src06,
|
4636
|
+
device const char * src07,
|
4637
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4638
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4639
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4640
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4641
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4642
|
+
|
4643
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4644
|
+
|
4645
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4646
|
+
|
4647
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4648
|
+
|
4649
|
+
kernel_mul_mv_q2_K_f32_impl(
|
4650
|
+
src0[id],
|
4651
|
+
(device const float *) (src1 + bid*nb11),
|
4652
|
+
(device float *) ( dst + bid*nb1),
|
4653
|
+
ne00,
|
4654
|
+
ne01,
|
4655
|
+
ne02,
|
4656
|
+
ne10,
|
4657
|
+
ne12,
|
4658
|
+
ne0,
|
4659
|
+
ne1,
|
4660
|
+
r2,
|
4661
|
+
r3,
|
4662
|
+
tgpig,
|
4663
|
+
tiisg,
|
4664
|
+
sgitg);
|
4665
|
+
}
|
4666
|
+
|
4667
|
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
4668
|
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
4669
|
+
device const char * ids,
|
4670
|
+
device const char * src1,
|
4671
|
+
device uchar * dst,
|
4672
|
+
constant int64_t & nbi1,
|
4673
|
+
constant int64_t & ne00,
|
4674
|
+
constant int64_t & ne01,
|
4675
|
+
constant int64_t & ne02,
|
4676
|
+
constant uint64_t & nb00,
|
4677
|
+
constant uint64_t & nb01,
|
4678
|
+
constant uint64_t & nb02,
|
4679
|
+
constant int64_t & ne10,
|
4680
|
+
constant int64_t & ne11,
|
4681
|
+
constant int64_t & ne12,
|
4682
|
+
constant int64_t & ne13,
|
4683
|
+
constant uint64_t & nb10,
|
4684
|
+
constant uint64_t & nb11,
|
4685
|
+
constant uint64_t & nb12,
|
4686
|
+
constant int64_t & ne0,
|
4687
|
+
constant int64_t & ne1,
|
4688
|
+
constant int64_t & nb1,
|
4689
|
+
constant uint & r2,
|
4690
|
+
constant uint & r3,
|
4691
|
+
constant int & idx,
|
4692
|
+
device const char * src00,
|
4693
|
+
device const char * src01,
|
4694
|
+
device const char * src02,
|
4695
|
+
device const char * src03,
|
4696
|
+
device const char * src04,
|
4697
|
+
device const char * src05,
|
4698
|
+
device const char * src06,
|
4699
|
+
device const char * src07,
|
4700
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4701
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4702
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4703
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4704
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4705
|
+
|
4706
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4707
|
+
|
4708
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4709
|
+
|
4710
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4711
|
+
|
4712
|
+
kernel_mul_mv_q3_K_f32_impl(
|
4713
|
+
src0[id],
|
4714
|
+
(device const float *) (src1 + bid*nb11),
|
4715
|
+
(device float *) ( dst + bid*nb1),
|
4716
|
+
ne00,
|
4717
|
+
ne01,
|
4718
|
+
ne02,
|
4719
|
+
ne10,
|
4720
|
+
ne12,
|
4721
|
+
ne0,
|
4722
|
+
ne1,
|
4723
|
+
r2,
|
4724
|
+
r3,
|
4725
|
+
tgpig,
|
4726
|
+
tiisg,
|
4727
|
+
sgitg);
|
4728
|
+
}
|
4729
|
+
|
4730
|
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
4731
|
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
4732
|
+
device const char * ids,
|
4733
|
+
device const char * src1,
|
4734
|
+
device uchar * dst,
|
4735
|
+
constant int64_t & nbi1,
|
4736
|
+
constant int64_t & ne00,
|
4737
|
+
constant int64_t & ne01,
|
4738
|
+
constant int64_t & ne02,
|
4739
|
+
constant uint64_t & nb00,
|
4740
|
+
constant uint64_t & nb01,
|
4741
|
+
constant uint64_t & nb02,
|
4742
|
+
constant int64_t & ne10,
|
4743
|
+
constant int64_t & ne11,
|
4744
|
+
constant int64_t & ne12,
|
4745
|
+
constant int64_t & ne13,
|
4746
|
+
constant uint64_t & nb10,
|
4747
|
+
constant uint64_t & nb11,
|
4748
|
+
constant uint64_t & nb12,
|
4749
|
+
constant int64_t & ne0,
|
4750
|
+
constant int64_t & ne1,
|
4751
|
+
constant int64_t & nb1,
|
4752
|
+
constant uint & r2,
|
4753
|
+
constant uint & r3,
|
4754
|
+
constant int & idx,
|
4755
|
+
device const char * src00,
|
4756
|
+
device const char * src01,
|
4757
|
+
device const char * src02,
|
4758
|
+
device const char * src03,
|
4759
|
+
device const char * src04,
|
4760
|
+
device const char * src05,
|
4761
|
+
device const char * src06,
|
4762
|
+
device const char * src07,
|
4763
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4764
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4765
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4766
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4767
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4768
|
+
|
4769
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4770
|
+
|
4771
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4772
|
+
|
4773
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4774
|
+
|
4775
|
+
kernel_mul_mv_q4_K_f32_impl(
|
4776
|
+
src0[id],
|
4777
|
+
(device const float *) (src1 + bid*nb11),
|
4778
|
+
(device float *) ( dst + bid*nb1),
|
4779
|
+
ne00,
|
4780
|
+
ne01,
|
4781
|
+
ne02,
|
4782
|
+
ne10,
|
4783
|
+
ne12,
|
4784
|
+
ne0,
|
4785
|
+
ne1,
|
4786
|
+
r2,
|
4787
|
+
r3,
|
4788
|
+
tgpig,
|
4789
|
+
tiisg,
|
4790
|
+
sgitg);
|
4791
|
+
}
|
4792
|
+
|
4793
|
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
4794
|
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
4795
|
+
device const char * ids,
|
4796
|
+
device const char * src1,
|
4797
|
+
device uchar * dst,
|
4798
|
+
constant int64_t & nbi1,
|
4799
|
+
constant int64_t & ne00,
|
4800
|
+
constant int64_t & ne01,
|
4801
|
+
constant int64_t & ne02,
|
4802
|
+
constant uint64_t & nb00,
|
4803
|
+
constant uint64_t & nb01,
|
4804
|
+
constant uint64_t & nb02,
|
4805
|
+
constant int64_t & ne10,
|
4806
|
+
constant int64_t & ne11,
|
4807
|
+
constant int64_t & ne12,
|
4808
|
+
constant int64_t & ne13,
|
4809
|
+
constant uint64_t & nb10,
|
4810
|
+
constant uint64_t & nb11,
|
4811
|
+
constant uint64_t & nb12,
|
4812
|
+
constant int64_t & ne0,
|
4813
|
+
constant int64_t & ne1,
|
4814
|
+
constant int64_t & nb1,
|
4815
|
+
constant uint & r2,
|
4816
|
+
constant uint & r3,
|
4817
|
+
constant int & idx,
|
4818
|
+
device const char * src00,
|
4819
|
+
device const char * src01,
|
4820
|
+
device const char * src02,
|
4821
|
+
device const char * src03,
|
4822
|
+
device const char * src04,
|
4823
|
+
device const char * src05,
|
4824
|
+
device const char * src06,
|
4825
|
+
device const char * src07,
|
4826
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4827
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4828
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4829
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4830
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4831
|
+
|
4832
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4833
|
+
|
4834
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4835
|
+
|
4836
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4837
|
+
|
4838
|
+
kernel_mul_mv_q5_K_f32_impl(
|
4839
|
+
src0[id],
|
4840
|
+
(device const float *) (src1 + bid*nb11),
|
4841
|
+
(device float *) ( dst + bid*nb1),
|
4842
|
+
ne00,
|
4843
|
+
ne01,
|
4844
|
+
ne02,
|
4845
|
+
ne10,
|
4846
|
+
ne12,
|
4847
|
+
ne0,
|
4848
|
+
ne1,
|
4849
|
+
r2,
|
4850
|
+
r3,
|
4851
|
+
tgpig,
|
4852
|
+
tiisg,
|
4853
|
+
sgitg);
|
4854
|
+
}
|
4855
|
+
|
4856
|
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
4857
|
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
4858
|
+
device const char * ids,
|
4859
|
+
device const char * src1,
|
4860
|
+
device uchar * dst,
|
4861
|
+
constant int64_t & nbi1,
|
4862
|
+
constant int64_t & ne00,
|
4863
|
+
constant int64_t & ne01,
|
4864
|
+
constant int64_t & ne02,
|
4865
|
+
constant uint64_t & nb00,
|
4866
|
+
constant uint64_t & nb01,
|
4867
|
+
constant uint64_t & nb02,
|
4868
|
+
constant int64_t & ne10,
|
4869
|
+
constant int64_t & ne11,
|
4870
|
+
constant int64_t & ne12,
|
4871
|
+
constant int64_t & ne13,
|
4872
|
+
constant uint64_t & nb10,
|
4873
|
+
constant uint64_t & nb11,
|
4874
|
+
constant uint64_t & nb12,
|
4875
|
+
constant int64_t & ne0,
|
4876
|
+
constant int64_t & ne1,
|
4877
|
+
constant int64_t & nb1,
|
4878
|
+
constant uint & r2,
|
4879
|
+
constant uint & r3,
|
4880
|
+
constant int & idx,
|
4881
|
+
device const char * src00,
|
4882
|
+
device const char * src01,
|
4883
|
+
device const char * src02,
|
4884
|
+
device const char * src03,
|
4885
|
+
device const char * src04,
|
4886
|
+
device const char * src05,
|
4887
|
+
device const char * src06,
|
4888
|
+
device const char * src07,
|
4889
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4890
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4891
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4892
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4893
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4894
|
+
|
4895
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4896
|
+
|
4897
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4898
|
+
|
4899
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4900
|
+
|
4901
|
+
kernel_mul_mv_q6_K_f32_impl(
|
4902
|
+
src0[id],
|
4903
|
+
(device const float *) (src1 + bid*nb11),
|
4904
|
+
(device float *) ( dst + bid*nb1),
|
4905
|
+
ne00,
|
4906
|
+
ne01,
|
4907
|
+
ne02,
|
4908
|
+
ne10,
|
4909
|
+
ne12,
|
4910
|
+
ne0,
|
4911
|
+
ne1,
|
4912
|
+
r2,
|
4913
|
+
r3,
|
4914
|
+
tgpig,
|
4915
|
+
tiisg,
|
4916
|
+
sgitg);
|
4917
|
+
}
|