llama_cpp 0.10.0 → 0.10.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +18 -1
- data/ext/llama_cpp/src/ggml-alloc.c +12 -4
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-backend-impl.h +12 -8
- data/ext/llama_cpp/src/ggml-backend.c +75 -5
- data/ext/llama_cpp/src/ggml-backend.h +7 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +952 -232
- data/ext/llama_cpp/src/ggml-metal.h +3 -0
- data/ext/llama_cpp/src/ggml-metal.m +725 -98
- data/ext/llama_cpp/src/ggml-metal.metal +1508 -171
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +554 -215
- data/ext/llama_cpp/src/ggml.h +58 -23
- data/ext/llama_cpp/src/llama.cpp +1157 -851
- data/ext/llama_cpp/src/llama.h +9 -4
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- metadata +2 -2
@@ -79,6 +79,7 @@ kernel void kernel_add(
|
|
79
79
|
constant int64_t & nb1,
|
80
80
|
constant int64_t & nb2,
|
81
81
|
constant int64_t & nb3,
|
82
|
+
constant int64_t & offs,
|
82
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
83
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
84
85
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -90,9 +91,9 @@ kernel void kernel_add(
|
|
90
91
|
const int64_t i12 = i02 % ne12;
|
91
92
|
const int64_t i11 = i01 % ne11;
|
92
93
|
|
93
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
94
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
94
95
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
95
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
96
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
96
97
|
|
97
98
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
98
99
|
const int i10 = i0 % ne10;
|
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
|
|
204
205
|
device const float4 * src0,
|
205
206
|
device const float4 * src1,
|
206
207
|
device float4 * dst,
|
207
|
-
constant int64_t & nb [[buffer(
|
208
|
+
constant int64_t & nb [[buffer(28)]],
|
208
209
|
uint tpig[[thread_position_in_grid]]) {
|
209
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
210
211
|
}
|
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
|
|
213
214
|
device const float4 * src0,
|
214
215
|
device const float4 * src1,
|
215
216
|
device float4 * dst,
|
216
|
-
constant int64_t & nb [[buffer(
|
217
|
+
constant int64_t & nb [[buffer(28)]],
|
217
218
|
uint tpig[[thread_position_in_grid]]) {
|
218
219
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
219
220
|
}
|
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
|
|
222
223
|
device const float4 * src0,
|
223
224
|
device const float4 * src1,
|
224
225
|
device float4 * dst,
|
225
|
-
constant int64_t & nb [[buffer(
|
226
|
+
constant int64_t & nb [[buffer(28)]],
|
226
227
|
uint tpig[[thread_position_in_grid]]) {
|
227
228
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
228
229
|
}
|
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
|
|
243
244
|
dst[tpig] = src0[tpig] * scale;
|
244
245
|
}
|
245
246
|
|
246
|
-
kernel void
|
247
|
-
device const
|
248
|
-
device
|
247
|
+
kernel void kernel_relu(
|
248
|
+
device const float * src0,
|
249
|
+
device float * dst,
|
249
250
|
uint tpig[[thread_position_in_grid]]) {
|
250
|
-
|
251
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
251
|
+
dst[tpig] = max(0.0f, src0[tpig]);
|
252
252
|
}
|
253
253
|
|
254
|
-
kernel void
|
254
|
+
kernel void kernel_tanh(
|
255
255
|
device const float * src0,
|
256
256
|
device float * dst,
|
257
257
|
uint tpig[[thread_position_in_grid]]) {
|
258
|
-
|
258
|
+
device const float & x = src0[tpig];
|
259
|
+
dst[tpig] = precise::tanh(x);
|
260
|
+
}
|
261
|
+
|
262
|
+
constant float GELU_COEF_A = 0.044715f;
|
263
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
264
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
265
|
+
|
266
|
+
kernel void kernel_gelu(
|
267
|
+
device const float4 * src0,
|
268
|
+
device float4 * dst,
|
269
|
+
uint tpig[[thread_position_in_grid]]) {
|
270
|
+
device const float4 & x = src0[tpig];
|
271
|
+
|
272
|
+
// BEWARE !!!
|
273
|
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
274
|
+
// This was observed with Falcon 7B and 40B models
|
275
|
+
//
|
276
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
277
|
+
}
|
278
|
+
|
279
|
+
kernel void kernel_gelu_quick(
|
280
|
+
device const float4 * src0,
|
281
|
+
device float4 * dst,
|
282
|
+
uint tpig[[thread_position_in_grid]]) {
|
283
|
+
device const float4 & x = src0[tpig];
|
284
|
+
|
285
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
286
|
+
}
|
287
|
+
|
288
|
+
kernel void kernel_silu(
|
289
|
+
device const float4 * src0,
|
290
|
+
device float4 * dst,
|
291
|
+
uint tpig[[thread_position_in_grid]]) {
|
292
|
+
device const float4 & x = src0[tpig];
|
293
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
259
294
|
}
|
260
295
|
|
261
296
|
kernel void kernel_sqr(
|
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
|
|
313
348
|
dst_row[0] = row_sum;
|
314
349
|
}
|
315
350
|
|
316
|
-
constant float GELU_COEF_A = 0.044715f;
|
317
|
-
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
318
|
-
|
319
|
-
kernel void kernel_gelu(
|
320
|
-
device const float4 * src0,
|
321
|
-
device float4 * dst,
|
322
|
-
uint tpig[[thread_position_in_grid]]) {
|
323
|
-
device const float4 & x = src0[tpig];
|
324
|
-
|
325
|
-
// BEWARE !!!
|
326
|
-
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
327
|
-
// This was observed with Falcon 7B and 40B models
|
328
|
-
//
|
329
|
-
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
330
|
-
}
|
331
|
-
|
332
351
|
kernel void kernel_soft_max(
|
333
352
|
device const float * src0,
|
334
353
|
device const float * src1,
|
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
|
|
347
366
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
348
367
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
349
368
|
|
350
|
-
device const float * psrc0 =
|
351
|
-
device const float * pmask = src1 ? src1
|
352
|
-
device float * pdst =
|
369
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
370
|
+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
371
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
353
372
|
|
354
373
|
// parallel max
|
355
374
|
float lmax = -INFINITY;
|
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
|
|
385
404
|
pdst[i00] = exp_psrc0;
|
386
405
|
}
|
387
406
|
|
407
|
+
// This barrier fixes a failing test
|
408
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
409
|
+
threadgroup_barrier(mem_flags::mem_none);
|
410
|
+
|
388
411
|
float sum = simd_sum(lsum);
|
412
|
+
|
389
413
|
if (ntg > N_SIMDWIDTH) {
|
390
414
|
if (sgitg == 0) {
|
391
415
|
buf[tiisg] = 0.0f;
|
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
|
|
428
452
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
429
453
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
430
454
|
|
431
|
-
device const float4 * psrc4 =
|
432
|
-
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
433
|
-
device float4 * pdst4 =
|
455
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
456
|
+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
457
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
434
458
|
|
435
459
|
// parallel max
|
436
460
|
float4 lmax4 = -INFINITY;
|
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
|
|
468
492
|
}
|
469
493
|
|
470
494
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
495
|
+
|
496
|
+
// This barrier fixes a failing test
|
497
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
498
|
+
threadgroup_barrier(mem_flags::mem_none);
|
499
|
+
|
471
500
|
float sum = simd_sum(lsum);
|
501
|
+
|
472
502
|
if (ntg > N_SIMDWIDTH) {
|
473
503
|
if (sgitg == 0) {
|
474
504
|
buf[tiisg] = 0.0f;
|
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
|
|
639
669
|
}
|
640
670
|
}
|
641
671
|
|
672
|
+
kernel void kernel_group_norm(
|
673
|
+
device const float * src0,
|
674
|
+
device float * dst,
|
675
|
+
constant int64_t & ne00,
|
676
|
+
constant int64_t & ne01,
|
677
|
+
constant int64_t & ne02,
|
678
|
+
constant uint64_t & nb00,
|
679
|
+
constant uint64_t & nb01,
|
680
|
+
constant uint64_t & nb02,
|
681
|
+
constant int32_t & n_groups,
|
682
|
+
constant float & eps,
|
683
|
+
threadgroup float * buf [[threadgroup(0)]],
|
684
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
685
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
686
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
687
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
688
|
+
uint ntg[[threads_per_threadgroup]]) {
|
689
|
+
const int64_t ne = ne00*ne01*ne02;
|
690
|
+
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
691
|
+
|
692
|
+
int start = tgpig * gs;
|
693
|
+
int end = start + gs;
|
694
|
+
|
695
|
+
start += tpitg;
|
696
|
+
|
697
|
+
if (end >= ne) {
|
698
|
+
end = ne;
|
699
|
+
}
|
700
|
+
|
701
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
702
|
+
|
703
|
+
for (int j = start; j < end; j += ntg) {
|
704
|
+
tmp += src0[j];
|
705
|
+
}
|
706
|
+
|
707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
708
|
+
tmp = simd_sum(tmp);
|
709
|
+
if (ntg > N_SIMDWIDTH) {
|
710
|
+
if (sgitg == 0) {
|
711
|
+
buf[tiisg] = 0.0f;
|
712
|
+
}
|
713
|
+
|
714
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
715
|
+
|
716
|
+
if (tiisg == 0) {
|
717
|
+
buf[sgitg] = tmp;
|
718
|
+
}
|
719
|
+
|
720
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
721
|
+
|
722
|
+
tmp = buf[tiisg];
|
723
|
+
tmp = simd_sum(tmp);
|
724
|
+
}
|
725
|
+
|
726
|
+
const float mean = tmp / gs;
|
727
|
+
tmp = 0.0f;
|
728
|
+
|
729
|
+
for (int j = start; j < end; j += ntg) {
|
730
|
+
float xi = src0[j] - mean;
|
731
|
+
dst[j] = xi;
|
732
|
+
tmp += xi * xi;
|
733
|
+
}
|
734
|
+
|
735
|
+
tmp = simd_sum(tmp);
|
736
|
+
if (ntg > N_SIMDWIDTH) {
|
737
|
+
if (sgitg == 0) {
|
738
|
+
buf[tiisg] = 0.0f;
|
739
|
+
}
|
740
|
+
|
741
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
742
|
+
|
743
|
+
if (tiisg == 0) {
|
744
|
+
buf[sgitg] = tmp;
|
745
|
+
}
|
746
|
+
|
747
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
748
|
+
|
749
|
+
tmp = buf[tiisg];
|
750
|
+
tmp = simd_sum(tmp);
|
751
|
+
}
|
752
|
+
|
753
|
+
const float variance = tmp / gs;
|
754
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
755
|
+
for (int j = start; j < end; j += ntg) {
|
756
|
+
dst[j] *= scale;
|
757
|
+
}
|
758
|
+
}
|
759
|
+
|
642
760
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
643
761
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
644
762
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
@@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
731
849
|
// giard against the number of rows not being divisible by
|
732
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
733
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
734
|
-
void
|
852
|
+
void mul_vec_q_n_f32_impl(
|
735
853
|
device const void * src0,
|
736
854
|
device const float * src1,
|
737
855
|
device float * dst,
|
@@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
813
931
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
814
932
|
uint tiisg[[thread_index_in_simdgroup]],
|
815
933
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
816
|
-
|
934
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
817
935
|
}
|
818
936
|
|
819
937
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
832
950
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
833
951
|
uint tiisg[[thread_index_in_simdgroup]],
|
834
952
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
835
|
-
|
953
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
836
954
|
}
|
837
955
|
|
838
956
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
851
969
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
852
970
|
uint tiisg[[thread_index_in_simdgroup]],
|
853
971
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
854
|
-
|
972
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
855
973
|
}
|
856
974
|
|
857
975
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
870
988
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
871
989
|
uint tiisg[[thread_index_in_simdgroup]],
|
872
990
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
873
|
-
|
991
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
874
992
|
}
|
875
993
|
|
876
994
|
|
877
995
|
#define NB_Q8_0 8
|
878
996
|
|
879
|
-
|
997
|
+
void kernel_mul_mv_q8_0_f32_impl(
|
880
998
|
device const void * src0,
|
881
999
|
device const float * src1,
|
882
1000
|
device float * dst,
|
883
1001
|
constant int64_t & ne00,
|
884
|
-
constant int64_t & ne01
|
885
|
-
constant int64_t & ne02
|
886
|
-
constant int64_t & ne10
|
887
|
-
constant int64_t & ne12
|
888
|
-
constant int64_t & ne0
|
889
|
-
constant int64_t & ne1
|
890
|
-
constant uint & r2
|
891
|
-
constant uint & r3
|
1002
|
+
constant int64_t & ne01,
|
1003
|
+
constant int64_t & ne02,
|
1004
|
+
constant int64_t & ne10,
|
1005
|
+
constant int64_t & ne12,
|
1006
|
+
constant int64_t & ne0,
|
1007
|
+
constant int64_t & ne1,
|
1008
|
+
constant uint & r2,
|
1009
|
+
constant uint & r3,
|
892
1010
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
893
|
-
uint
|
894
|
-
uint
|
1011
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1012
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
895
1013
|
const int nr = N_DST;
|
896
1014
|
const int nsg = N_SIMDGROUP;
|
897
1015
|
const int nw = N_SIMDWIDTH;
|
@@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
945
1063
|
}
|
946
1064
|
}
|
947
1065
|
|
1066
|
+
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
1067
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
1068
|
+
device const void * src0,
|
1069
|
+
device const float * src1,
|
1070
|
+
device float * dst,
|
1071
|
+
constant int64_t & ne00,
|
1072
|
+
constant int64_t & ne01,
|
1073
|
+
constant int64_t & ne02,
|
1074
|
+
constant int64_t & ne10,
|
1075
|
+
constant int64_t & ne12,
|
1076
|
+
constant int64_t & ne0,
|
1077
|
+
constant int64_t & ne1,
|
1078
|
+
constant uint & r2 [[buffer(17)]],
|
1079
|
+
constant uint & r3 [[buffer(18)]],
|
1080
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1083
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1084
|
+
}
|
1085
|
+
|
948
1086
|
#define N_F32_F32 4
|
949
1087
|
|
950
|
-
|
1088
|
+
void kernel_mul_mv_f32_f32_impl(
|
951
1089
|
device const char * src0,
|
952
1090
|
device const char * src1,
|
953
1091
|
device float * dst,
|
@@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
965
1103
|
constant uint64_t & nb12,
|
966
1104
|
constant int64_t & ne0,
|
967
1105
|
constant int64_t & ne1,
|
968
|
-
constant uint & r2
|
969
|
-
constant uint & r3
|
1106
|
+
constant uint & r2,
|
1107
|
+
constant uint & r3,
|
970
1108
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
971
1109
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
972
1110
|
|
@@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
|
|
1025
1163
|
}
|
1026
1164
|
}
|
1027
1165
|
|
1166
|
+
[[host_name("kernel_mul_mv_f32_f32")]]
|
1167
|
+
kernel void kernel_mul_mv_f32_f32(
|
1168
|
+
device const char * src0,
|
1169
|
+
device const char * src1,
|
1170
|
+
device float * dst,
|
1171
|
+
constant int64_t & ne00,
|
1172
|
+
constant int64_t & ne01,
|
1173
|
+
constant int64_t & ne02,
|
1174
|
+
constant uint64_t & nb00,
|
1175
|
+
constant uint64_t & nb01,
|
1176
|
+
constant uint64_t & nb02,
|
1177
|
+
constant int64_t & ne10,
|
1178
|
+
constant int64_t & ne11,
|
1179
|
+
constant int64_t & ne12,
|
1180
|
+
constant uint64_t & nb10,
|
1181
|
+
constant uint64_t & nb11,
|
1182
|
+
constant uint64_t & nb12,
|
1183
|
+
constant int64_t & ne0,
|
1184
|
+
constant int64_t & ne1,
|
1185
|
+
constant uint & r2 [[buffer(17)]],
|
1186
|
+
constant uint & r3 [[buffer(18)]],
|
1187
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1188
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1189
|
+
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1190
|
+
}
|
1191
|
+
|
1028
1192
|
#define N_F16_F16 4
|
1029
1193
|
|
1030
1194
|
kernel void kernel_mul_mv_f16_f16(
|
@@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
|
|
1105
1269
|
}
|
1106
1270
|
}
|
1107
1271
|
|
1108
|
-
|
1272
|
+
void kernel_mul_mv_f16_f32_1row_impl(
|
1109
1273
|
device const char * src0,
|
1110
1274
|
device const char * src1,
|
1111
1275
|
device float * dst,
|
@@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
1123
1287
|
constant uint64_t & nb12,
|
1124
1288
|
constant int64_t & ne0,
|
1125
1289
|
constant int64_t & ne1,
|
1126
|
-
constant uint & r2
|
1127
|
-
constant uint & r3
|
1290
|
+
constant uint & r2,
|
1291
|
+
constant uint & r3,
|
1128
1292
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1129
1293
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1130
1294
|
|
@@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
1161
1325
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
1162
1326
|
}
|
1163
1327
|
}
|
1328
|
+
}
|
1164
1329
|
|
1330
|
+
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
1331
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
1332
|
+
device const char * src0,
|
1333
|
+
device const char * src1,
|
1334
|
+
device float * dst,
|
1335
|
+
constant int64_t & ne00,
|
1336
|
+
constant int64_t & ne01,
|
1337
|
+
constant int64_t & ne02,
|
1338
|
+
constant uint64_t & nb00,
|
1339
|
+
constant uint64_t & nb01,
|
1340
|
+
constant uint64_t & nb02,
|
1341
|
+
constant int64_t & ne10,
|
1342
|
+
constant int64_t & ne11,
|
1343
|
+
constant int64_t & ne12,
|
1344
|
+
constant uint64_t & nb10,
|
1345
|
+
constant uint64_t & nb11,
|
1346
|
+
constant uint64_t & nb12,
|
1347
|
+
constant int64_t & ne0,
|
1348
|
+
constant int64_t & ne1,
|
1349
|
+
constant uint & r2 [[buffer(17)]],
|
1350
|
+
constant uint & r3 [[buffer(18)]],
|
1351
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1353
|
+
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1165
1354
|
}
|
1166
1355
|
|
1167
1356
|
#define N_F16_F32 4
|
1168
1357
|
|
1169
|
-
|
1358
|
+
void kernel_mul_mv_f16_f32_impl(
|
1170
1359
|
device const char * src0,
|
1171
1360
|
device const char * src1,
|
1172
1361
|
device float * dst,
|
@@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1184
1373
|
constant uint64_t & nb12,
|
1185
1374
|
constant int64_t & ne0,
|
1186
1375
|
constant int64_t & ne1,
|
1187
|
-
constant uint & r2
|
1188
|
-
constant uint & r3
|
1376
|
+
constant uint & r2,
|
1377
|
+
constant uint & r3,
|
1189
1378
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1190
1379
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1191
1380
|
|
@@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1244
1433
|
}
|
1245
1434
|
}
|
1246
1435
|
|
1436
|
+
[[host_name("kernel_mul_mv_f16_f32")]]
|
1437
|
+
kernel void kernel_mul_mv_f16_f32(
|
1438
|
+
device const char * src0,
|
1439
|
+
device const char * src1,
|
1440
|
+
device float * dst,
|
1441
|
+
constant int64_t & ne00,
|
1442
|
+
constant int64_t & ne01,
|
1443
|
+
constant int64_t & ne02,
|
1444
|
+
constant uint64_t & nb00,
|
1445
|
+
constant uint64_t & nb01,
|
1446
|
+
constant uint64_t & nb02,
|
1447
|
+
constant int64_t & ne10,
|
1448
|
+
constant int64_t & ne11,
|
1449
|
+
constant int64_t & ne12,
|
1450
|
+
constant uint64_t & nb10,
|
1451
|
+
constant uint64_t & nb11,
|
1452
|
+
constant uint64_t & nb12,
|
1453
|
+
constant int64_t & ne0,
|
1454
|
+
constant int64_t & ne1,
|
1455
|
+
constant uint & r2 [[buffer(17)]],
|
1456
|
+
constant uint & r3 [[buffer(18)]],
|
1457
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1458
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1459
|
+
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1460
|
+
}
|
1461
|
+
|
1247
1462
|
// Assumes row size (ne00) is a multiple of 4
|
1248
1463
|
kernel void kernel_mul_mv_f16_f32_l4(
|
1249
1464
|
device const char * src0,
|
@@ -1487,8 +1702,9 @@ kernel void kernel_rope(
|
|
1487
1702
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
1488
1703
|
}
|
1489
1704
|
} else {
|
1490
|
-
for (int64_t
|
1491
|
-
|
1705
|
+
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1706
|
+
if (ic < n_dims) {
|
1707
|
+
const int64_t ib = 0;
|
1492
1708
|
|
1493
1709
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1494
1710
|
const float cur_rot = inv_ndims*ic - ib;
|
@@ -1507,6 +1723,14 @@ kernel void kernel_rope(
|
|
1507
1723
|
|
1508
1724
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
1509
1725
|
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
1726
|
+
} else {
|
1727
|
+
const int64_t i0 = ic;
|
1728
|
+
|
1729
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1730
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1731
|
+
|
1732
|
+
dst_data[0] = src[0];
|
1733
|
+
dst_data[1] = src[1];
|
1510
1734
|
}
|
1511
1735
|
}
|
1512
1736
|
}
|
@@ -1548,21 +1772,112 @@ kernel void kernel_im2col_f16(
|
|
1548
1772
|
}
|
1549
1773
|
}
|
1550
1774
|
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1775
|
+
kernel void kernel_upscale_f32(
|
1776
|
+
device const char * src0,
|
1777
|
+
device char * dst,
|
1778
|
+
constant int64_t & ne00,
|
1779
|
+
constant int64_t & ne01,
|
1780
|
+
constant int64_t & ne02,
|
1781
|
+
constant int64_t & ne03,
|
1782
|
+
constant uint64_t & nb00,
|
1783
|
+
constant uint64_t & nb01,
|
1784
|
+
constant uint64_t & nb02,
|
1785
|
+
constant uint64_t & nb03,
|
1786
|
+
constant int64_t & ne0,
|
1787
|
+
constant int64_t & ne1,
|
1788
|
+
constant int64_t & ne2,
|
1789
|
+
constant int64_t & ne3,
|
1790
|
+
constant uint64_t & nb0,
|
1791
|
+
constant uint64_t & nb1,
|
1792
|
+
constant uint64_t & nb2,
|
1793
|
+
constant uint64_t & nb3,
|
1794
|
+
constant int32_t & sf,
|
1795
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1796
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1797
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1798
|
+
|
1799
|
+
const int64_t i3 = tgpig.z;
|
1800
|
+
const int64_t i2 = tgpig.y;
|
1801
|
+
const int64_t i1 = tgpig.x;
|
1802
|
+
|
1803
|
+
const int64_t i03 = i3;
|
1804
|
+
const int64_t i02 = i2;
|
1805
|
+
const int64_t i01 = i1/sf;
|
1806
|
+
|
1807
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1808
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1809
|
+
|
1810
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1811
|
+
dst_ptr[i0] = src0_ptr[i0/sf];
|
1812
|
+
}
|
1813
|
+
}
|
1814
|
+
|
1815
|
+
kernel void kernel_pad_f32(
|
1816
|
+
device const char * src0,
|
1817
|
+
device char * dst,
|
1818
|
+
constant int64_t & ne00,
|
1819
|
+
constant int64_t & ne01,
|
1820
|
+
constant int64_t & ne02,
|
1821
|
+
constant int64_t & ne03,
|
1822
|
+
constant uint64_t & nb00,
|
1823
|
+
constant uint64_t & nb01,
|
1824
|
+
constant uint64_t & nb02,
|
1825
|
+
constant uint64_t & nb03,
|
1826
|
+
constant int64_t & ne0,
|
1827
|
+
constant int64_t & ne1,
|
1828
|
+
constant int64_t & ne2,
|
1829
|
+
constant int64_t & ne3,
|
1830
|
+
constant uint64_t & nb0,
|
1831
|
+
constant uint64_t & nb1,
|
1832
|
+
constant uint64_t & nb2,
|
1833
|
+
constant uint64_t & nb3,
|
1834
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1835
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1836
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1837
|
+
|
1838
|
+
const int64_t i3 = tgpig.z;
|
1839
|
+
const int64_t i2 = tgpig.y;
|
1840
|
+
const int64_t i1 = tgpig.x;
|
1841
|
+
|
1842
|
+
const int64_t i03 = i3;
|
1843
|
+
const int64_t i02 = i2;
|
1844
|
+
const int64_t i01 = i1;
|
1845
|
+
|
1846
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1847
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1848
|
+
|
1849
|
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
1850
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1851
|
+
if (i0 < ne00) {
|
1852
|
+
dst_ptr[i0] = src0_ptr[i0];
|
1853
|
+
} else {
|
1854
|
+
dst_ptr[i0] = 0.0f;
|
1855
|
+
}
|
1856
|
+
}
|
1857
|
+
|
1858
|
+
return;
|
1859
|
+
}
|
1860
|
+
|
1861
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1862
|
+
dst_ptr[i0] = 0.0f;
|
1863
|
+
}
|
1864
|
+
}
|
1865
|
+
|
1866
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
1867
|
+
typedef void (argsort_t)(
|
1868
|
+
device const float * x,
|
1869
|
+
device int32_t * dst,
|
1870
|
+
constant int64_t & ncols,
|
1871
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1872
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1873
|
+
|
1874
|
+
template<ggml_sort_order order>
|
1875
|
+
kernel void kernel_argsort_f32_i32(
|
1876
|
+
device const float * x,
|
1877
|
+
device int32_t * dst,
|
1878
|
+
constant int64_t & ncols,
|
1879
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1880
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1566
1881
|
// bitonic sort
|
1567
1882
|
int col = tpitg[0];
|
1568
1883
|
int row = tgpig[1];
|
@@ -1600,9 +1915,17 @@ kernel void kernel_argsort_f32_i32(
|
|
1600
1915
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
1601
1916
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
1602
1917
|
|
1918
|
+
kernel void kernel_leaky_relu_f32(
|
1919
|
+
device const float * src0,
|
1920
|
+
device float * dst,
|
1921
|
+
constant float & slope,
|
1922
|
+
uint tpig[[thread_position_in_grid]]) {
|
1923
|
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
1924
|
+
}
|
1925
|
+
|
1603
1926
|
kernel void kernel_cpy_f16_f16(
|
1604
|
-
device
|
1605
|
-
device
|
1927
|
+
device const half * src0,
|
1928
|
+
device half * dst,
|
1606
1929
|
constant int64_t & ne00,
|
1607
1930
|
constant int64_t & ne01,
|
1608
1931
|
constant int64_t & ne02,
|
@@ -1641,6 +1964,47 @@ kernel void kernel_cpy_f16_f16(
|
|
1641
1964
|
}
|
1642
1965
|
}
|
1643
1966
|
|
1967
|
+
kernel void kernel_cpy_f16_f32(
|
1968
|
+
device const half * src0,
|
1969
|
+
device float * dst,
|
1970
|
+
constant int64_t & ne00,
|
1971
|
+
constant int64_t & ne01,
|
1972
|
+
constant int64_t & ne02,
|
1973
|
+
constant int64_t & ne03,
|
1974
|
+
constant uint64_t & nb00,
|
1975
|
+
constant uint64_t & nb01,
|
1976
|
+
constant uint64_t & nb02,
|
1977
|
+
constant uint64_t & nb03,
|
1978
|
+
constant int64_t & ne0,
|
1979
|
+
constant int64_t & ne1,
|
1980
|
+
constant int64_t & ne2,
|
1981
|
+
constant int64_t & ne3,
|
1982
|
+
constant uint64_t & nb0,
|
1983
|
+
constant uint64_t & nb1,
|
1984
|
+
constant uint64_t & nb2,
|
1985
|
+
constant uint64_t & nb3,
|
1986
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1987
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1988
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1989
|
+
const int64_t i03 = tgpig[2];
|
1990
|
+
const int64_t i02 = tgpig[1];
|
1991
|
+
const int64_t i01 = tgpig[0];
|
1992
|
+
|
1993
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1994
|
+
|
1995
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1996
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1997
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1998
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1999
|
+
|
2000
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2001
|
+
|
2002
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
2003
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2004
|
+
dst_data[i00] = src[0];
|
2005
|
+
}
|
2006
|
+
}
|
2007
|
+
|
1644
2008
|
kernel void kernel_cpy_f32_f16(
|
1645
2009
|
device const float * src0,
|
1646
2010
|
device half * dst,
|
@@ -1917,9 +2281,9 @@ kernel void kernel_cpy_f32_q4_1(
|
|
1917
2281
|
}
|
1918
2282
|
|
1919
2283
|
kernel void kernel_concat(
|
1920
|
-
device
|
1921
|
-
device
|
1922
|
-
device
|
2284
|
+
device const char * src0,
|
2285
|
+
device const char * src1,
|
2286
|
+
device char * dst,
|
1923
2287
|
constant int64_t & ne00,
|
1924
2288
|
constant int64_t & ne01,
|
1925
2289
|
constant int64_t & ne02,
|
@@ -1956,7 +2320,7 @@ kernel void kernel_concat(
|
|
1956
2320
|
const int64_t i12 = i02 % ne12;
|
1957
2321
|
const int64_t i11 = i01 % ne11;
|
1958
2322
|
|
1959
|
-
device const char * src0_ptr = src0 + i03
|
2323
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
1960
2324
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1961
2325
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1962
2326
|
|
@@ -2064,19 +2428,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
2064
2428
|
|
2065
2429
|
//====================================== dot products =========================
|
2066
2430
|
|
2067
|
-
|
2431
|
+
void kernel_mul_mv_q2_K_f32_impl(
|
2068
2432
|
device const void * src0,
|
2069
2433
|
device const float * src1,
|
2070
2434
|
device float * dst,
|
2071
2435
|
constant int64_t & ne00,
|
2072
|
-
constant int64_t & ne01
|
2073
|
-
constant int64_t & ne02
|
2074
|
-
constant int64_t & ne10
|
2075
|
-
constant int64_t & ne12
|
2076
|
-
constant int64_t & ne0
|
2077
|
-
constant int64_t & ne1
|
2078
|
-
constant uint & r2
|
2079
|
-
constant uint & r3
|
2436
|
+
constant int64_t & ne01,
|
2437
|
+
constant int64_t & ne02,
|
2438
|
+
constant int64_t & ne10,
|
2439
|
+
constant int64_t & ne12,
|
2440
|
+
constant int64_t & ne0,
|
2441
|
+
constant int64_t & ne1,
|
2442
|
+
constant uint & r2,
|
2443
|
+
constant uint & r3,
|
2080
2444
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2081
2445
|
uint tiisg[[thread_index_in_simdgroup]],
|
2082
2446
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2214,8 +2578,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
2214
2578
|
}
|
2215
2579
|
}
|
2216
2580
|
|
2217
|
-
|
2218
|
-
kernel void
|
2581
|
+
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
2582
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
2219
2583
|
device const void * src0,
|
2220
2584
|
device const float * src1,
|
2221
2585
|
device float * dst,
|
@@ -2229,8 +2593,29 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2229
2593
|
constant uint & r2 [[buffer(17)]],
|
2230
2594
|
constant uint & r3 [[buffer(18)]],
|
2231
2595
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2232
|
-
uint
|
2233
|
-
uint
|
2596
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2597
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2598
|
+
|
2599
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2600
|
+
}
|
2601
|
+
|
2602
|
+
#if QK_K == 256
|
2603
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
2604
|
+
device const void * src0,
|
2605
|
+
device const float * src1,
|
2606
|
+
device float * dst,
|
2607
|
+
constant int64_t & ne00,
|
2608
|
+
constant int64_t & ne01,
|
2609
|
+
constant int64_t & ne02,
|
2610
|
+
constant int64_t & ne10,
|
2611
|
+
constant int64_t & ne12,
|
2612
|
+
constant int64_t & ne0,
|
2613
|
+
constant int64_t & ne1,
|
2614
|
+
constant uint & r2,
|
2615
|
+
constant uint & r3,
|
2616
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2617
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2618
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2234
2619
|
|
2235
2620
|
const int nb = ne00/QK_K;
|
2236
2621
|
|
@@ -2373,19 +2758,19 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2373
2758
|
}
|
2374
2759
|
}
|
2375
2760
|
#else
|
2376
|
-
|
2761
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
2377
2762
|
device const void * src0,
|
2378
2763
|
device const float * src1,
|
2379
2764
|
device float * dst,
|
2380
2765
|
constant int64_t & ne00,
|
2381
|
-
constant int64_t & ne01
|
2382
|
-
constant int64_t & ne02
|
2383
|
-
constant int64_t & ne10
|
2384
|
-
constant int64_t & ne12
|
2385
|
-
constant int64_t & ne0
|
2386
|
-
constant int64_t & ne1
|
2387
|
-
constant uint & r2
|
2388
|
-
constant uint & r3
|
2766
|
+
constant int64_t & ne01,
|
2767
|
+
constant int64_t & ne02,
|
2768
|
+
constant int64_t & ne10,
|
2769
|
+
constant int64_t & ne12,
|
2770
|
+
constant int64_t & ne0,
|
2771
|
+
constant int64_t & ne1,
|
2772
|
+
constant uint & r2,
|
2773
|
+
constant uint & r3,
|
2389
2774
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2390
2775
|
uint tiisg[[thread_index_in_simdgroup]],
|
2391
2776
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2450,20 +2835,41 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2450
2835
|
}
|
2451
2836
|
#endif
|
2452
2837
|
|
2838
|
+
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
2839
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
2840
|
+
device const void * src0,
|
2841
|
+
device const float * src1,
|
2842
|
+
device float * dst,
|
2843
|
+
constant int64_t & ne00,
|
2844
|
+
constant int64_t & ne01[[buffer(4)]],
|
2845
|
+
constant int64_t & ne02[[buffer(5)]],
|
2846
|
+
constant int64_t & ne10[[buffer(9)]],
|
2847
|
+
constant int64_t & ne12[[buffer(11)]],
|
2848
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2849
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2850
|
+
constant uint & r2 [[buffer(17)]],
|
2851
|
+
constant uint & r3 [[buffer(18)]],
|
2852
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2853
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2854
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2855
|
+
|
2856
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2857
|
+
}
|
2858
|
+
|
2453
2859
|
#if QK_K == 256
|
2454
|
-
|
2860
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
2455
2861
|
device const void * src0,
|
2456
2862
|
device const float * src1,
|
2457
2863
|
device float * dst,
|
2458
2864
|
constant int64_t & ne00,
|
2459
|
-
constant int64_t & ne01
|
2460
|
-
constant int64_t & ne02
|
2461
|
-
constant int64_t & ne10
|
2462
|
-
constant int64_t & ne12
|
2463
|
-
constant int64_t & ne0
|
2464
|
-
constant int64_t & ne1
|
2465
|
-
constant uint & r2
|
2466
|
-
constant uint & r3
|
2865
|
+
constant int64_t & ne01,
|
2866
|
+
constant int64_t & ne02,
|
2867
|
+
constant int64_t & ne10,
|
2868
|
+
constant int64_t & ne12,
|
2869
|
+
constant int64_t & ne0,
|
2870
|
+
constant int64_t & ne1,
|
2871
|
+
constant uint & r2,
|
2872
|
+
constant uint & r3,
|
2467
2873
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2468
2874
|
uint tiisg[[thread_index_in_simdgroup]],
|
2469
2875
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2564,19 +2970,19 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2564
2970
|
}
|
2565
2971
|
}
|
2566
2972
|
#else
|
2567
|
-
|
2973
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
2568
2974
|
device const void * src0,
|
2569
2975
|
device const float * src1,
|
2570
2976
|
device float * dst,
|
2571
2977
|
constant int64_t & ne00,
|
2572
|
-
constant int64_t & ne01
|
2573
|
-
constant int64_t & ne02
|
2574
|
-
constant int64_t & ne10
|
2575
|
-
constant int64_t & ne12
|
2576
|
-
constant int64_t & ne0
|
2577
|
-
constant int64_t & ne1
|
2578
|
-
constant uint & r2
|
2579
|
-
constant uint & r3
|
2978
|
+
constant int64_t & ne01,
|
2979
|
+
constant int64_t & ne02,
|
2980
|
+
constant int64_t & ne10,
|
2981
|
+
constant int64_t & ne12,
|
2982
|
+
constant int64_t & ne0,
|
2983
|
+
constant int64_t & ne1,
|
2984
|
+
constant uint & r2,
|
2985
|
+
constant uint & r3,
|
2580
2986
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2581
2987
|
uint tiisg[[thread_index_in_simdgroup]],
|
2582
2988
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2660,7 +3066,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2660
3066
|
}
|
2661
3067
|
#endif
|
2662
3068
|
|
2663
|
-
|
3069
|
+
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
3070
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
2664
3071
|
device const void * src0,
|
2665
3072
|
device const float * src1,
|
2666
3073
|
device float * dst,
|
@@ -2677,6 +3084,26 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2677
3084
|
uint tiisg[[thread_index_in_simdgroup]],
|
2678
3085
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2679
3086
|
|
3087
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3088
|
+
}
|
3089
|
+
|
3090
|
+
void kernel_mul_mv_q5_K_f32_impl(
|
3091
|
+
device const void * src0,
|
3092
|
+
device const float * src1,
|
3093
|
+
device float * dst,
|
3094
|
+
constant int64_t & ne00,
|
3095
|
+
constant int64_t & ne01,
|
3096
|
+
constant int64_t & ne02,
|
3097
|
+
constant int64_t & ne10,
|
3098
|
+
constant int64_t & ne12,
|
3099
|
+
constant int64_t & ne0,
|
3100
|
+
constant int64_t & ne1,
|
3101
|
+
constant uint & r2,
|
3102
|
+
constant uint & r3,
|
3103
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3104
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3105
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3106
|
+
|
2680
3107
|
const int nb = ne00/QK_K;
|
2681
3108
|
|
2682
3109
|
const int64_t r0 = tgpig.x;
|
@@ -2836,10 +3263,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2836
3263
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
2837
3264
|
}
|
2838
3265
|
}
|
2839
|
-
|
2840
3266
|
}
|
2841
3267
|
|
2842
|
-
|
3268
|
+
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
3269
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
2843
3270
|
device const void * src0,
|
2844
3271
|
device const float * src1,
|
2845
3272
|
device float * dst,
|
@@ -2853,21 +3280,41 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2853
3280
|
constant uint & r2 [[buffer(17)]],
|
2854
3281
|
constant uint & r3 [[buffer(18)]],
|
2855
3282
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2856
|
-
uint
|
2857
|
-
uint
|
2858
|
-
|
2859
|
-
const uint8_t kmask1 = 0x03;
|
2860
|
-
const uint8_t kmask2 = 0x0C;
|
2861
|
-
const uint8_t kmask3 = 0x30;
|
2862
|
-
const uint8_t kmask4 = 0xC0;
|
2863
|
-
|
2864
|
-
const int nb = ne00/QK_K;
|
2865
|
-
|
2866
|
-
const int64_t r0 = tgpig.x;
|
2867
|
-
const int64_t r1 = tgpig.y;
|
2868
|
-
const int im = tgpig.z;
|
3283
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3284
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2869
3285
|
|
2870
|
-
|
3286
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3287
|
+
}
|
3288
|
+
|
3289
|
+
void kernel_mul_mv_q6_K_f32_impl(
|
3290
|
+
device const void * src0,
|
3291
|
+
device const float * src1,
|
3292
|
+
device float * dst,
|
3293
|
+
constant int64_t & ne00,
|
3294
|
+
constant int64_t & ne01,
|
3295
|
+
constant int64_t & ne02,
|
3296
|
+
constant int64_t & ne10,
|
3297
|
+
constant int64_t & ne12,
|
3298
|
+
constant int64_t & ne0,
|
3299
|
+
constant int64_t & ne1,
|
3300
|
+
constant uint & r2,
|
3301
|
+
constant uint & r3,
|
3302
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3303
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3304
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3305
|
+
|
3306
|
+
const uint8_t kmask1 = 0x03;
|
3307
|
+
const uint8_t kmask2 = 0x0C;
|
3308
|
+
const uint8_t kmask3 = 0x30;
|
3309
|
+
const uint8_t kmask4 = 0xC0;
|
3310
|
+
|
3311
|
+
const int nb = ne00/QK_K;
|
3312
|
+
|
3313
|
+
const int64_t r0 = tgpig.x;
|
3314
|
+
const int64_t r1 = tgpig.y;
|
3315
|
+
const int im = tgpig.z;
|
3316
|
+
|
3317
|
+
const int row = 2 * r0 + sgitg;
|
2871
3318
|
|
2872
3319
|
const uint i12 = im%ne12;
|
2873
3320
|
const uint i13 = im/ne12;
|
@@ -2945,6 +3392,27 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2945
3392
|
}
|
2946
3393
|
}
|
2947
3394
|
|
3395
|
+
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
3396
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
3397
|
+
device const void * src0,
|
3398
|
+
device const float * src1,
|
3399
|
+
device float * dst,
|
3400
|
+
constant int64_t & ne00,
|
3401
|
+
constant int64_t & ne01[[buffer(4)]],
|
3402
|
+
constant int64_t & ne02[[buffer(5)]],
|
3403
|
+
constant int64_t & ne10[[buffer(9)]],
|
3404
|
+
constant int64_t & ne12[[buffer(11)]],
|
3405
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3406
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3407
|
+
constant uint & r2 [[buffer(17)]],
|
3408
|
+
constant uint & r3 [[buffer(18)]],
|
3409
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3410
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3411
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3412
|
+
|
3413
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3414
|
+
}
|
3415
|
+
|
2948
3416
|
//============================= templates and their specializations =============================
|
2949
3417
|
|
2950
3418
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
@@ -3062,10 +3530,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
3062
3530
|
|
3063
3531
|
template <typename type4x4>
|
3064
3532
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
3065
|
-
const
|
3066
|
-
const
|
3533
|
+
const float d = xb->d;
|
3534
|
+
const float min = xb->dmin;
|
3067
3535
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
3068
|
-
|
3536
|
+
float dl, ml;
|
3069
3537
|
uint8_t sc = xb->scales[il];
|
3070
3538
|
|
3071
3539
|
#if QK_K == 256
|
@@ -3135,10 +3603,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
3135
3603
|
q = q + (il/4) * 32 + 16 * (il&1);
|
3136
3604
|
il = il & 3;
|
3137
3605
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
3138
|
-
const
|
3139
|
-
const
|
3140
|
-
const
|
3141
|
-
const
|
3606
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3607
|
+
const float min = xb->dmin;
|
3608
|
+
const float dl = d * sc[0];
|
3609
|
+
const float ml = min * sc[1];
|
3142
3610
|
#else
|
3143
3611
|
q = q + 16 * (il&1);
|
3144
3612
|
device const uint8_t * s = xb->scales;
|
@@ -3165,13 +3633,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
3165
3633
|
uint8_t ul = 1 << (il/2);
|
3166
3634
|
il = il & 3;
|
3167
3635
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
3168
|
-
const
|
3169
|
-
const
|
3170
|
-
const
|
3171
|
-
const
|
3636
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3637
|
+
const float min = xb->dmin;
|
3638
|
+
const float dl = d * sc[0];
|
3639
|
+
const float ml = min * sc[1];
|
3172
3640
|
|
3173
|
-
const ushort mask
|
3174
|
-
const
|
3641
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
3642
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
3175
3643
|
for (int i = 0; i < 16; ++i) {
|
3176
3644
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
3177
3645
|
}
|
@@ -3219,22 +3687,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
3219
3687
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
3220
3688
|
kernel void kernel_get_rows(
|
3221
3689
|
device const void * src0,
|
3222
|
-
device const
|
3690
|
+
device const char * src1,
|
3223
3691
|
device float * dst,
|
3224
3692
|
constant int64_t & ne00,
|
3225
3693
|
constant uint64_t & nb01,
|
3694
|
+
constant uint64_t & nb02,
|
3695
|
+
constant int64_t & ne10,
|
3696
|
+
constant uint64_t & nb10,
|
3697
|
+
constant uint64_t & nb11,
|
3226
3698
|
constant uint64_t & nb1,
|
3227
|
-
|
3699
|
+
constant uint64_t & nb2,
|
3700
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3228
3701
|
uint tiitg[[thread_index_in_threadgroup]],
|
3229
|
-
|
3230
|
-
const
|
3231
|
-
const
|
3702
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3703
|
+
//const int64_t i = tgpig;
|
3704
|
+
//const int64_t r = ((device int32_t *) src1)[i];
|
3705
|
+
|
3706
|
+
const int64_t i10 = tgpig.x;
|
3707
|
+
const int64_t i11 = tgpig.y;
|
3708
|
+
|
3709
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3232
3710
|
|
3233
|
-
|
3711
|
+
const int64_t i02 = i11;
|
3712
|
+
|
3713
|
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
3234
3714
|
float4x4 temp;
|
3235
3715
|
dequantize_func(
|
3236
|
-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
3237
|
-
*(((device float4x4 *) ((device char *) dst +
|
3716
|
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
3717
|
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
3718
|
+
}
|
3719
|
+
}
|
3720
|
+
|
3721
|
+
kernel void kernel_get_rows_f32(
|
3722
|
+
device const void * src0,
|
3723
|
+
device const char * src1,
|
3724
|
+
device float * dst,
|
3725
|
+
constant int64_t & ne00,
|
3726
|
+
constant uint64_t & nb01,
|
3727
|
+
constant uint64_t & nb02,
|
3728
|
+
constant int64_t & ne10,
|
3729
|
+
constant uint64_t & nb10,
|
3730
|
+
constant uint64_t & nb11,
|
3731
|
+
constant uint64_t & nb1,
|
3732
|
+
constant uint64_t & nb2,
|
3733
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3734
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3735
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3736
|
+
const int64_t i10 = tgpig.x;
|
3737
|
+
const int64_t i11 = tgpig.y;
|
3738
|
+
|
3739
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3740
|
+
|
3741
|
+
const int64_t i02 = i11;
|
3742
|
+
|
3743
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3744
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3745
|
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3746
|
+
}
|
3747
|
+
}
|
3748
|
+
|
3749
|
+
kernel void kernel_get_rows_f16(
|
3750
|
+
device const void * src0,
|
3751
|
+
device const char * src1,
|
3752
|
+
device float * dst,
|
3753
|
+
constant int64_t & ne00,
|
3754
|
+
constant uint64_t & nb01,
|
3755
|
+
constant uint64_t & nb02,
|
3756
|
+
constant int64_t & ne10,
|
3757
|
+
constant uint64_t & nb10,
|
3758
|
+
constant uint64_t & nb11,
|
3759
|
+
constant uint64_t & nb1,
|
3760
|
+
constant uint64_t & nb2,
|
3761
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3762
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3763
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3764
|
+
const int64_t i10 = tgpig.x;
|
3765
|
+
const int64_t i11 = tgpig.y;
|
3766
|
+
|
3767
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3768
|
+
|
3769
|
+
const int64_t i02 = i11;
|
3770
|
+
|
3771
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3772
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3773
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3238
3774
|
}
|
3239
3775
|
}
|
3240
3776
|
|
@@ -3426,19 +3962,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
3426
3962
|
|
3427
3963
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3428
3964
|
kernel void kernel_mul_mm_id(
|
3429
|
-
device const
|
3965
|
+
device const uchar * ids,
|
3430
3966
|
device const uchar * src1,
|
3431
|
-
device
|
3967
|
+
device uchar * dst,
|
3968
|
+
constant int64_t & nbi1,
|
3432
3969
|
constant int64_t & ne00,
|
3433
3970
|
constant int64_t & ne02,
|
3434
3971
|
constant int64_t & nb01,
|
3435
3972
|
constant int64_t & nb02,
|
3436
3973
|
constant int64_t & ne12,
|
3974
|
+
constant int64_t & ne13,
|
3437
3975
|
constant int64_t & nb10,
|
3438
3976
|
constant int64_t & nb11,
|
3439
3977
|
constant int64_t & nb12,
|
3440
3978
|
constant int64_t & ne0,
|
3441
3979
|
constant int64_t & ne1,
|
3980
|
+
constant int64_t & nb1,
|
3442
3981
|
constant uint & r2,
|
3443
3982
|
constant uint & r3,
|
3444
3983
|
constant int & idx,
|
@@ -3456,10 +3995,16 @@ kernel void kernel_mul_mm_id(
|
|
3456
3995
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3457
3996
|
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3458
3997
|
|
3998
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
3999
|
+
|
4000
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4001
|
+
|
4002
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4003
|
+
|
3459
4004
|
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3460
|
-
src0[
|
3461
|
-
src1,
|
3462
|
-
dst,
|
4005
|
+
src0[id],
|
4006
|
+
src1 + bid*nb11,
|
4007
|
+
(device float *) (dst + bid*nb1),
|
3463
4008
|
ne00,
|
3464
4009
|
ne02,
|
3465
4010
|
nb01,
|
@@ -3484,17 +4029,26 @@ kernel void kernel_mul_mm_id(
|
|
3484
4029
|
#define QK_NL 4
|
3485
4030
|
#endif
|
3486
4031
|
|
4032
|
+
//
|
4033
|
+
// get rows
|
4034
|
+
//
|
4035
|
+
|
3487
4036
|
typedef void (get_rows_t)(
|
3488
4037
|
device const void * src0,
|
3489
|
-
device const
|
4038
|
+
device const char * src1,
|
3490
4039
|
device float * dst,
|
3491
4040
|
constant int64_t & ne00,
|
3492
4041
|
constant uint64_t & nb01,
|
4042
|
+
constant uint64_t & nb02,
|
4043
|
+
constant int64_t & ne10,
|
4044
|
+
constant uint64_t & nb10,
|
4045
|
+
constant uint64_t & nb11,
|
3493
4046
|
constant uint64_t & nb1,
|
3494
|
-
|
4047
|
+
constant uint64_t & nb2,
|
4048
|
+
uint3, uint, uint3);
|
3495
4049
|
|
3496
|
-
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
3497
|
-
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
4050
|
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
4051
|
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
3498
4052
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
3499
4053
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
3500
4054
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
@@ -3506,6 +4060,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
3506
4060
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
3507
4061
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
3508
4062
|
|
4063
|
+
//
|
4064
|
+
// matrix-matrix multiplication
|
4065
|
+
//
|
4066
|
+
|
3509
4067
|
typedef void (mat_mm_t)(
|
3510
4068
|
device const uchar * src0,
|
3511
4069
|
device const uchar * src1,
|
@@ -3538,20 +4096,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
3538
4096
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
3539
4097
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
3540
4098
|
|
4099
|
+
//
|
4100
|
+
// indirect matrix-matrix multiplication
|
4101
|
+
//
|
4102
|
+
|
3541
4103
|
typedef void (mat_mm_id_t)(
|
3542
|
-
device const
|
4104
|
+
device const uchar * ids,
|
3543
4105
|
device const uchar * src1,
|
3544
|
-
device
|
4106
|
+
device uchar * dst,
|
4107
|
+
constant int64_t & nbi1,
|
3545
4108
|
constant int64_t & ne00,
|
3546
4109
|
constant int64_t & ne02,
|
3547
4110
|
constant int64_t & nb01,
|
3548
4111
|
constant int64_t & nb02,
|
3549
4112
|
constant int64_t & ne12,
|
4113
|
+
constant int64_t & ne13,
|
3550
4114
|
constant int64_t & nb10,
|
3551
4115
|
constant int64_t & nb11,
|
3552
4116
|
constant int64_t & nb12,
|
3553
4117
|
constant int64_t & ne0,
|
3554
4118
|
constant int64_t & ne1,
|
4119
|
+
constant int64_t & nb1,
|
3555
4120
|
constant uint & r2,
|
3556
4121
|
constant uint & r3,
|
3557
4122
|
constant int & idx,
|
@@ -3578,3 +4143,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
3578
4143
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
3579
4144
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
3580
4145
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
4146
|
+
|
4147
|
+
//
|
4148
|
+
// matrix-vector multiplication
|
4149
|
+
//
|
4150
|
+
|
4151
|
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
4152
|
+
kernel void kernel_mul_mv_id_f32_f32(
|
4153
|
+
device const char * ids,
|
4154
|
+
device const char * src1,
|
4155
|
+
device uchar * dst,
|
4156
|
+
constant int64_t & nbi1,
|
4157
|
+
constant int64_t & ne00,
|
4158
|
+
constant int64_t & ne01,
|
4159
|
+
constant int64_t & ne02,
|
4160
|
+
constant uint64_t & nb00,
|
4161
|
+
constant uint64_t & nb01,
|
4162
|
+
constant uint64_t & nb02,
|
4163
|
+
constant int64_t & ne10,
|
4164
|
+
constant int64_t & ne11,
|
4165
|
+
constant int64_t & ne12,
|
4166
|
+
constant int64_t & ne13,
|
4167
|
+
constant uint64_t & nb10,
|
4168
|
+
constant uint64_t & nb11,
|
4169
|
+
constant uint64_t & nb12,
|
4170
|
+
constant int64_t & ne0,
|
4171
|
+
constant int64_t & ne1,
|
4172
|
+
constant int64_t & nb1,
|
4173
|
+
constant uint & r2,
|
4174
|
+
constant uint & r3,
|
4175
|
+
constant int & idx,
|
4176
|
+
device const char * src00,
|
4177
|
+
device const char * src01,
|
4178
|
+
device const char * src02,
|
4179
|
+
device const char * src03,
|
4180
|
+
device const char * src04,
|
4181
|
+
device const char * src05,
|
4182
|
+
device const char * src06,
|
4183
|
+
device const char * src07,
|
4184
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4185
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4186
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4187
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4188
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4189
|
+
|
4190
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4191
|
+
|
4192
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4193
|
+
|
4194
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4195
|
+
|
4196
|
+
kernel_mul_mv_f32_f32_impl(
|
4197
|
+
src0[id],
|
4198
|
+
src1 + bid*nb11,
|
4199
|
+
(device float *) (dst + bid*nb1),
|
4200
|
+
ne00,
|
4201
|
+
ne01,
|
4202
|
+
ne02,
|
4203
|
+
nb00,
|
4204
|
+
nb01,
|
4205
|
+
nb02,
|
4206
|
+
ne10,
|
4207
|
+
ne11,
|
4208
|
+
ne12,
|
4209
|
+
nb10,
|
4210
|
+
nb11,
|
4211
|
+
nb12,
|
4212
|
+
ne0,
|
4213
|
+
ne1,
|
4214
|
+
r2,
|
4215
|
+
r3,
|
4216
|
+
tgpig,
|
4217
|
+
tiisg);
|
4218
|
+
}
|
4219
|
+
|
4220
|
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
4221
|
+
kernel void kernel_mul_mv_id_f16_f32(
|
4222
|
+
device const char * ids,
|
4223
|
+
device const char * src1,
|
4224
|
+
device uchar * dst,
|
4225
|
+
constant int64_t & nbi1,
|
4226
|
+
constant int64_t & ne00,
|
4227
|
+
constant int64_t & ne01,
|
4228
|
+
constant int64_t & ne02,
|
4229
|
+
constant uint64_t & nb00,
|
4230
|
+
constant uint64_t & nb01,
|
4231
|
+
constant uint64_t & nb02,
|
4232
|
+
constant int64_t & ne10,
|
4233
|
+
constant int64_t & ne11,
|
4234
|
+
constant int64_t & ne12,
|
4235
|
+
constant int64_t & ne13,
|
4236
|
+
constant uint64_t & nb10,
|
4237
|
+
constant uint64_t & nb11,
|
4238
|
+
constant uint64_t & nb12,
|
4239
|
+
constant int64_t & ne0,
|
4240
|
+
constant int64_t & ne1,
|
4241
|
+
constant int64_t & nb1,
|
4242
|
+
constant uint & r2,
|
4243
|
+
constant uint & r3,
|
4244
|
+
constant int & idx,
|
4245
|
+
device const char * src00,
|
4246
|
+
device const char * src01,
|
4247
|
+
device const char * src02,
|
4248
|
+
device const char * src03,
|
4249
|
+
device const char * src04,
|
4250
|
+
device const char * src05,
|
4251
|
+
device const char * src06,
|
4252
|
+
device const char * src07,
|
4253
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4254
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4255
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4256
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4257
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4258
|
+
|
4259
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4260
|
+
|
4261
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4262
|
+
|
4263
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4264
|
+
|
4265
|
+
kernel_mul_mv_f16_f32_impl(
|
4266
|
+
src0[id],
|
4267
|
+
src1 + bid*nb11,
|
4268
|
+
(device float *) (dst + bid*nb1),
|
4269
|
+
ne00,
|
4270
|
+
ne01,
|
4271
|
+
ne02,
|
4272
|
+
nb00,
|
4273
|
+
nb01,
|
4274
|
+
nb02,
|
4275
|
+
ne10,
|
4276
|
+
ne11,
|
4277
|
+
ne12,
|
4278
|
+
nb10,
|
4279
|
+
nb11,
|
4280
|
+
nb12,
|
4281
|
+
ne0,
|
4282
|
+
ne1,
|
4283
|
+
r2,
|
4284
|
+
r3,
|
4285
|
+
tgpig,
|
4286
|
+
tiisg);
|
4287
|
+
}
|
4288
|
+
|
4289
|
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
4290
|
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
4291
|
+
device const char * ids,
|
4292
|
+
device const char * src1,
|
4293
|
+
device uchar * dst,
|
4294
|
+
constant int64_t & nbi1,
|
4295
|
+
constant int64_t & ne00,
|
4296
|
+
constant int64_t & ne01,
|
4297
|
+
constant int64_t & ne02,
|
4298
|
+
constant uint64_t & nb00,
|
4299
|
+
constant uint64_t & nb01,
|
4300
|
+
constant uint64_t & nb02,
|
4301
|
+
constant int64_t & ne10,
|
4302
|
+
constant int64_t & ne11,
|
4303
|
+
constant int64_t & ne12,
|
4304
|
+
constant int64_t & ne13,
|
4305
|
+
constant uint64_t & nb10,
|
4306
|
+
constant uint64_t & nb11,
|
4307
|
+
constant uint64_t & nb12,
|
4308
|
+
constant int64_t & ne0,
|
4309
|
+
constant int64_t & ne1,
|
4310
|
+
constant int64_t & nb1,
|
4311
|
+
constant uint & r2,
|
4312
|
+
constant uint & r3,
|
4313
|
+
constant int & idx,
|
4314
|
+
device const char * src00,
|
4315
|
+
device const char * src01,
|
4316
|
+
device const char * src02,
|
4317
|
+
device const char * src03,
|
4318
|
+
device const char * src04,
|
4319
|
+
device const char * src05,
|
4320
|
+
device const char * src06,
|
4321
|
+
device const char * src07,
|
4322
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4323
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4324
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4325
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4326
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4327
|
+
|
4328
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4329
|
+
|
4330
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4331
|
+
|
4332
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4333
|
+
|
4334
|
+
kernel_mul_mv_q8_0_f32_impl(
|
4335
|
+
src0[id],
|
4336
|
+
(device const float *) (src1 + bid*nb11),
|
4337
|
+
(device float *) ( dst + bid*nb1),
|
4338
|
+
ne00,
|
4339
|
+
ne01,
|
4340
|
+
ne02,
|
4341
|
+
ne10,
|
4342
|
+
ne12,
|
4343
|
+
ne0,
|
4344
|
+
ne1,
|
4345
|
+
r2,
|
4346
|
+
r3,
|
4347
|
+
tgpig,
|
4348
|
+
tiisg,
|
4349
|
+
sgitg);
|
4350
|
+
}
|
4351
|
+
|
4352
|
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
4353
|
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
4354
|
+
device const char * ids,
|
4355
|
+
device const char * src1,
|
4356
|
+
device uchar * dst,
|
4357
|
+
constant int64_t & nbi1,
|
4358
|
+
constant int64_t & ne00,
|
4359
|
+
constant int64_t & ne01,
|
4360
|
+
constant int64_t & ne02,
|
4361
|
+
constant uint64_t & nb00,
|
4362
|
+
constant uint64_t & nb01,
|
4363
|
+
constant uint64_t & nb02,
|
4364
|
+
constant int64_t & ne10,
|
4365
|
+
constant int64_t & ne11,
|
4366
|
+
constant int64_t & ne12,
|
4367
|
+
constant int64_t & ne13,
|
4368
|
+
constant uint64_t & nb10,
|
4369
|
+
constant uint64_t & nb11,
|
4370
|
+
constant uint64_t & nb12,
|
4371
|
+
constant int64_t & ne0,
|
4372
|
+
constant int64_t & ne1,
|
4373
|
+
constant int64_t & nb1,
|
4374
|
+
constant uint & r2,
|
4375
|
+
constant uint & r3,
|
4376
|
+
constant int & idx,
|
4377
|
+
device const char * src00,
|
4378
|
+
device const char * src01,
|
4379
|
+
device const char * src02,
|
4380
|
+
device const char * src03,
|
4381
|
+
device const char * src04,
|
4382
|
+
device const char * src05,
|
4383
|
+
device const char * src06,
|
4384
|
+
device const char * src07,
|
4385
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4386
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4387
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4388
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4389
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4390
|
+
|
4391
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4392
|
+
|
4393
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4394
|
+
|
4395
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4396
|
+
|
4397
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4398
|
+
src0[id],
|
4399
|
+
(device const float *) (src1 + bid*nb11),
|
4400
|
+
(device float *) ( dst + bid*nb1),
|
4401
|
+
ne00,
|
4402
|
+
ne01,
|
4403
|
+
ne02,
|
4404
|
+
ne10,
|
4405
|
+
ne12,
|
4406
|
+
ne0,
|
4407
|
+
ne1,
|
4408
|
+
r2,
|
4409
|
+
r3,
|
4410
|
+
tgpig,
|
4411
|
+
tiisg,
|
4412
|
+
sgitg);
|
4413
|
+
}
|
4414
|
+
|
4415
|
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
4416
|
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
4417
|
+
device const char * ids,
|
4418
|
+
device const char * src1,
|
4419
|
+
device uchar * dst,
|
4420
|
+
constant int64_t & nbi1,
|
4421
|
+
constant int64_t & ne00,
|
4422
|
+
constant int64_t & ne01,
|
4423
|
+
constant int64_t & ne02,
|
4424
|
+
constant uint64_t & nb00,
|
4425
|
+
constant uint64_t & nb01,
|
4426
|
+
constant uint64_t & nb02,
|
4427
|
+
constant int64_t & ne10,
|
4428
|
+
constant int64_t & ne11,
|
4429
|
+
constant int64_t & ne12,
|
4430
|
+
constant int64_t & ne13,
|
4431
|
+
constant uint64_t & nb10,
|
4432
|
+
constant uint64_t & nb11,
|
4433
|
+
constant uint64_t & nb12,
|
4434
|
+
constant int64_t & ne0,
|
4435
|
+
constant int64_t & ne1,
|
4436
|
+
constant int64_t & nb1,
|
4437
|
+
constant uint & r2,
|
4438
|
+
constant uint & r3,
|
4439
|
+
constant int & idx,
|
4440
|
+
device const char * src00,
|
4441
|
+
device const char * src01,
|
4442
|
+
device const char * src02,
|
4443
|
+
device const char * src03,
|
4444
|
+
device const char * src04,
|
4445
|
+
device const char * src05,
|
4446
|
+
device const char * src06,
|
4447
|
+
device const char * src07,
|
4448
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4449
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4450
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4451
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4452
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4453
|
+
|
4454
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4455
|
+
|
4456
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4457
|
+
|
4458
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4459
|
+
|
4460
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4461
|
+
src0[id],
|
4462
|
+
(device const float *) (src1 + bid*nb11),
|
4463
|
+
(device float *) ( dst + bid*nb1),
|
4464
|
+
ne00,
|
4465
|
+
ne01,
|
4466
|
+
ne02,
|
4467
|
+
ne10,
|
4468
|
+
ne12,
|
4469
|
+
ne0,
|
4470
|
+
ne1,
|
4471
|
+
r2,
|
4472
|
+
r3,
|
4473
|
+
tgpig,
|
4474
|
+
tiisg,
|
4475
|
+
sgitg);
|
4476
|
+
}
|
4477
|
+
|
4478
|
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
4479
|
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
4480
|
+
device const char * ids,
|
4481
|
+
device const char * src1,
|
4482
|
+
device uchar * dst,
|
4483
|
+
constant int64_t & nbi1,
|
4484
|
+
constant int64_t & ne00,
|
4485
|
+
constant int64_t & ne01,
|
4486
|
+
constant int64_t & ne02,
|
4487
|
+
constant uint64_t & nb00,
|
4488
|
+
constant uint64_t & nb01,
|
4489
|
+
constant uint64_t & nb02,
|
4490
|
+
constant int64_t & ne10,
|
4491
|
+
constant int64_t & ne11,
|
4492
|
+
constant int64_t & ne12,
|
4493
|
+
constant int64_t & ne13,
|
4494
|
+
constant uint64_t & nb10,
|
4495
|
+
constant uint64_t & nb11,
|
4496
|
+
constant uint64_t & nb12,
|
4497
|
+
constant int64_t & ne0,
|
4498
|
+
constant int64_t & ne1,
|
4499
|
+
constant int64_t & nb1,
|
4500
|
+
constant uint & r2,
|
4501
|
+
constant uint & r3,
|
4502
|
+
constant int & idx,
|
4503
|
+
device const char * src00,
|
4504
|
+
device const char * src01,
|
4505
|
+
device const char * src02,
|
4506
|
+
device const char * src03,
|
4507
|
+
device const char * src04,
|
4508
|
+
device const char * src05,
|
4509
|
+
device const char * src06,
|
4510
|
+
device const char * src07,
|
4511
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4512
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4513
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4514
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4515
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4516
|
+
|
4517
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4518
|
+
|
4519
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4520
|
+
|
4521
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4522
|
+
|
4523
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4524
|
+
src0[id],
|
4525
|
+
(device const float *) (src1 + bid*nb11),
|
4526
|
+
(device float *) ( dst + bid*nb1),
|
4527
|
+
ne00,
|
4528
|
+
ne01,
|
4529
|
+
ne02,
|
4530
|
+
ne10,
|
4531
|
+
ne12,
|
4532
|
+
ne0,
|
4533
|
+
ne1,
|
4534
|
+
r2,
|
4535
|
+
r3,
|
4536
|
+
tgpig,
|
4537
|
+
tiisg,
|
4538
|
+
sgitg);
|
4539
|
+
}
|
4540
|
+
|
4541
|
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
4542
|
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
4543
|
+
device const char * ids,
|
4544
|
+
device const char * src1,
|
4545
|
+
device uchar * dst,
|
4546
|
+
constant int64_t & nbi1,
|
4547
|
+
constant int64_t & ne00,
|
4548
|
+
constant int64_t & ne01,
|
4549
|
+
constant int64_t & ne02,
|
4550
|
+
constant uint64_t & nb00,
|
4551
|
+
constant uint64_t & nb01,
|
4552
|
+
constant uint64_t & nb02,
|
4553
|
+
constant int64_t & ne10,
|
4554
|
+
constant int64_t & ne11,
|
4555
|
+
constant int64_t & ne12,
|
4556
|
+
constant int64_t & ne13,
|
4557
|
+
constant uint64_t & nb10,
|
4558
|
+
constant uint64_t & nb11,
|
4559
|
+
constant uint64_t & nb12,
|
4560
|
+
constant int64_t & ne0,
|
4561
|
+
constant int64_t & ne1,
|
4562
|
+
constant int64_t & nb1,
|
4563
|
+
constant uint & r2,
|
4564
|
+
constant uint & r3,
|
4565
|
+
constant int & idx,
|
4566
|
+
device const char * src00,
|
4567
|
+
device const char * src01,
|
4568
|
+
device const char * src02,
|
4569
|
+
device const char * src03,
|
4570
|
+
device const char * src04,
|
4571
|
+
device const char * src05,
|
4572
|
+
device const char * src06,
|
4573
|
+
device const char * src07,
|
4574
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4575
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4576
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4577
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4578
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4579
|
+
|
4580
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4581
|
+
|
4582
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4583
|
+
|
4584
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4585
|
+
|
4586
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4587
|
+
src0[id],
|
4588
|
+
(device const float *) (src1 + bid*nb11),
|
4589
|
+
(device float *) ( dst + bid*nb1),
|
4590
|
+
ne00,
|
4591
|
+
ne01,
|
4592
|
+
ne02,
|
4593
|
+
ne10,
|
4594
|
+
ne12,
|
4595
|
+
ne0,
|
4596
|
+
ne1,
|
4597
|
+
r2,
|
4598
|
+
r3,
|
4599
|
+
tgpig,
|
4600
|
+
tiisg,
|
4601
|
+
sgitg);
|
4602
|
+
}
|
4603
|
+
|
4604
|
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
4605
|
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
4606
|
+
device const char * ids,
|
4607
|
+
device const char * src1,
|
4608
|
+
device uchar * dst,
|
4609
|
+
constant int64_t & nbi1,
|
4610
|
+
constant int64_t & ne00,
|
4611
|
+
constant int64_t & ne01,
|
4612
|
+
constant int64_t & ne02,
|
4613
|
+
constant uint64_t & nb00,
|
4614
|
+
constant uint64_t & nb01,
|
4615
|
+
constant uint64_t & nb02,
|
4616
|
+
constant int64_t & ne10,
|
4617
|
+
constant int64_t & ne11,
|
4618
|
+
constant int64_t & ne12,
|
4619
|
+
constant int64_t & ne13,
|
4620
|
+
constant uint64_t & nb10,
|
4621
|
+
constant uint64_t & nb11,
|
4622
|
+
constant uint64_t & nb12,
|
4623
|
+
constant int64_t & ne0,
|
4624
|
+
constant int64_t & ne1,
|
4625
|
+
constant int64_t & nb1,
|
4626
|
+
constant uint & r2,
|
4627
|
+
constant uint & r3,
|
4628
|
+
constant int & idx,
|
4629
|
+
device const char * src00,
|
4630
|
+
device const char * src01,
|
4631
|
+
device const char * src02,
|
4632
|
+
device const char * src03,
|
4633
|
+
device const char * src04,
|
4634
|
+
device const char * src05,
|
4635
|
+
device const char * src06,
|
4636
|
+
device const char * src07,
|
4637
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4638
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4639
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4640
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4641
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4642
|
+
|
4643
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4644
|
+
|
4645
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4646
|
+
|
4647
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4648
|
+
|
4649
|
+
kernel_mul_mv_q2_K_f32_impl(
|
4650
|
+
src0[id],
|
4651
|
+
(device const float *) (src1 + bid*nb11),
|
4652
|
+
(device float *) ( dst + bid*nb1),
|
4653
|
+
ne00,
|
4654
|
+
ne01,
|
4655
|
+
ne02,
|
4656
|
+
ne10,
|
4657
|
+
ne12,
|
4658
|
+
ne0,
|
4659
|
+
ne1,
|
4660
|
+
r2,
|
4661
|
+
r3,
|
4662
|
+
tgpig,
|
4663
|
+
tiisg,
|
4664
|
+
sgitg);
|
4665
|
+
}
|
4666
|
+
|
4667
|
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
4668
|
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
4669
|
+
device const char * ids,
|
4670
|
+
device const char * src1,
|
4671
|
+
device uchar * dst,
|
4672
|
+
constant int64_t & nbi1,
|
4673
|
+
constant int64_t & ne00,
|
4674
|
+
constant int64_t & ne01,
|
4675
|
+
constant int64_t & ne02,
|
4676
|
+
constant uint64_t & nb00,
|
4677
|
+
constant uint64_t & nb01,
|
4678
|
+
constant uint64_t & nb02,
|
4679
|
+
constant int64_t & ne10,
|
4680
|
+
constant int64_t & ne11,
|
4681
|
+
constant int64_t & ne12,
|
4682
|
+
constant int64_t & ne13,
|
4683
|
+
constant uint64_t & nb10,
|
4684
|
+
constant uint64_t & nb11,
|
4685
|
+
constant uint64_t & nb12,
|
4686
|
+
constant int64_t & ne0,
|
4687
|
+
constant int64_t & ne1,
|
4688
|
+
constant int64_t & nb1,
|
4689
|
+
constant uint & r2,
|
4690
|
+
constant uint & r3,
|
4691
|
+
constant int & idx,
|
4692
|
+
device const char * src00,
|
4693
|
+
device const char * src01,
|
4694
|
+
device const char * src02,
|
4695
|
+
device const char * src03,
|
4696
|
+
device const char * src04,
|
4697
|
+
device const char * src05,
|
4698
|
+
device const char * src06,
|
4699
|
+
device const char * src07,
|
4700
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4701
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4702
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4703
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4704
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4705
|
+
|
4706
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4707
|
+
|
4708
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4709
|
+
|
4710
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4711
|
+
|
4712
|
+
kernel_mul_mv_q3_K_f32_impl(
|
4713
|
+
src0[id],
|
4714
|
+
(device const float *) (src1 + bid*nb11),
|
4715
|
+
(device float *) ( dst + bid*nb1),
|
4716
|
+
ne00,
|
4717
|
+
ne01,
|
4718
|
+
ne02,
|
4719
|
+
ne10,
|
4720
|
+
ne12,
|
4721
|
+
ne0,
|
4722
|
+
ne1,
|
4723
|
+
r2,
|
4724
|
+
r3,
|
4725
|
+
tgpig,
|
4726
|
+
tiisg,
|
4727
|
+
sgitg);
|
4728
|
+
}
|
4729
|
+
|
4730
|
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
4731
|
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
4732
|
+
device const char * ids,
|
4733
|
+
device const char * src1,
|
4734
|
+
device uchar * dst,
|
4735
|
+
constant int64_t & nbi1,
|
4736
|
+
constant int64_t & ne00,
|
4737
|
+
constant int64_t & ne01,
|
4738
|
+
constant int64_t & ne02,
|
4739
|
+
constant uint64_t & nb00,
|
4740
|
+
constant uint64_t & nb01,
|
4741
|
+
constant uint64_t & nb02,
|
4742
|
+
constant int64_t & ne10,
|
4743
|
+
constant int64_t & ne11,
|
4744
|
+
constant int64_t & ne12,
|
4745
|
+
constant int64_t & ne13,
|
4746
|
+
constant uint64_t & nb10,
|
4747
|
+
constant uint64_t & nb11,
|
4748
|
+
constant uint64_t & nb12,
|
4749
|
+
constant int64_t & ne0,
|
4750
|
+
constant int64_t & ne1,
|
4751
|
+
constant int64_t & nb1,
|
4752
|
+
constant uint & r2,
|
4753
|
+
constant uint & r3,
|
4754
|
+
constant int & idx,
|
4755
|
+
device const char * src00,
|
4756
|
+
device const char * src01,
|
4757
|
+
device const char * src02,
|
4758
|
+
device const char * src03,
|
4759
|
+
device const char * src04,
|
4760
|
+
device const char * src05,
|
4761
|
+
device const char * src06,
|
4762
|
+
device const char * src07,
|
4763
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4764
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4765
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4766
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4767
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4768
|
+
|
4769
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4770
|
+
|
4771
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4772
|
+
|
4773
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4774
|
+
|
4775
|
+
kernel_mul_mv_q4_K_f32_impl(
|
4776
|
+
src0[id],
|
4777
|
+
(device const float *) (src1 + bid*nb11),
|
4778
|
+
(device float *) ( dst + bid*nb1),
|
4779
|
+
ne00,
|
4780
|
+
ne01,
|
4781
|
+
ne02,
|
4782
|
+
ne10,
|
4783
|
+
ne12,
|
4784
|
+
ne0,
|
4785
|
+
ne1,
|
4786
|
+
r2,
|
4787
|
+
r3,
|
4788
|
+
tgpig,
|
4789
|
+
tiisg,
|
4790
|
+
sgitg);
|
4791
|
+
}
|
4792
|
+
|
4793
|
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
4794
|
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
4795
|
+
device const char * ids,
|
4796
|
+
device const char * src1,
|
4797
|
+
device uchar * dst,
|
4798
|
+
constant int64_t & nbi1,
|
4799
|
+
constant int64_t & ne00,
|
4800
|
+
constant int64_t & ne01,
|
4801
|
+
constant int64_t & ne02,
|
4802
|
+
constant uint64_t & nb00,
|
4803
|
+
constant uint64_t & nb01,
|
4804
|
+
constant uint64_t & nb02,
|
4805
|
+
constant int64_t & ne10,
|
4806
|
+
constant int64_t & ne11,
|
4807
|
+
constant int64_t & ne12,
|
4808
|
+
constant int64_t & ne13,
|
4809
|
+
constant uint64_t & nb10,
|
4810
|
+
constant uint64_t & nb11,
|
4811
|
+
constant uint64_t & nb12,
|
4812
|
+
constant int64_t & ne0,
|
4813
|
+
constant int64_t & ne1,
|
4814
|
+
constant int64_t & nb1,
|
4815
|
+
constant uint & r2,
|
4816
|
+
constant uint & r3,
|
4817
|
+
constant int & idx,
|
4818
|
+
device const char * src00,
|
4819
|
+
device const char * src01,
|
4820
|
+
device const char * src02,
|
4821
|
+
device const char * src03,
|
4822
|
+
device const char * src04,
|
4823
|
+
device const char * src05,
|
4824
|
+
device const char * src06,
|
4825
|
+
device const char * src07,
|
4826
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4827
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4828
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4829
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4830
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4831
|
+
|
4832
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4833
|
+
|
4834
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4835
|
+
|
4836
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4837
|
+
|
4838
|
+
kernel_mul_mv_q5_K_f32_impl(
|
4839
|
+
src0[id],
|
4840
|
+
(device const float *) (src1 + bid*nb11),
|
4841
|
+
(device float *) ( dst + bid*nb1),
|
4842
|
+
ne00,
|
4843
|
+
ne01,
|
4844
|
+
ne02,
|
4845
|
+
ne10,
|
4846
|
+
ne12,
|
4847
|
+
ne0,
|
4848
|
+
ne1,
|
4849
|
+
r2,
|
4850
|
+
r3,
|
4851
|
+
tgpig,
|
4852
|
+
tiisg,
|
4853
|
+
sgitg);
|
4854
|
+
}
|
4855
|
+
|
4856
|
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
4857
|
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
4858
|
+
device const char * ids,
|
4859
|
+
device const char * src1,
|
4860
|
+
device uchar * dst,
|
4861
|
+
constant int64_t & nbi1,
|
4862
|
+
constant int64_t & ne00,
|
4863
|
+
constant int64_t & ne01,
|
4864
|
+
constant int64_t & ne02,
|
4865
|
+
constant uint64_t & nb00,
|
4866
|
+
constant uint64_t & nb01,
|
4867
|
+
constant uint64_t & nb02,
|
4868
|
+
constant int64_t & ne10,
|
4869
|
+
constant int64_t & ne11,
|
4870
|
+
constant int64_t & ne12,
|
4871
|
+
constant int64_t & ne13,
|
4872
|
+
constant uint64_t & nb10,
|
4873
|
+
constant uint64_t & nb11,
|
4874
|
+
constant uint64_t & nb12,
|
4875
|
+
constant int64_t & ne0,
|
4876
|
+
constant int64_t & ne1,
|
4877
|
+
constant int64_t & nb1,
|
4878
|
+
constant uint & r2,
|
4879
|
+
constant uint & r3,
|
4880
|
+
constant int & idx,
|
4881
|
+
device const char * src00,
|
4882
|
+
device const char * src01,
|
4883
|
+
device const char * src02,
|
4884
|
+
device const char * src03,
|
4885
|
+
device const char * src04,
|
4886
|
+
device const char * src05,
|
4887
|
+
device const char * src06,
|
4888
|
+
device const char * src07,
|
4889
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4890
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4891
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4892
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4893
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4894
|
+
|
4895
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4896
|
+
|
4897
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4898
|
+
|
4899
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4900
|
+
|
4901
|
+
kernel_mul_mv_q6_K_f32_impl(
|
4902
|
+
src0[id],
|
4903
|
+
(device const float *) (src1 + bid*nb11),
|
4904
|
+
(device float *) ( dst + bid*nb1),
|
4905
|
+
ne00,
|
4906
|
+
ne01,
|
4907
|
+
ne02,
|
4908
|
+
ne10,
|
4909
|
+
ne12,
|
4910
|
+
ne0,
|
4911
|
+
ne1,
|
4912
|
+
r2,
|
4913
|
+
r3,
|
4914
|
+
tgpig,
|
4915
|
+
tiisg,
|
4916
|
+
sgitg);
|
4917
|
+
}
|