llama_cpp 0.10.0 → 0.10.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +2 -0
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +691 -93
- data/ext/llama_cpp/src/ggml-metal.m +535 -54
- data/ext/llama_cpp/src/ggml-metal.metal +1497 -169
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +325 -159
- data/ext/llama_cpp/src/ggml.h +34 -13
- data/ext/llama_cpp/src/llama.cpp +195 -35
- data/ext/llama_cpp/src/llama.h +1 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- metadata +2 -2
@@ -79,6 +79,7 @@ kernel void kernel_add(
|
|
79
79
|
constant int64_t & nb1,
|
80
80
|
constant int64_t & nb2,
|
81
81
|
constant int64_t & nb3,
|
82
|
+
constant int64_t & offs,
|
82
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
83
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
84
85
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -90,9 +91,9 @@ kernel void kernel_add(
|
|
90
91
|
const int64_t i12 = i02 % ne12;
|
91
92
|
const int64_t i11 = i01 % ne11;
|
92
93
|
|
93
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
94
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
94
95
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
95
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
96
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
96
97
|
|
97
98
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
98
99
|
const int i10 = i0 % ne10;
|
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
|
|
204
205
|
device const float4 * src0,
|
205
206
|
device const float4 * src1,
|
206
207
|
device float4 * dst,
|
207
|
-
constant int64_t & nb [[buffer(
|
208
|
+
constant int64_t & nb [[buffer(28)]],
|
208
209
|
uint tpig[[thread_position_in_grid]]) {
|
209
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
210
211
|
}
|
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
|
|
213
214
|
device const float4 * src0,
|
214
215
|
device const float4 * src1,
|
215
216
|
device float4 * dst,
|
216
|
-
constant int64_t & nb [[buffer(
|
217
|
+
constant int64_t & nb [[buffer(28)]],
|
217
218
|
uint tpig[[thread_position_in_grid]]) {
|
218
219
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
219
220
|
}
|
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
|
|
222
223
|
device const float4 * src0,
|
223
224
|
device const float4 * src1,
|
224
225
|
device float4 * dst,
|
225
|
-
constant int64_t & nb [[buffer(
|
226
|
+
constant int64_t & nb [[buffer(28)]],
|
226
227
|
uint tpig[[thread_position_in_grid]]) {
|
227
228
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
228
229
|
}
|
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
|
|
243
244
|
dst[tpig] = src0[tpig] * scale;
|
244
245
|
}
|
245
246
|
|
246
|
-
kernel void
|
247
|
-
device const
|
248
|
-
device
|
247
|
+
kernel void kernel_relu(
|
248
|
+
device const float * src0,
|
249
|
+
device float * dst,
|
249
250
|
uint tpig[[thread_position_in_grid]]) {
|
250
|
-
|
251
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
251
|
+
dst[tpig] = max(0.0f, src0[tpig]);
|
252
252
|
}
|
253
253
|
|
254
|
-
kernel void
|
254
|
+
kernel void kernel_tanh(
|
255
255
|
device const float * src0,
|
256
256
|
device float * dst,
|
257
257
|
uint tpig[[thread_position_in_grid]]) {
|
258
|
-
|
258
|
+
device const float & x = src0[tpig];
|
259
|
+
dst[tpig] = precise::tanh(x);
|
260
|
+
}
|
261
|
+
|
262
|
+
constant float GELU_COEF_A = 0.044715f;
|
263
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
264
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
265
|
+
|
266
|
+
kernel void kernel_gelu(
|
267
|
+
device const float4 * src0,
|
268
|
+
device float4 * dst,
|
269
|
+
uint tpig[[thread_position_in_grid]]) {
|
270
|
+
device const float4 & x = src0[tpig];
|
271
|
+
|
272
|
+
// BEWARE !!!
|
273
|
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
274
|
+
// This was observed with Falcon 7B and 40B models
|
275
|
+
//
|
276
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
277
|
+
}
|
278
|
+
|
279
|
+
kernel void kernel_gelu_quick(
|
280
|
+
device const float4 * src0,
|
281
|
+
device float4 * dst,
|
282
|
+
uint tpig[[thread_position_in_grid]]) {
|
283
|
+
device const float4 & x = src0[tpig];
|
284
|
+
|
285
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
286
|
+
}
|
287
|
+
|
288
|
+
kernel void kernel_silu(
|
289
|
+
device const float4 * src0,
|
290
|
+
device float4 * dst,
|
291
|
+
uint tpig[[thread_position_in_grid]]) {
|
292
|
+
device const float4 & x = src0[tpig];
|
293
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
259
294
|
}
|
260
295
|
|
261
296
|
kernel void kernel_sqr(
|
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
|
|
313
348
|
dst_row[0] = row_sum;
|
314
349
|
}
|
315
350
|
|
316
|
-
constant float GELU_COEF_A = 0.044715f;
|
317
|
-
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
318
|
-
|
319
|
-
kernel void kernel_gelu(
|
320
|
-
device const float4 * src0,
|
321
|
-
device float4 * dst,
|
322
|
-
uint tpig[[thread_position_in_grid]]) {
|
323
|
-
device const float4 & x = src0[tpig];
|
324
|
-
|
325
|
-
// BEWARE !!!
|
326
|
-
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
327
|
-
// This was observed with Falcon 7B and 40B models
|
328
|
-
//
|
329
|
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
330
|
-
}
|
331
|
-
|
332
351
|
kernel void kernel_soft_max(
|
333
352
|
device const float * src0,
|
334
353
|
device const float * src1,
|
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
|
|
347
366
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
348
367
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
349
368
|
|
350
|
-
device const float * psrc0 =
|
351
|
-
device const float * pmask = src1 ? src1
|
352
|
-
device float * pdst =
|
369
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
370
|
+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
371
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
353
372
|
|
354
373
|
// parallel max
|
355
374
|
float lmax = -INFINITY;
|
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
|
|
385
404
|
pdst[i00] = exp_psrc0;
|
386
405
|
}
|
387
406
|
|
407
|
+
// This barrier fixes a failing test
|
408
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
409
|
+
threadgroup_barrier(mem_flags::mem_none);
|
410
|
+
|
388
411
|
float sum = simd_sum(lsum);
|
412
|
+
|
389
413
|
if (ntg > N_SIMDWIDTH) {
|
390
414
|
if (sgitg == 0) {
|
391
415
|
buf[tiisg] = 0.0f;
|
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
|
|
428
452
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
429
453
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
430
454
|
|
431
|
-
device const float4 * psrc4 =
|
432
|
-
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
433
|
-
device float4 * pdst4 =
|
455
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
456
|
+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
457
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
434
458
|
|
435
459
|
// parallel max
|
436
460
|
float4 lmax4 = -INFINITY;
|
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
|
|
468
492
|
}
|
469
493
|
|
470
494
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
495
|
+
|
496
|
+
// This barrier fixes a failing test
|
497
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
498
|
+
threadgroup_barrier(mem_flags::mem_none);
|
499
|
+
|
471
500
|
float sum = simd_sum(lsum);
|
501
|
+
|
472
502
|
if (ntg > N_SIMDWIDTH) {
|
473
503
|
if (sgitg == 0) {
|
474
504
|
buf[tiisg] = 0.0f;
|
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
|
|
639
669
|
}
|
640
670
|
}
|
641
671
|
|
672
|
+
kernel void kernel_group_norm(
|
673
|
+
device const float * src0,
|
674
|
+
device float * dst,
|
675
|
+
constant int64_t & ne00,
|
676
|
+
constant int64_t & ne01,
|
677
|
+
constant int64_t & ne02,
|
678
|
+
constant uint64_t & nb00,
|
679
|
+
constant uint64_t & nb01,
|
680
|
+
constant uint64_t & nb02,
|
681
|
+
constant int32_t & n_groups,
|
682
|
+
constant float & eps,
|
683
|
+
threadgroup float * buf [[threadgroup(0)]],
|
684
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
685
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
686
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
687
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
688
|
+
uint ntg[[threads_per_threadgroup]]) {
|
689
|
+
const int64_t ne = ne00*ne01*ne02;
|
690
|
+
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
691
|
+
|
692
|
+
int start = tgpig * gs;
|
693
|
+
int end = start + gs;
|
694
|
+
|
695
|
+
start += tpitg;
|
696
|
+
|
697
|
+
if (end >= ne) {
|
698
|
+
end = ne;
|
699
|
+
}
|
700
|
+
|
701
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
702
|
+
|
703
|
+
for (int j = start; j < end; j += ntg) {
|
704
|
+
tmp += src0[j];
|
705
|
+
}
|
706
|
+
|
707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
708
|
+
tmp = simd_sum(tmp);
|
709
|
+
if (ntg > N_SIMDWIDTH) {
|
710
|
+
if (sgitg == 0) {
|
711
|
+
buf[tiisg] = 0.0f;
|
712
|
+
}
|
713
|
+
|
714
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
715
|
+
|
716
|
+
if (tiisg == 0) {
|
717
|
+
buf[sgitg] = tmp;
|
718
|
+
}
|
719
|
+
|
720
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
721
|
+
|
722
|
+
tmp = buf[tiisg];
|
723
|
+
tmp = simd_sum(tmp);
|
724
|
+
}
|
725
|
+
|
726
|
+
const float mean = tmp / gs;
|
727
|
+
tmp = 0.0f;
|
728
|
+
|
729
|
+
for (int j = start; j < end; j += ntg) {
|
730
|
+
float xi = src0[j] - mean;
|
731
|
+
dst[j] = xi;
|
732
|
+
tmp += xi * xi;
|
733
|
+
}
|
734
|
+
|
735
|
+
tmp = simd_sum(tmp);
|
736
|
+
if (ntg > N_SIMDWIDTH) {
|
737
|
+
if (sgitg == 0) {
|
738
|
+
buf[tiisg] = 0.0f;
|
739
|
+
}
|
740
|
+
|
741
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
742
|
+
|
743
|
+
if (tiisg == 0) {
|
744
|
+
buf[sgitg] = tmp;
|
745
|
+
}
|
746
|
+
|
747
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
748
|
+
|
749
|
+
tmp = buf[tiisg];
|
750
|
+
tmp = simd_sum(tmp);
|
751
|
+
}
|
752
|
+
|
753
|
+
const float variance = tmp / gs;
|
754
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
755
|
+
for (int j = start; j < end; j += ntg) {
|
756
|
+
dst[j] *= scale;
|
757
|
+
}
|
758
|
+
}
|
759
|
+
|
642
760
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
643
761
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
644
762
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
@@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
731
849
|
// giard against the number of rows not being divisible by
|
732
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
733
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
734
|
-
void
|
852
|
+
void mul_vec_q_n_f32_impl(
|
735
853
|
device const void * src0,
|
736
854
|
device const float * src1,
|
737
855
|
device float * dst,
|
@@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
813
931
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
814
932
|
uint tiisg[[thread_index_in_simdgroup]],
|
815
933
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
816
|
-
|
934
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
817
935
|
}
|
818
936
|
|
819
937
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
832
950
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
833
951
|
uint tiisg[[thread_index_in_simdgroup]],
|
834
952
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
835
|
-
|
953
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
836
954
|
}
|
837
955
|
|
838
956
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
851
969
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
852
970
|
uint tiisg[[thread_index_in_simdgroup]],
|
853
971
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
854
|
-
|
972
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
855
973
|
}
|
856
974
|
|
857
975
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
870
988
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
871
989
|
uint tiisg[[thread_index_in_simdgroup]],
|
872
990
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
873
|
-
|
991
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
874
992
|
}
|
875
993
|
|
876
994
|
|
877
995
|
#define NB_Q8_0 8
|
878
996
|
|
879
|
-
|
997
|
+
void kernel_mul_mv_q8_0_f32_impl(
|
880
998
|
device const void * src0,
|
881
999
|
device const float * src1,
|
882
1000
|
device float * dst,
|
883
1001
|
constant int64_t & ne00,
|
884
|
-
constant int64_t & ne01
|
885
|
-
constant int64_t & ne02
|
886
|
-
constant int64_t & ne10
|
887
|
-
constant int64_t & ne12
|
888
|
-
constant int64_t & ne0
|
889
|
-
constant int64_t & ne1
|
890
|
-
constant uint & r2
|
891
|
-
constant uint & r3
|
1002
|
+
constant int64_t & ne01,
|
1003
|
+
constant int64_t & ne02,
|
1004
|
+
constant int64_t & ne10,
|
1005
|
+
constant int64_t & ne12,
|
1006
|
+
constant int64_t & ne0,
|
1007
|
+
constant int64_t & ne1,
|
1008
|
+
constant uint & r2,
|
1009
|
+
constant uint & r3,
|
892
1010
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
893
|
-
uint
|
894
|
-
uint
|
1011
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1012
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
895
1013
|
const int nr = N_DST;
|
896
1014
|
const int nsg = N_SIMDGROUP;
|
897
1015
|
const int nw = N_SIMDWIDTH;
|
@@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
945
1063
|
}
|
946
1064
|
}
|
947
1065
|
|
1066
|
+
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
1067
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
1068
|
+
device const void * src0,
|
1069
|
+
device const float * src1,
|
1070
|
+
device float * dst,
|
1071
|
+
constant int64_t & ne00,
|
1072
|
+
constant int64_t & ne01,
|
1073
|
+
constant int64_t & ne02,
|
1074
|
+
constant int64_t & ne10,
|
1075
|
+
constant int64_t & ne12,
|
1076
|
+
constant int64_t & ne0,
|
1077
|
+
constant int64_t & ne1,
|
1078
|
+
constant uint & r2 [[buffer(17)]],
|
1079
|
+
constant uint & r3 [[buffer(18)]],
|
1080
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1083
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1084
|
+
}
|
1085
|
+
|
948
1086
|
#define N_F32_F32 4
|
949
1087
|
|
950
|
-
|
1088
|
+
void kernel_mul_mv_f32_f32_impl(
|
951
1089
|
device const char * src0,
|
952
1090
|
device const char * src1,
|
953
1091
|
device float * dst,
|
@@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
965
1103
|
constant uint64_t & nb12,
|
966
1104
|
constant int64_t & ne0,
|
967
1105
|
constant int64_t & ne1,
|
968
|
-
constant uint & r2
|
969
|
-
constant uint & r3
|
1106
|
+
constant uint & r2,
|
1107
|
+
constant uint & r3,
|
970
1108
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
971
1109
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
972
1110
|
|
@@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
|
|
1025
1163
|
}
|
1026
1164
|
}
|
1027
1165
|
|
1166
|
+
[[host_name("kernel_mul_mv_f32_f32")]]
|
1167
|
+
kernel void kernel_mul_mv_f32_f32(
|
1168
|
+
device const char * src0,
|
1169
|
+
device const char * src1,
|
1170
|
+
device float * dst,
|
1171
|
+
constant int64_t & ne00,
|
1172
|
+
constant int64_t & ne01,
|
1173
|
+
constant int64_t & ne02,
|
1174
|
+
constant uint64_t & nb00,
|
1175
|
+
constant uint64_t & nb01,
|
1176
|
+
constant uint64_t & nb02,
|
1177
|
+
constant int64_t & ne10,
|
1178
|
+
constant int64_t & ne11,
|
1179
|
+
constant int64_t & ne12,
|
1180
|
+
constant uint64_t & nb10,
|
1181
|
+
constant uint64_t & nb11,
|
1182
|
+
constant uint64_t & nb12,
|
1183
|
+
constant int64_t & ne0,
|
1184
|
+
constant int64_t & ne1,
|
1185
|
+
constant uint & r2 [[buffer(17)]],
|
1186
|
+
constant uint & r3 [[buffer(18)]],
|
1187
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1188
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1189
|
+
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1190
|
+
}
|
1191
|
+
|
1028
1192
|
#define N_F16_F16 4
|
1029
1193
|
|
1030
1194
|
kernel void kernel_mul_mv_f16_f16(
|
@@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
|
|
1105
1269
|
}
|
1106
1270
|
}
|
1107
1271
|
|
1108
|
-
|
1272
|
+
void kernel_mul_mv_f16_f32_1row_impl(
|
1109
1273
|
device const char * src0,
|
1110
1274
|
device const char * src1,
|
1111
1275
|
device float * dst,
|
@@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
1123
1287
|
constant uint64_t & nb12,
|
1124
1288
|
constant int64_t & ne0,
|
1125
1289
|
constant int64_t & ne1,
|
1126
|
-
constant uint & r2
|
1127
|
-
constant uint & r3
|
1290
|
+
constant uint & r2,
|
1291
|
+
constant uint & r3,
|
1128
1292
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1129
1293
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1130
1294
|
|
@@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
1161
1325
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
1162
1326
|
}
|
1163
1327
|
}
|
1328
|
+
}
|
1164
1329
|
|
1330
|
+
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
1331
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
1332
|
+
device const char * src0,
|
1333
|
+
device const char * src1,
|
1334
|
+
device float * dst,
|
1335
|
+
constant int64_t & ne00,
|
1336
|
+
constant int64_t & ne01,
|
1337
|
+
constant int64_t & ne02,
|
1338
|
+
constant uint64_t & nb00,
|
1339
|
+
constant uint64_t & nb01,
|
1340
|
+
constant uint64_t & nb02,
|
1341
|
+
constant int64_t & ne10,
|
1342
|
+
constant int64_t & ne11,
|
1343
|
+
constant int64_t & ne12,
|
1344
|
+
constant uint64_t & nb10,
|
1345
|
+
constant uint64_t & nb11,
|
1346
|
+
constant uint64_t & nb12,
|
1347
|
+
constant int64_t & ne0,
|
1348
|
+
constant int64_t & ne1,
|
1349
|
+
constant uint & r2 [[buffer(17)]],
|
1350
|
+
constant uint & r3 [[buffer(18)]],
|
1351
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1353
|
+
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1165
1354
|
}
|
1166
1355
|
|
1167
1356
|
#define N_F16_F32 4
|
1168
1357
|
|
1169
|
-
|
1358
|
+
void kernel_mul_mv_f16_f32_impl(
|
1170
1359
|
device const char * src0,
|
1171
1360
|
device const char * src1,
|
1172
1361
|
device float * dst,
|
@@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1184
1373
|
constant uint64_t & nb12,
|
1185
1374
|
constant int64_t & ne0,
|
1186
1375
|
constant int64_t & ne1,
|
1187
|
-
constant uint & r2
|
1188
|
-
constant uint & r3
|
1376
|
+
constant uint & r2,
|
1377
|
+
constant uint & r3,
|
1189
1378
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1190
1379
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1191
1380
|
|
@@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1244
1433
|
}
|
1245
1434
|
}
|
1246
1435
|
|
1436
|
+
[[host_name("kernel_mul_mv_f16_f32")]]
|
1437
|
+
kernel void kernel_mul_mv_f16_f32(
|
1438
|
+
device const char * src0,
|
1439
|
+
device const char * src1,
|
1440
|
+
device float * dst,
|
1441
|
+
constant int64_t & ne00,
|
1442
|
+
constant int64_t & ne01,
|
1443
|
+
constant int64_t & ne02,
|
1444
|
+
constant uint64_t & nb00,
|
1445
|
+
constant uint64_t & nb01,
|
1446
|
+
constant uint64_t & nb02,
|
1447
|
+
constant int64_t & ne10,
|
1448
|
+
constant int64_t & ne11,
|
1449
|
+
constant int64_t & ne12,
|
1450
|
+
constant uint64_t & nb10,
|
1451
|
+
constant uint64_t & nb11,
|
1452
|
+
constant uint64_t & nb12,
|
1453
|
+
constant int64_t & ne0,
|
1454
|
+
constant int64_t & ne1,
|
1455
|
+
constant uint & r2 [[buffer(17)]],
|
1456
|
+
constant uint & r3 [[buffer(18)]],
|
1457
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1458
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1459
|
+
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1460
|
+
}
|
1461
|
+
|
1247
1462
|
// Assumes row size (ne00) is a multiple of 4
|
1248
1463
|
kernel void kernel_mul_mv_f16_f32_l4(
|
1249
1464
|
device const char * src0,
|
@@ -1548,25 +1763,116 @@ kernel void kernel_im2col_f16(
|
|
1548
1763
|
}
|
1549
1764
|
}
|
1550
1765
|
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1766
|
+
kernel void kernel_upscale_f32(
|
1767
|
+
device const char * src0,
|
1768
|
+
device char * dst,
|
1769
|
+
constant int64_t & ne00,
|
1770
|
+
constant int64_t & ne01,
|
1771
|
+
constant int64_t & ne02,
|
1772
|
+
constant int64_t & ne03,
|
1773
|
+
constant uint64_t & nb00,
|
1774
|
+
constant uint64_t & nb01,
|
1775
|
+
constant uint64_t & nb02,
|
1776
|
+
constant uint64_t & nb03,
|
1777
|
+
constant int64_t & ne0,
|
1778
|
+
constant int64_t & ne1,
|
1779
|
+
constant int64_t & ne2,
|
1780
|
+
constant int64_t & ne3,
|
1781
|
+
constant uint64_t & nb0,
|
1782
|
+
constant uint64_t & nb1,
|
1783
|
+
constant uint64_t & nb2,
|
1784
|
+
constant uint64_t & nb3,
|
1785
|
+
constant int32_t & sf,
|
1786
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1787
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1788
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1789
|
+
|
1790
|
+
const int64_t i3 = tgpig.z;
|
1791
|
+
const int64_t i2 = tgpig.y;
|
1792
|
+
const int64_t i1 = tgpig.x;
|
1793
|
+
|
1794
|
+
const int64_t i03 = i3;
|
1795
|
+
const int64_t i02 = i2;
|
1796
|
+
const int64_t i01 = i1/sf;
|
1797
|
+
|
1798
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1799
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1800
|
+
|
1801
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1802
|
+
dst_ptr[i0] = src0_ptr[i0/sf];
|
1803
|
+
}
|
1804
|
+
}
|
1805
|
+
|
1806
|
+
kernel void kernel_pad_f32(
|
1807
|
+
device const char * src0,
|
1808
|
+
device char * dst,
|
1809
|
+
constant int64_t & ne00,
|
1810
|
+
constant int64_t & ne01,
|
1811
|
+
constant int64_t & ne02,
|
1812
|
+
constant int64_t & ne03,
|
1813
|
+
constant uint64_t & nb00,
|
1814
|
+
constant uint64_t & nb01,
|
1815
|
+
constant uint64_t & nb02,
|
1816
|
+
constant uint64_t & nb03,
|
1817
|
+
constant int64_t & ne0,
|
1818
|
+
constant int64_t & ne1,
|
1819
|
+
constant int64_t & ne2,
|
1820
|
+
constant int64_t & ne3,
|
1821
|
+
constant uint64_t & nb0,
|
1822
|
+
constant uint64_t & nb1,
|
1823
|
+
constant uint64_t & nb2,
|
1824
|
+
constant uint64_t & nb3,
|
1825
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1826
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1827
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1828
|
+
|
1829
|
+
const int64_t i3 = tgpig.z;
|
1830
|
+
const int64_t i2 = tgpig.y;
|
1831
|
+
const int64_t i1 = tgpig.x;
|
1832
|
+
|
1833
|
+
const int64_t i03 = i3;
|
1834
|
+
const int64_t i02 = i2;
|
1835
|
+
const int64_t i01 = i1;
|
1836
|
+
|
1837
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1838
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1839
|
+
|
1840
|
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
1841
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1842
|
+
if (i0 < ne00) {
|
1843
|
+
dst_ptr[i0] = src0_ptr[i0];
|
1844
|
+
} else {
|
1845
|
+
dst_ptr[i0] = 0.0f;
|
1846
|
+
}
|
1847
|
+
}
|
1848
|
+
|
1849
|
+
return;
|
1850
|
+
}
|
1851
|
+
|
1852
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1853
|
+
dst_ptr[i0] = 0.0f;
|
1854
|
+
}
|
1855
|
+
}
|
1856
|
+
|
1857
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
1858
|
+
typedef void (argsort_t)(
|
1859
|
+
device const float * x,
|
1860
|
+
device int32_t * dst,
|
1861
|
+
constant int64_t & ncols,
|
1862
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1863
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1864
|
+
|
1865
|
+
template<ggml_sort_order order>
|
1866
|
+
kernel void kernel_argsort_f32_i32(
|
1867
|
+
device const float * x,
|
1868
|
+
device int32_t * dst,
|
1869
|
+
constant int64_t & ncols,
|
1870
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1871
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1872
|
+
// bitonic sort
|
1873
|
+
int col = tpitg[0];
|
1874
|
+
int row = tgpig[1];
|
1875
|
+
|
1570
1876
|
if (col >= ncols) return;
|
1571
1877
|
|
1572
1878
|
device const float * x_row = x + row * ncols;
|
@@ -1600,9 +1906,17 @@ kernel void kernel_argsort_f32_i32(
|
|
1600
1906
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
1601
1907
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
1602
1908
|
|
1909
|
+
kernel void kernel_leaky_relu_f32(
|
1910
|
+
device const float * src0,
|
1911
|
+
device float * dst,
|
1912
|
+
constant float & slope,
|
1913
|
+
uint tpig[[thread_position_in_grid]]) {
|
1914
|
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
1915
|
+
}
|
1916
|
+
|
1603
1917
|
kernel void kernel_cpy_f16_f16(
|
1604
|
-
device
|
1605
|
-
device
|
1918
|
+
device const half * src0,
|
1919
|
+
device half * dst,
|
1606
1920
|
constant int64_t & ne00,
|
1607
1921
|
constant int64_t & ne01,
|
1608
1922
|
constant int64_t & ne02,
|
@@ -1641,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
|
|
1641
1955
|
}
|
1642
1956
|
}
|
1643
1957
|
|
1958
|
+
kernel void kernel_cpy_f16_f32(
|
1959
|
+
device const half * src0,
|
1960
|
+
device float * dst,
|
1961
|
+
constant int64_t & ne00,
|
1962
|
+
constant int64_t & ne01,
|
1963
|
+
constant int64_t & ne02,
|
1964
|
+
constant int64_t & ne03,
|
1965
|
+
constant uint64_t & nb00,
|
1966
|
+
constant uint64_t & nb01,
|
1967
|
+
constant uint64_t & nb02,
|
1968
|
+
constant uint64_t & nb03,
|
1969
|
+
constant int64_t & ne0,
|
1970
|
+
constant int64_t & ne1,
|
1971
|
+
constant int64_t & ne2,
|
1972
|
+
constant int64_t & ne3,
|
1973
|
+
constant uint64_t & nb0,
|
1974
|
+
constant uint64_t & nb1,
|
1975
|
+
constant uint64_t & nb2,
|
1976
|
+
constant uint64_t & nb3,
|
1977
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1978
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1979
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1980
|
+
const int64_t i03 = tgpig[2];
|
1981
|
+
const int64_t i02 = tgpig[1];
|
1982
|
+
const int64_t i01 = tgpig[0];
|
1983
|
+
|
1984
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1985
|
+
|
1986
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1987
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1988
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1989
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1990
|
+
|
1991
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1992
|
+
|
1993
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1994
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1995
|
+
dst_data[i00] = src[0];
|
1996
|
+
}
|
1997
|
+
}
|
1998
|
+
|
1644
1999
|
kernel void kernel_cpy_f32_f16(
|
1645
2000
|
device const float * src0,
|
1646
2001
|
device half * dst,
|
@@ -1917,9 +2272,9 @@ kernel void kernel_cpy_f32_q4_1(
|
|
1917
2272
|
}
|
1918
2273
|
|
1919
2274
|
kernel void kernel_concat(
|
1920
|
-
device
|
1921
|
-
device
|
1922
|
-
device
|
2275
|
+
device const char * src0,
|
2276
|
+
device const char * src1,
|
2277
|
+
device char * dst,
|
1923
2278
|
constant int64_t & ne00,
|
1924
2279
|
constant int64_t & ne01,
|
1925
2280
|
constant int64_t & ne02,
|
@@ -1956,7 +2311,7 @@ kernel void kernel_concat(
|
|
1956
2311
|
const int64_t i12 = i02 % ne12;
|
1957
2312
|
const int64_t i11 = i01 % ne11;
|
1958
2313
|
|
1959
|
-
device const char * src0_ptr = src0 + i03
|
2314
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
1960
2315
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1961
2316
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1962
2317
|
|
@@ -2064,19 +2419,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
2064
2419
|
|
2065
2420
|
//====================================== dot products =========================
|
2066
2421
|
|
2067
|
-
|
2422
|
+
void kernel_mul_mv_q2_K_f32_impl(
|
2068
2423
|
device const void * src0,
|
2069
2424
|
device const float * src1,
|
2070
2425
|
device float * dst,
|
2071
2426
|
constant int64_t & ne00,
|
2072
|
-
constant int64_t & ne01
|
2073
|
-
constant int64_t & ne02
|
2074
|
-
constant int64_t & ne10
|
2075
|
-
constant int64_t & ne12
|
2076
|
-
constant int64_t & ne0
|
2077
|
-
constant int64_t & ne1
|
2078
|
-
constant uint & r2
|
2079
|
-
constant uint & r3
|
2427
|
+
constant int64_t & ne01,
|
2428
|
+
constant int64_t & ne02,
|
2429
|
+
constant int64_t & ne10,
|
2430
|
+
constant int64_t & ne12,
|
2431
|
+
constant int64_t & ne0,
|
2432
|
+
constant int64_t & ne1,
|
2433
|
+
constant uint & r2,
|
2434
|
+
constant uint & r3,
|
2080
2435
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2081
2436
|
uint tiisg[[thread_index_in_simdgroup]],
|
2082
2437
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2214,8 +2569,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
2214
2569
|
}
|
2215
2570
|
}
|
2216
2571
|
|
2217
|
-
|
2218
|
-
kernel void
|
2572
|
+
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
2573
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
2219
2574
|
device const void * src0,
|
2220
2575
|
device const float * src1,
|
2221
2576
|
device float * dst,
|
@@ -2229,8 +2584,29 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2229
2584
|
constant uint & r2 [[buffer(17)]],
|
2230
2585
|
constant uint & r3 [[buffer(18)]],
|
2231
2586
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2232
|
-
uint
|
2233
|
-
uint
|
2587
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2588
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2589
|
+
|
2590
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2591
|
+
}
|
2592
|
+
|
2593
|
+
#if QK_K == 256
|
2594
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
2595
|
+
device const void * src0,
|
2596
|
+
device const float * src1,
|
2597
|
+
device float * dst,
|
2598
|
+
constant int64_t & ne00,
|
2599
|
+
constant int64_t & ne01,
|
2600
|
+
constant int64_t & ne02,
|
2601
|
+
constant int64_t & ne10,
|
2602
|
+
constant int64_t & ne12,
|
2603
|
+
constant int64_t & ne0,
|
2604
|
+
constant int64_t & ne1,
|
2605
|
+
constant uint & r2,
|
2606
|
+
constant uint & r3,
|
2607
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2608
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2609
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2234
2610
|
|
2235
2611
|
const int nb = ne00/QK_K;
|
2236
2612
|
|
@@ -2373,19 +2749,19 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2373
2749
|
}
|
2374
2750
|
}
|
2375
2751
|
#else
|
2376
|
-
|
2752
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
2377
2753
|
device const void * src0,
|
2378
2754
|
device const float * src1,
|
2379
2755
|
device float * dst,
|
2380
2756
|
constant int64_t & ne00,
|
2381
|
-
constant int64_t & ne01
|
2382
|
-
constant int64_t & ne02
|
2383
|
-
constant int64_t & ne10
|
2384
|
-
constant int64_t & ne12
|
2385
|
-
constant int64_t & ne0
|
2386
|
-
constant int64_t & ne1
|
2387
|
-
constant uint & r2
|
2388
|
-
constant uint & r3
|
2757
|
+
constant int64_t & ne01,
|
2758
|
+
constant int64_t & ne02,
|
2759
|
+
constant int64_t & ne10,
|
2760
|
+
constant int64_t & ne12,
|
2761
|
+
constant int64_t & ne0,
|
2762
|
+
constant int64_t & ne1,
|
2763
|
+
constant uint & r2,
|
2764
|
+
constant uint & r3,
|
2389
2765
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2390
2766
|
uint tiisg[[thread_index_in_simdgroup]],
|
2391
2767
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2450,20 +2826,41 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2450
2826
|
}
|
2451
2827
|
#endif
|
2452
2828
|
|
2829
|
+
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
2830
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
2831
|
+
device const void * src0,
|
2832
|
+
device const float * src1,
|
2833
|
+
device float * dst,
|
2834
|
+
constant int64_t & ne00,
|
2835
|
+
constant int64_t & ne01[[buffer(4)]],
|
2836
|
+
constant int64_t & ne02[[buffer(5)]],
|
2837
|
+
constant int64_t & ne10[[buffer(9)]],
|
2838
|
+
constant int64_t & ne12[[buffer(11)]],
|
2839
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2840
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2841
|
+
constant uint & r2 [[buffer(17)]],
|
2842
|
+
constant uint & r3 [[buffer(18)]],
|
2843
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2844
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2845
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2846
|
+
|
2847
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2848
|
+
}
|
2849
|
+
|
2453
2850
|
#if QK_K == 256
|
2454
|
-
|
2851
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
2455
2852
|
device const void * src0,
|
2456
2853
|
device const float * src1,
|
2457
2854
|
device float * dst,
|
2458
2855
|
constant int64_t & ne00,
|
2459
|
-
constant int64_t & ne01
|
2460
|
-
constant int64_t & ne02
|
2461
|
-
constant int64_t & ne10
|
2462
|
-
constant int64_t & ne12
|
2463
|
-
constant int64_t & ne0
|
2464
|
-
constant int64_t & ne1
|
2465
|
-
constant uint & r2
|
2466
|
-
constant uint & r3
|
2856
|
+
constant int64_t & ne01,
|
2857
|
+
constant int64_t & ne02,
|
2858
|
+
constant int64_t & ne10,
|
2859
|
+
constant int64_t & ne12,
|
2860
|
+
constant int64_t & ne0,
|
2861
|
+
constant int64_t & ne1,
|
2862
|
+
constant uint & r2,
|
2863
|
+
constant uint & r3,
|
2467
2864
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2468
2865
|
uint tiisg[[thread_index_in_simdgroup]],
|
2469
2866
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2564,19 +2961,19 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2564
2961
|
}
|
2565
2962
|
}
|
2566
2963
|
#else
|
2567
|
-
|
2964
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
2568
2965
|
device const void * src0,
|
2569
2966
|
device const float * src1,
|
2570
2967
|
device float * dst,
|
2571
2968
|
constant int64_t & ne00,
|
2572
|
-
constant int64_t & ne01
|
2573
|
-
constant int64_t & ne02
|
2574
|
-
constant int64_t & ne10
|
2575
|
-
constant int64_t & ne12
|
2576
|
-
constant int64_t & ne0
|
2577
|
-
constant int64_t & ne1
|
2578
|
-
constant uint & r2
|
2579
|
-
constant uint & r3
|
2969
|
+
constant int64_t & ne01,
|
2970
|
+
constant int64_t & ne02,
|
2971
|
+
constant int64_t & ne10,
|
2972
|
+
constant int64_t & ne12,
|
2973
|
+
constant int64_t & ne0,
|
2974
|
+
constant int64_t & ne1,
|
2975
|
+
constant uint & r2,
|
2976
|
+
constant uint & r3,
|
2580
2977
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2581
2978
|
uint tiisg[[thread_index_in_simdgroup]],
|
2582
2979
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2660,7 +3057,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2660
3057
|
}
|
2661
3058
|
#endif
|
2662
3059
|
|
2663
|
-
|
3060
|
+
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
3061
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
2664
3062
|
device const void * src0,
|
2665
3063
|
device const float * src1,
|
2666
3064
|
device float * dst,
|
@@ -2677,6 +3075,26 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2677
3075
|
uint tiisg[[thread_index_in_simdgroup]],
|
2678
3076
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2679
3077
|
|
3078
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3079
|
+
}
|
3080
|
+
|
3081
|
+
void kernel_mul_mv_q5_K_f32_impl(
|
3082
|
+
device const void * src0,
|
3083
|
+
device const float * src1,
|
3084
|
+
device float * dst,
|
3085
|
+
constant int64_t & ne00,
|
3086
|
+
constant int64_t & ne01,
|
3087
|
+
constant int64_t & ne02,
|
3088
|
+
constant int64_t & ne10,
|
3089
|
+
constant int64_t & ne12,
|
3090
|
+
constant int64_t & ne0,
|
3091
|
+
constant int64_t & ne1,
|
3092
|
+
constant uint & r2,
|
3093
|
+
constant uint & r3,
|
3094
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3095
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3096
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3097
|
+
|
2680
3098
|
const int nb = ne00/QK_K;
|
2681
3099
|
|
2682
3100
|
const int64_t r0 = tgpig.x;
|
@@ -2836,10 +3254,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2836
3254
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
2837
3255
|
}
|
2838
3256
|
}
|
2839
|
-
|
2840
3257
|
}
|
2841
3258
|
|
2842
|
-
|
3259
|
+
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
3260
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
2843
3261
|
device const void * src0,
|
2844
3262
|
device const float * src1,
|
2845
3263
|
device float * dst,
|
@@ -2853,18 +3271,38 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2853
3271
|
constant uint & r2 [[buffer(17)]],
|
2854
3272
|
constant uint & r3 [[buffer(18)]],
|
2855
3273
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2856
|
-
uint
|
2857
|
-
uint
|
2858
|
-
|
2859
|
-
const uint8_t kmask1 = 0x03;
|
2860
|
-
const uint8_t kmask2 = 0x0C;
|
2861
|
-
const uint8_t kmask3 = 0x30;
|
2862
|
-
const uint8_t kmask4 = 0xC0;
|
3274
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3275
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2863
3276
|
|
2864
|
-
|
3277
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3278
|
+
}
|
2865
3279
|
|
2866
|
-
|
2867
|
-
|
3280
|
+
void kernel_mul_mv_q6_K_f32_impl(
|
3281
|
+
device const void * src0,
|
3282
|
+
device const float * src1,
|
3283
|
+
device float * dst,
|
3284
|
+
constant int64_t & ne00,
|
3285
|
+
constant int64_t & ne01,
|
3286
|
+
constant int64_t & ne02,
|
3287
|
+
constant int64_t & ne10,
|
3288
|
+
constant int64_t & ne12,
|
3289
|
+
constant int64_t & ne0,
|
3290
|
+
constant int64_t & ne1,
|
3291
|
+
constant uint & r2,
|
3292
|
+
constant uint & r3,
|
3293
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3294
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3295
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3296
|
+
|
3297
|
+
const uint8_t kmask1 = 0x03;
|
3298
|
+
const uint8_t kmask2 = 0x0C;
|
3299
|
+
const uint8_t kmask3 = 0x30;
|
3300
|
+
const uint8_t kmask4 = 0xC0;
|
3301
|
+
|
3302
|
+
const int nb = ne00/QK_K;
|
3303
|
+
|
3304
|
+
const int64_t r0 = tgpig.x;
|
3305
|
+
const int64_t r1 = tgpig.y;
|
2868
3306
|
const int im = tgpig.z;
|
2869
3307
|
|
2870
3308
|
const int row = 2 * r0 + sgitg;
|
@@ -2945,6 +3383,27 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2945
3383
|
}
|
2946
3384
|
}
|
2947
3385
|
|
3386
|
+
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
3387
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
3388
|
+
device const void * src0,
|
3389
|
+
device const float * src1,
|
3390
|
+
device float * dst,
|
3391
|
+
constant int64_t & ne00,
|
3392
|
+
constant int64_t & ne01[[buffer(4)]],
|
3393
|
+
constant int64_t & ne02[[buffer(5)]],
|
3394
|
+
constant int64_t & ne10[[buffer(9)]],
|
3395
|
+
constant int64_t & ne12[[buffer(11)]],
|
3396
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3397
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3398
|
+
constant uint & r2 [[buffer(17)]],
|
3399
|
+
constant uint & r3 [[buffer(18)]],
|
3400
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3401
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3402
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3403
|
+
|
3404
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3405
|
+
}
|
3406
|
+
|
2948
3407
|
//============================= templates and their specializations =============================
|
2949
3408
|
|
2950
3409
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
@@ -3062,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
3062
3521
|
|
3063
3522
|
template <typename type4x4>
|
3064
3523
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
3065
|
-
const
|
3066
|
-
const
|
3524
|
+
const float d = xb->d;
|
3525
|
+
const float min = xb->dmin;
|
3067
3526
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
3068
|
-
|
3527
|
+
float dl, ml;
|
3069
3528
|
uint8_t sc = xb->scales[il];
|
3070
3529
|
|
3071
3530
|
#if QK_K == 256
|
@@ -3135,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
3135
3594
|
q = q + (il/4) * 32 + 16 * (il&1);
|
3136
3595
|
il = il & 3;
|
3137
3596
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
3138
|
-
const
|
3139
|
-
const
|
3140
|
-
const
|
3141
|
-
const
|
3597
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3598
|
+
const float min = xb->dmin;
|
3599
|
+
const float dl = d * sc[0];
|
3600
|
+
const float ml = min * sc[1];
|
3142
3601
|
#else
|
3143
3602
|
q = q + 16 * (il&1);
|
3144
3603
|
device const uint8_t * s = xb->scales;
|
@@ -3165,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
3165
3624
|
uint8_t ul = 1 << (il/2);
|
3166
3625
|
il = il & 3;
|
3167
3626
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
3168
|
-
const
|
3169
|
-
const
|
3170
|
-
const
|
3171
|
-
const
|
3627
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3628
|
+
const float min = xb->dmin;
|
3629
|
+
const float dl = d * sc[0];
|
3630
|
+
const float ml = min * sc[1];
|
3172
3631
|
|
3173
|
-
const ushort mask
|
3174
|
-
const
|
3632
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
3633
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
3175
3634
|
for (int i = 0; i < 16; ++i) {
|
3176
3635
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
3177
3636
|
}
|
@@ -3219,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
3219
3678
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
3220
3679
|
kernel void kernel_get_rows(
|
3221
3680
|
device const void * src0,
|
3222
|
-
device const
|
3681
|
+
device const char * src1,
|
3223
3682
|
device float * dst,
|
3224
3683
|
constant int64_t & ne00,
|
3225
3684
|
constant uint64_t & nb01,
|
3685
|
+
constant uint64_t & nb02,
|
3686
|
+
constant int64_t & ne10,
|
3687
|
+
constant uint64_t & nb10,
|
3688
|
+
constant uint64_t & nb11,
|
3226
3689
|
constant uint64_t & nb1,
|
3227
|
-
|
3690
|
+
constant uint64_t & nb2,
|
3691
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3228
3692
|
uint tiitg[[thread_index_in_threadgroup]],
|
3229
|
-
|
3230
|
-
const
|
3231
|
-
const
|
3693
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3694
|
+
//const int64_t i = tgpig;
|
3695
|
+
//const int64_t r = ((device int32_t *) src1)[i];
|
3696
|
+
|
3697
|
+
const int64_t i10 = tgpig.x;
|
3698
|
+
const int64_t i11 = tgpig.y;
|
3232
3699
|
|
3233
|
-
|
3700
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3701
|
+
|
3702
|
+
const int64_t i02 = i11;
|
3703
|
+
|
3704
|
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
3234
3705
|
float4x4 temp;
|
3235
3706
|
dequantize_func(
|
3236
|
-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
3237
|
-
*(((device float4x4 *) ((device char *) dst +
|
3707
|
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
3708
|
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
3709
|
+
}
|
3710
|
+
}
|
3711
|
+
|
3712
|
+
kernel void kernel_get_rows_f32(
|
3713
|
+
device const void * src0,
|
3714
|
+
device const char * src1,
|
3715
|
+
device float * dst,
|
3716
|
+
constant int64_t & ne00,
|
3717
|
+
constant uint64_t & nb01,
|
3718
|
+
constant uint64_t & nb02,
|
3719
|
+
constant int64_t & ne10,
|
3720
|
+
constant uint64_t & nb10,
|
3721
|
+
constant uint64_t & nb11,
|
3722
|
+
constant uint64_t & nb1,
|
3723
|
+
constant uint64_t & nb2,
|
3724
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3725
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3726
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3727
|
+
const int64_t i10 = tgpig.x;
|
3728
|
+
const int64_t i11 = tgpig.y;
|
3729
|
+
|
3730
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3731
|
+
|
3732
|
+
const int64_t i02 = i11;
|
3733
|
+
|
3734
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3735
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3736
|
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3737
|
+
}
|
3738
|
+
}
|
3739
|
+
|
3740
|
+
kernel void kernel_get_rows_f16(
|
3741
|
+
device const void * src0,
|
3742
|
+
device const char * src1,
|
3743
|
+
device float * dst,
|
3744
|
+
constant int64_t & ne00,
|
3745
|
+
constant uint64_t & nb01,
|
3746
|
+
constant uint64_t & nb02,
|
3747
|
+
constant int64_t & ne10,
|
3748
|
+
constant uint64_t & nb10,
|
3749
|
+
constant uint64_t & nb11,
|
3750
|
+
constant uint64_t & nb1,
|
3751
|
+
constant uint64_t & nb2,
|
3752
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3753
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3754
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3755
|
+
const int64_t i10 = tgpig.x;
|
3756
|
+
const int64_t i11 = tgpig.y;
|
3757
|
+
|
3758
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3759
|
+
|
3760
|
+
const int64_t i02 = i11;
|
3761
|
+
|
3762
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3763
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3764
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3238
3765
|
}
|
3239
3766
|
}
|
3240
3767
|
|
@@ -3426,19 +3953,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
3426
3953
|
|
3427
3954
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3428
3955
|
kernel void kernel_mul_mm_id(
|
3429
|
-
device const
|
3956
|
+
device const uchar * ids,
|
3430
3957
|
device const uchar * src1,
|
3431
|
-
device
|
3958
|
+
device uchar * dst,
|
3959
|
+
constant int64_t & nbi1,
|
3432
3960
|
constant int64_t & ne00,
|
3433
3961
|
constant int64_t & ne02,
|
3434
3962
|
constant int64_t & nb01,
|
3435
3963
|
constant int64_t & nb02,
|
3436
3964
|
constant int64_t & ne12,
|
3965
|
+
constant int64_t & ne13,
|
3437
3966
|
constant int64_t & nb10,
|
3438
3967
|
constant int64_t & nb11,
|
3439
3968
|
constant int64_t & nb12,
|
3440
3969
|
constant int64_t & ne0,
|
3441
3970
|
constant int64_t & ne1,
|
3971
|
+
constant int64_t & nb1,
|
3442
3972
|
constant uint & r2,
|
3443
3973
|
constant uint & r3,
|
3444
3974
|
constant int & idx,
|
@@ -3456,10 +3986,16 @@ kernel void kernel_mul_mm_id(
|
|
3456
3986
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3457
3987
|
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3458
3988
|
|
3989
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
3990
|
+
|
3991
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
3992
|
+
|
3993
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
3994
|
+
|
3459
3995
|
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3460
|
-
src0[
|
3461
|
-
src1,
|
3462
|
-
dst,
|
3996
|
+
src0[id],
|
3997
|
+
src1 + bid*nb11,
|
3998
|
+
(device float *) (dst + bid*nb1),
|
3463
3999
|
ne00,
|
3464
4000
|
ne02,
|
3465
4001
|
nb01,
|
@@ -3484,17 +4020,26 @@ kernel void kernel_mul_mm_id(
|
|
3484
4020
|
#define QK_NL 4
|
3485
4021
|
#endif
|
3486
4022
|
|
4023
|
+
//
|
4024
|
+
// get rows
|
4025
|
+
//
|
4026
|
+
|
3487
4027
|
typedef void (get_rows_t)(
|
3488
4028
|
device const void * src0,
|
3489
|
-
device const
|
4029
|
+
device const char * src1,
|
3490
4030
|
device float * dst,
|
3491
4031
|
constant int64_t & ne00,
|
3492
4032
|
constant uint64_t & nb01,
|
4033
|
+
constant uint64_t & nb02,
|
4034
|
+
constant int64_t & ne10,
|
4035
|
+
constant uint64_t & nb10,
|
4036
|
+
constant uint64_t & nb11,
|
3493
4037
|
constant uint64_t & nb1,
|
3494
|
-
|
4038
|
+
constant uint64_t & nb2,
|
4039
|
+
uint3, uint, uint3);
|
3495
4040
|
|
3496
|
-
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
3497
|
-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
4041
|
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
4042
|
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
3498
4043
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
3499
4044
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
3500
4045
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
@@ -3506,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
3506
4051
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
3507
4052
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
3508
4053
|
|
4054
|
+
//
|
4055
|
+
// matrix-matrix multiplication
|
4056
|
+
//
|
4057
|
+
|
3509
4058
|
typedef void (mat_mm_t)(
|
3510
4059
|
device const uchar * src0,
|
3511
4060
|
device const uchar * src1,
|
@@ -3538,20 +4087,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
3538
4087
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
3539
4088
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
3540
4089
|
|
4090
|
+
//
|
4091
|
+
// indirect matrix-matrix multiplication
|
4092
|
+
//
|
4093
|
+
|
3541
4094
|
typedef void (mat_mm_id_t)(
|
3542
|
-
device const
|
4095
|
+
device const uchar * ids,
|
3543
4096
|
device const uchar * src1,
|
3544
|
-
device
|
4097
|
+
device uchar * dst,
|
4098
|
+
constant int64_t & nbi1,
|
3545
4099
|
constant int64_t & ne00,
|
3546
4100
|
constant int64_t & ne02,
|
3547
4101
|
constant int64_t & nb01,
|
3548
4102
|
constant int64_t & nb02,
|
3549
4103
|
constant int64_t & ne12,
|
4104
|
+
constant int64_t & ne13,
|
3550
4105
|
constant int64_t & nb10,
|
3551
4106
|
constant int64_t & nb11,
|
3552
4107
|
constant int64_t & nb12,
|
3553
4108
|
constant int64_t & ne0,
|
3554
4109
|
constant int64_t & ne1,
|
4110
|
+
constant int64_t & nb1,
|
3555
4111
|
constant uint & r2,
|
3556
4112
|
constant uint & r3,
|
3557
4113
|
constant int & idx,
|
@@ -3578,3 +4134,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
3578
4134
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
3579
4135
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
3580
4136
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
4137
|
+
|
4138
|
+
//
|
4139
|
+
// matrix-vector multiplication
|
4140
|
+
//
|
4141
|
+
|
4142
|
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
4143
|
+
kernel void kernel_mul_mv_id_f32_f32(
|
4144
|
+
device const char * ids,
|
4145
|
+
device const char * src1,
|
4146
|
+
device uchar * dst,
|
4147
|
+
constant int64_t & nbi1,
|
4148
|
+
constant int64_t & ne00,
|
4149
|
+
constant int64_t & ne01,
|
4150
|
+
constant int64_t & ne02,
|
4151
|
+
constant uint64_t & nb00,
|
4152
|
+
constant uint64_t & nb01,
|
4153
|
+
constant uint64_t & nb02,
|
4154
|
+
constant int64_t & ne10,
|
4155
|
+
constant int64_t & ne11,
|
4156
|
+
constant int64_t & ne12,
|
4157
|
+
constant int64_t & ne13,
|
4158
|
+
constant uint64_t & nb10,
|
4159
|
+
constant uint64_t & nb11,
|
4160
|
+
constant uint64_t & nb12,
|
4161
|
+
constant int64_t & ne0,
|
4162
|
+
constant int64_t & ne1,
|
4163
|
+
constant int64_t & nb1,
|
4164
|
+
constant uint & r2,
|
4165
|
+
constant uint & r3,
|
4166
|
+
constant int & idx,
|
4167
|
+
device const char * src00,
|
4168
|
+
device const char * src01,
|
4169
|
+
device const char * src02,
|
4170
|
+
device const char * src03,
|
4171
|
+
device const char * src04,
|
4172
|
+
device const char * src05,
|
4173
|
+
device const char * src06,
|
4174
|
+
device const char * src07,
|
4175
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4176
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4177
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4178
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4179
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4180
|
+
|
4181
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4182
|
+
|
4183
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4184
|
+
|
4185
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4186
|
+
|
4187
|
+
kernel_mul_mv_f32_f32_impl(
|
4188
|
+
src0[id],
|
4189
|
+
src1 + bid*nb11,
|
4190
|
+
(device float *) (dst + bid*nb1),
|
4191
|
+
ne00,
|
4192
|
+
ne01,
|
4193
|
+
ne02,
|
4194
|
+
nb00,
|
4195
|
+
nb01,
|
4196
|
+
nb02,
|
4197
|
+
ne10,
|
4198
|
+
ne11,
|
4199
|
+
ne12,
|
4200
|
+
nb10,
|
4201
|
+
nb11,
|
4202
|
+
nb12,
|
4203
|
+
ne0,
|
4204
|
+
ne1,
|
4205
|
+
r2,
|
4206
|
+
r3,
|
4207
|
+
tgpig,
|
4208
|
+
tiisg);
|
4209
|
+
}
|
4210
|
+
|
4211
|
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
4212
|
+
kernel void kernel_mul_mv_id_f16_f32(
|
4213
|
+
device const char * ids,
|
4214
|
+
device const char * src1,
|
4215
|
+
device uchar * dst,
|
4216
|
+
constant int64_t & nbi1,
|
4217
|
+
constant int64_t & ne00,
|
4218
|
+
constant int64_t & ne01,
|
4219
|
+
constant int64_t & ne02,
|
4220
|
+
constant uint64_t & nb00,
|
4221
|
+
constant uint64_t & nb01,
|
4222
|
+
constant uint64_t & nb02,
|
4223
|
+
constant int64_t & ne10,
|
4224
|
+
constant int64_t & ne11,
|
4225
|
+
constant int64_t & ne12,
|
4226
|
+
constant int64_t & ne13,
|
4227
|
+
constant uint64_t & nb10,
|
4228
|
+
constant uint64_t & nb11,
|
4229
|
+
constant uint64_t & nb12,
|
4230
|
+
constant int64_t & ne0,
|
4231
|
+
constant int64_t & ne1,
|
4232
|
+
constant int64_t & nb1,
|
4233
|
+
constant uint & r2,
|
4234
|
+
constant uint & r3,
|
4235
|
+
constant int & idx,
|
4236
|
+
device const char * src00,
|
4237
|
+
device const char * src01,
|
4238
|
+
device const char * src02,
|
4239
|
+
device const char * src03,
|
4240
|
+
device const char * src04,
|
4241
|
+
device const char * src05,
|
4242
|
+
device const char * src06,
|
4243
|
+
device const char * src07,
|
4244
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4245
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4246
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4247
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4248
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4249
|
+
|
4250
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4251
|
+
|
4252
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4253
|
+
|
4254
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4255
|
+
|
4256
|
+
kernel_mul_mv_f16_f32_impl(
|
4257
|
+
src0[id],
|
4258
|
+
src1 + bid*nb11,
|
4259
|
+
(device float *) (dst + bid*nb1),
|
4260
|
+
ne00,
|
4261
|
+
ne01,
|
4262
|
+
ne02,
|
4263
|
+
nb00,
|
4264
|
+
nb01,
|
4265
|
+
nb02,
|
4266
|
+
ne10,
|
4267
|
+
ne11,
|
4268
|
+
ne12,
|
4269
|
+
nb10,
|
4270
|
+
nb11,
|
4271
|
+
nb12,
|
4272
|
+
ne0,
|
4273
|
+
ne1,
|
4274
|
+
r2,
|
4275
|
+
r3,
|
4276
|
+
tgpig,
|
4277
|
+
tiisg);
|
4278
|
+
}
|
4279
|
+
|
4280
|
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
4281
|
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
4282
|
+
device const char * ids,
|
4283
|
+
device const char * src1,
|
4284
|
+
device uchar * dst,
|
4285
|
+
constant int64_t & nbi1,
|
4286
|
+
constant int64_t & ne00,
|
4287
|
+
constant int64_t & ne01,
|
4288
|
+
constant int64_t & ne02,
|
4289
|
+
constant uint64_t & nb00,
|
4290
|
+
constant uint64_t & nb01,
|
4291
|
+
constant uint64_t & nb02,
|
4292
|
+
constant int64_t & ne10,
|
4293
|
+
constant int64_t & ne11,
|
4294
|
+
constant int64_t & ne12,
|
4295
|
+
constant int64_t & ne13,
|
4296
|
+
constant uint64_t & nb10,
|
4297
|
+
constant uint64_t & nb11,
|
4298
|
+
constant uint64_t & nb12,
|
4299
|
+
constant int64_t & ne0,
|
4300
|
+
constant int64_t & ne1,
|
4301
|
+
constant int64_t & nb1,
|
4302
|
+
constant uint & r2,
|
4303
|
+
constant uint & r3,
|
4304
|
+
constant int & idx,
|
4305
|
+
device const char * src00,
|
4306
|
+
device const char * src01,
|
4307
|
+
device const char * src02,
|
4308
|
+
device const char * src03,
|
4309
|
+
device const char * src04,
|
4310
|
+
device const char * src05,
|
4311
|
+
device const char * src06,
|
4312
|
+
device const char * src07,
|
4313
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4314
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4315
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4316
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4317
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4318
|
+
|
4319
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4320
|
+
|
4321
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4322
|
+
|
4323
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4324
|
+
|
4325
|
+
kernel_mul_mv_q8_0_f32_impl(
|
4326
|
+
src0[id],
|
4327
|
+
(device const float *) (src1 + bid*nb11),
|
4328
|
+
(device float *) ( dst + bid*nb1),
|
4329
|
+
ne00,
|
4330
|
+
ne01,
|
4331
|
+
ne02,
|
4332
|
+
ne10,
|
4333
|
+
ne12,
|
4334
|
+
ne0,
|
4335
|
+
ne1,
|
4336
|
+
r2,
|
4337
|
+
r3,
|
4338
|
+
tgpig,
|
4339
|
+
tiisg,
|
4340
|
+
sgitg);
|
4341
|
+
}
|
4342
|
+
|
4343
|
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
4344
|
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
4345
|
+
device const char * ids,
|
4346
|
+
device const char * src1,
|
4347
|
+
device uchar * dst,
|
4348
|
+
constant int64_t & nbi1,
|
4349
|
+
constant int64_t & ne00,
|
4350
|
+
constant int64_t & ne01,
|
4351
|
+
constant int64_t & ne02,
|
4352
|
+
constant uint64_t & nb00,
|
4353
|
+
constant uint64_t & nb01,
|
4354
|
+
constant uint64_t & nb02,
|
4355
|
+
constant int64_t & ne10,
|
4356
|
+
constant int64_t & ne11,
|
4357
|
+
constant int64_t & ne12,
|
4358
|
+
constant int64_t & ne13,
|
4359
|
+
constant uint64_t & nb10,
|
4360
|
+
constant uint64_t & nb11,
|
4361
|
+
constant uint64_t & nb12,
|
4362
|
+
constant int64_t & ne0,
|
4363
|
+
constant int64_t & ne1,
|
4364
|
+
constant int64_t & nb1,
|
4365
|
+
constant uint & r2,
|
4366
|
+
constant uint & r3,
|
4367
|
+
constant int & idx,
|
4368
|
+
device const char * src00,
|
4369
|
+
device const char * src01,
|
4370
|
+
device const char * src02,
|
4371
|
+
device const char * src03,
|
4372
|
+
device const char * src04,
|
4373
|
+
device const char * src05,
|
4374
|
+
device const char * src06,
|
4375
|
+
device const char * src07,
|
4376
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4377
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4378
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4379
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4380
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4381
|
+
|
4382
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4383
|
+
|
4384
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4385
|
+
|
4386
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4387
|
+
|
4388
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4389
|
+
src0[id],
|
4390
|
+
(device const float *) (src1 + bid*nb11),
|
4391
|
+
(device float *) ( dst + bid*nb1),
|
4392
|
+
ne00,
|
4393
|
+
ne01,
|
4394
|
+
ne02,
|
4395
|
+
ne10,
|
4396
|
+
ne12,
|
4397
|
+
ne0,
|
4398
|
+
ne1,
|
4399
|
+
r2,
|
4400
|
+
r3,
|
4401
|
+
tgpig,
|
4402
|
+
tiisg,
|
4403
|
+
sgitg);
|
4404
|
+
}
|
4405
|
+
|
4406
|
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
4407
|
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
4408
|
+
device const char * ids,
|
4409
|
+
device const char * src1,
|
4410
|
+
device uchar * dst,
|
4411
|
+
constant int64_t & nbi1,
|
4412
|
+
constant int64_t & ne00,
|
4413
|
+
constant int64_t & ne01,
|
4414
|
+
constant int64_t & ne02,
|
4415
|
+
constant uint64_t & nb00,
|
4416
|
+
constant uint64_t & nb01,
|
4417
|
+
constant uint64_t & nb02,
|
4418
|
+
constant int64_t & ne10,
|
4419
|
+
constant int64_t & ne11,
|
4420
|
+
constant int64_t & ne12,
|
4421
|
+
constant int64_t & ne13,
|
4422
|
+
constant uint64_t & nb10,
|
4423
|
+
constant uint64_t & nb11,
|
4424
|
+
constant uint64_t & nb12,
|
4425
|
+
constant int64_t & ne0,
|
4426
|
+
constant int64_t & ne1,
|
4427
|
+
constant int64_t & nb1,
|
4428
|
+
constant uint & r2,
|
4429
|
+
constant uint & r3,
|
4430
|
+
constant int & idx,
|
4431
|
+
device const char * src00,
|
4432
|
+
device const char * src01,
|
4433
|
+
device const char * src02,
|
4434
|
+
device const char * src03,
|
4435
|
+
device const char * src04,
|
4436
|
+
device const char * src05,
|
4437
|
+
device const char * src06,
|
4438
|
+
device const char * src07,
|
4439
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4440
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4441
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4442
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4443
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4444
|
+
|
4445
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4446
|
+
|
4447
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4448
|
+
|
4449
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4450
|
+
|
4451
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4452
|
+
src0[id],
|
4453
|
+
(device const float *) (src1 + bid*nb11),
|
4454
|
+
(device float *) ( dst + bid*nb1),
|
4455
|
+
ne00,
|
4456
|
+
ne01,
|
4457
|
+
ne02,
|
4458
|
+
ne10,
|
4459
|
+
ne12,
|
4460
|
+
ne0,
|
4461
|
+
ne1,
|
4462
|
+
r2,
|
4463
|
+
r3,
|
4464
|
+
tgpig,
|
4465
|
+
tiisg,
|
4466
|
+
sgitg);
|
4467
|
+
}
|
4468
|
+
|
4469
|
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
4470
|
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
4471
|
+
device const char * ids,
|
4472
|
+
device const char * src1,
|
4473
|
+
device uchar * dst,
|
4474
|
+
constant int64_t & nbi1,
|
4475
|
+
constant int64_t & ne00,
|
4476
|
+
constant int64_t & ne01,
|
4477
|
+
constant int64_t & ne02,
|
4478
|
+
constant uint64_t & nb00,
|
4479
|
+
constant uint64_t & nb01,
|
4480
|
+
constant uint64_t & nb02,
|
4481
|
+
constant int64_t & ne10,
|
4482
|
+
constant int64_t & ne11,
|
4483
|
+
constant int64_t & ne12,
|
4484
|
+
constant int64_t & ne13,
|
4485
|
+
constant uint64_t & nb10,
|
4486
|
+
constant uint64_t & nb11,
|
4487
|
+
constant uint64_t & nb12,
|
4488
|
+
constant int64_t & ne0,
|
4489
|
+
constant int64_t & ne1,
|
4490
|
+
constant int64_t & nb1,
|
4491
|
+
constant uint & r2,
|
4492
|
+
constant uint & r3,
|
4493
|
+
constant int & idx,
|
4494
|
+
device const char * src00,
|
4495
|
+
device const char * src01,
|
4496
|
+
device const char * src02,
|
4497
|
+
device const char * src03,
|
4498
|
+
device const char * src04,
|
4499
|
+
device const char * src05,
|
4500
|
+
device const char * src06,
|
4501
|
+
device const char * src07,
|
4502
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4503
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4504
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4505
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4506
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4507
|
+
|
4508
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4509
|
+
|
4510
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4511
|
+
|
4512
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4513
|
+
|
4514
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4515
|
+
src0[id],
|
4516
|
+
(device const float *) (src1 + bid*nb11),
|
4517
|
+
(device float *) ( dst + bid*nb1),
|
4518
|
+
ne00,
|
4519
|
+
ne01,
|
4520
|
+
ne02,
|
4521
|
+
ne10,
|
4522
|
+
ne12,
|
4523
|
+
ne0,
|
4524
|
+
ne1,
|
4525
|
+
r2,
|
4526
|
+
r3,
|
4527
|
+
tgpig,
|
4528
|
+
tiisg,
|
4529
|
+
sgitg);
|
4530
|
+
}
|
4531
|
+
|
4532
|
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
4533
|
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
4534
|
+
device const char * ids,
|
4535
|
+
device const char * src1,
|
4536
|
+
device uchar * dst,
|
4537
|
+
constant int64_t & nbi1,
|
4538
|
+
constant int64_t & ne00,
|
4539
|
+
constant int64_t & ne01,
|
4540
|
+
constant int64_t & ne02,
|
4541
|
+
constant uint64_t & nb00,
|
4542
|
+
constant uint64_t & nb01,
|
4543
|
+
constant uint64_t & nb02,
|
4544
|
+
constant int64_t & ne10,
|
4545
|
+
constant int64_t & ne11,
|
4546
|
+
constant int64_t & ne12,
|
4547
|
+
constant int64_t & ne13,
|
4548
|
+
constant uint64_t & nb10,
|
4549
|
+
constant uint64_t & nb11,
|
4550
|
+
constant uint64_t & nb12,
|
4551
|
+
constant int64_t & ne0,
|
4552
|
+
constant int64_t & ne1,
|
4553
|
+
constant int64_t & nb1,
|
4554
|
+
constant uint & r2,
|
4555
|
+
constant uint & r3,
|
4556
|
+
constant int & idx,
|
4557
|
+
device const char * src00,
|
4558
|
+
device const char * src01,
|
4559
|
+
device const char * src02,
|
4560
|
+
device const char * src03,
|
4561
|
+
device const char * src04,
|
4562
|
+
device const char * src05,
|
4563
|
+
device const char * src06,
|
4564
|
+
device const char * src07,
|
4565
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4566
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4567
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4568
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4569
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4570
|
+
|
4571
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4572
|
+
|
4573
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4574
|
+
|
4575
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4576
|
+
|
4577
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4578
|
+
src0[id],
|
4579
|
+
(device const float *) (src1 + bid*nb11),
|
4580
|
+
(device float *) ( dst + bid*nb1),
|
4581
|
+
ne00,
|
4582
|
+
ne01,
|
4583
|
+
ne02,
|
4584
|
+
ne10,
|
4585
|
+
ne12,
|
4586
|
+
ne0,
|
4587
|
+
ne1,
|
4588
|
+
r2,
|
4589
|
+
r3,
|
4590
|
+
tgpig,
|
4591
|
+
tiisg,
|
4592
|
+
sgitg);
|
4593
|
+
}
|
4594
|
+
|
4595
|
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
4596
|
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
4597
|
+
device const char * ids,
|
4598
|
+
device const char * src1,
|
4599
|
+
device uchar * dst,
|
4600
|
+
constant int64_t & nbi1,
|
4601
|
+
constant int64_t & ne00,
|
4602
|
+
constant int64_t & ne01,
|
4603
|
+
constant int64_t & ne02,
|
4604
|
+
constant uint64_t & nb00,
|
4605
|
+
constant uint64_t & nb01,
|
4606
|
+
constant uint64_t & nb02,
|
4607
|
+
constant int64_t & ne10,
|
4608
|
+
constant int64_t & ne11,
|
4609
|
+
constant int64_t & ne12,
|
4610
|
+
constant int64_t & ne13,
|
4611
|
+
constant uint64_t & nb10,
|
4612
|
+
constant uint64_t & nb11,
|
4613
|
+
constant uint64_t & nb12,
|
4614
|
+
constant int64_t & ne0,
|
4615
|
+
constant int64_t & ne1,
|
4616
|
+
constant int64_t & nb1,
|
4617
|
+
constant uint & r2,
|
4618
|
+
constant uint & r3,
|
4619
|
+
constant int & idx,
|
4620
|
+
device const char * src00,
|
4621
|
+
device const char * src01,
|
4622
|
+
device const char * src02,
|
4623
|
+
device const char * src03,
|
4624
|
+
device const char * src04,
|
4625
|
+
device const char * src05,
|
4626
|
+
device const char * src06,
|
4627
|
+
device const char * src07,
|
4628
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4629
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4630
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4631
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4632
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4633
|
+
|
4634
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4635
|
+
|
4636
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4637
|
+
|
4638
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4639
|
+
|
4640
|
+
kernel_mul_mv_q2_K_f32_impl(
|
4641
|
+
src0[id],
|
4642
|
+
(device const float *) (src1 + bid*nb11),
|
4643
|
+
(device float *) ( dst + bid*nb1),
|
4644
|
+
ne00,
|
4645
|
+
ne01,
|
4646
|
+
ne02,
|
4647
|
+
ne10,
|
4648
|
+
ne12,
|
4649
|
+
ne0,
|
4650
|
+
ne1,
|
4651
|
+
r2,
|
4652
|
+
r3,
|
4653
|
+
tgpig,
|
4654
|
+
tiisg,
|
4655
|
+
sgitg);
|
4656
|
+
}
|
4657
|
+
|
4658
|
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
4659
|
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
4660
|
+
device const char * ids,
|
4661
|
+
device const char * src1,
|
4662
|
+
device uchar * dst,
|
4663
|
+
constant int64_t & nbi1,
|
4664
|
+
constant int64_t & ne00,
|
4665
|
+
constant int64_t & ne01,
|
4666
|
+
constant int64_t & ne02,
|
4667
|
+
constant uint64_t & nb00,
|
4668
|
+
constant uint64_t & nb01,
|
4669
|
+
constant uint64_t & nb02,
|
4670
|
+
constant int64_t & ne10,
|
4671
|
+
constant int64_t & ne11,
|
4672
|
+
constant int64_t & ne12,
|
4673
|
+
constant int64_t & ne13,
|
4674
|
+
constant uint64_t & nb10,
|
4675
|
+
constant uint64_t & nb11,
|
4676
|
+
constant uint64_t & nb12,
|
4677
|
+
constant int64_t & ne0,
|
4678
|
+
constant int64_t & ne1,
|
4679
|
+
constant int64_t & nb1,
|
4680
|
+
constant uint & r2,
|
4681
|
+
constant uint & r3,
|
4682
|
+
constant int & idx,
|
4683
|
+
device const char * src00,
|
4684
|
+
device const char * src01,
|
4685
|
+
device const char * src02,
|
4686
|
+
device const char * src03,
|
4687
|
+
device const char * src04,
|
4688
|
+
device const char * src05,
|
4689
|
+
device const char * src06,
|
4690
|
+
device const char * src07,
|
4691
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4692
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4693
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4694
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4695
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4696
|
+
|
4697
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4698
|
+
|
4699
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4700
|
+
|
4701
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4702
|
+
|
4703
|
+
kernel_mul_mv_q3_K_f32_impl(
|
4704
|
+
src0[id],
|
4705
|
+
(device const float *) (src1 + bid*nb11),
|
4706
|
+
(device float *) ( dst + bid*nb1),
|
4707
|
+
ne00,
|
4708
|
+
ne01,
|
4709
|
+
ne02,
|
4710
|
+
ne10,
|
4711
|
+
ne12,
|
4712
|
+
ne0,
|
4713
|
+
ne1,
|
4714
|
+
r2,
|
4715
|
+
r3,
|
4716
|
+
tgpig,
|
4717
|
+
tiisg,
|
4718
|
+
sgitg);
|
4719
|
+
}
|
4720
|
+
|
4721
|
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
4722
|
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
4723
|
+
device const char * ids,
|
4724
|
+
device const char * src1,
|
4725
|
+
device uchar * dst,
|
4726
|
+
constant int64_t & nbi1,
|
4727
|
+
constant int64_t & ne00,
|
4728
|
+
constant int64_t & ne01,
|
4729
|
+
constant int64_t & ne02,
|
4730
|
+
constant uint64_t & nb00,
|
4731
|
+
constant uint64_t & nb01,
|
4732
|
+
constant uint64_t & nb02,
|
4733
|
+
constant int64_t & ne10,
|
4734
|
+
constant int64_t & ne11,
|
4735
|
+
constant int64_t & ne12,
|
4736
|
+
constant int64_t & ne13,
|
4737
|
+
constant uint64_t & nb10,
|
4738
|
+
constant uint64_t & nb11,
|
4739
|
+
constant uint64_t & nb12,
|
4740
|
+
constant int64_t & ne0,
|
4741
|
+
constant int64_t & ne1,
|
4742
|
+
constant int64_t & nb1,
|
4743
|
+
constant uint & r2,
|
4744
|
+
constant uint & r3,
|
4745
|
+
constant int & idx,
|
4746
|
+
device const char * src00,
|
4747
|
+
device const char * src01,
|
4748
|
+
device const char * src02,
|
4749
|
+
device const char * src03,
|
4750
|
+
device const char * src04,
|
4751
|
+
device const char * src05,
|
4752
|
+
device const char * src06,
|
4753
|
+
device const char * src07,
|
4754
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4755
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4756
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4757
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4758
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4759
|
+
|
4760
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4761
|
+
|
4762
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4763
|
+
|
4764
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4765
|
+
|
4766
|
+
kernel_mul_mv_q4_K_f32_impl(
|
4767
|
+
src0[id],
|
4768
|
+
(device const float *) (src1 + bid*nb11),
|
4769
|
+
(device float *) ( dst + bid*nb1),
|
4770
|
+
ne00,
|
4771
|
+
ne01,
|
4772
|
+
ne02,
|
4773
|
+
ne10,
|
4774
|
+
ne12,
|
4775
|
+
ne0,
|
4776
|
+
ne1,
|
4777
|
+
r2,
|
4778
|
+
r3,
|
4779
|
+
tgpig,
|
4780
|
+
tiisg,
|
4781
|
+
sgitg);
|
4782
|
+
}
|
4783
|
+
|
4784
|
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
4785
|
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
4786
|
+
device const char * ids,
|
4787
|
+
device const char * src1,
|
4788
|
+
device uchar * dst,
|
4789
|
+
constant int64_t & nbi1,
|
4790
|
+
constant int64_t & ne00,
|
4791
|
+
constant int64_t & ne01,
|
4792
|
+
constant int64_t & ne02,
|
4793
|
+
constant uint64_t & nb00,
|
4794
|
+
constant uint64_t & nb01,
|
4795
|
+
constant uint64_t & nb02,
|
4796
|
+
constant int64_t & ne10,
|
4797
|
+
constant int64_t & ne11,
|
4798
|
+
constant int64_t & ne12,
|
4799
|
+
constant int64_t & ne13,
|
4800
|
+
constant uint64_t & nb10,
|
4801
|
+
constant uint64_t & nb11,
|
4802
|
+
constant uint64_t & nb12,
|
4803
|
+
constant int64_t & ne0,
|
4804
|
+
constant int64_t & ne1,
|
4805
|
+
constant int64_t & nb1,
|
4806
|
+
constant uint & r2,
|
4807
|
+
constant uint & r3,
|
4808
|
+
constant int & idx,
|
4809
|
+
device const char * src00,
|
4810
|
+
device const char * src01,
|
4811
|
+
device const char * src02,
|
4812
|
+
device const char * src03,
|
4813
|
+
device const char * src04,
|
4814
|
+
device const char * src05,
|
4815
|
+
device const char * src06,
|
4816
|
+
device const char * src07,
|
4817
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4818
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4819
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4820
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4821
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4822
|
+
|
4823
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4824
|
+
|
4825
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4826
|
+
|
4827
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4828
|
+
|
4829
|
+
kernel_mul_mv_q5_K_f32_impl(
|
4830
|
+
src0[id],
|
4831
|
+
(device const float *) (src1 + bid*nb11),
|
4832
|
+
(device float *) ( dst + bid*nb1),
|
4833
|
+
ne00,
|
4834
|
+
ne01,
|
4835
|
+
ne02,
|
4836
|
+
ne10,
|
4837
|
+
ne12,
|
4838
|
+
ne0,
|
4839
|
+
ne1,
|
4840
|
+
r2,
|
4841
|
+
r3,
|
4842
|
+
tgpig,
|
4843
|
+
tiisg,
|
4844
|
+
sgitg);
|
4845
|
+
}
|
4846
|
+
|
4847
|
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
4848
|
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
4849
|
+
device const char * ids,
|
4850
|
+
device const char * src1,
|
4851
|
+
device uchar * dst,
|
4852
|
+
constant int64_t & nbi1,
|
4853
|
+
constant int64_t & ne00,
|
4854
|
+
constant int64_t & ne01,
|
4855
|
+
constant int64_t & ne02,
|
4856
|
+
constant uint64_t & nb00,
|
4857
|
+
constant uint64_t & nb01,
|
4858
|
+
constant uint64_t & nb02,
|
4859
|
+
constant int64_t & ne10,
|
4860
|
+
constant int64_t & ne11,
|
4861
|
+
constant int64_t & ne12,
|
4862
|
+
constant int64_t & ne13,
|
4863
|
+
constant uint64_t & nb10,
|
4864
|
+
constant uint64_t & nb11,
|
4865
|
+
constant uint64_t & nb12,
|
4866
|
+
constant int64_t & ne0,
|
4867
|
+
constant int64_t & ne1,
|
4868
|
+
constant int64_t & nb1,
|
4869
|
+
constant uint & r2,
|
4870
|
+
constant uint & r3,
|
4871
|
+
constant int & idx,
|
4872
|
+
device const char * src00,
|
4873
|
+
device const char * src01,
|
4874
|
+
device const char * src02,
|
4875
|
+
device const char * src03,
|
4876
|
+
device const char * src04,
|
4877
|
+
device const char * src05,
|
4878
|
+
device const char * src06,
|
4879
|
+
device const char * src07,
|
4880
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4882
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4883
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4884
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4885
|
+
|
4886
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4887
|
+
|
4888
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4889
|
+
|
4890
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4891
|
+
|
4892
|
+
kernel_mul_mv_q6_K_f32_impl(
|
4893
|
+
src0[id],
|
4894
|
+
(device const float *) (src1 + bid*nb11),
|
4895
|
+
(device float *) ( dst + bid*nb1),
|
4896
|
+
ne00,
|
4897
|
+
ne01,
|
4898
|
+
ne02,
|
4899
|
+
ne10,
|
4900
|
+
ne12,
|
4901
|
+
ne0,
|
4902
|
+
ne1,
|
4903
|
+
r2,
|
4904
|
+
r3,
|
4905
|
+
tgpig,
|
4906
|
+
tiisg,
|
4907
|
+
sgitg);
|
4908
|
+
}
|