llama_cpp 0.9.5 → 0.10.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,6 +3,8 @@
3
3
  using namespace metal;
4
4
 
5
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
+ #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
6
8
 
7
9
  #define QK4_0 32
8
10
  #define QR4_0 2
@@ -41,8 +43,13 @@ typedef struct {
41
43
 
42
44
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
43
45
 
44
- // general-purpose kernel for addition of two tensors
45
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
46
+ enum ggml_sort_order {
47
+ GGML_SORT_ASC,
48
+ GGML_SORT_DESC,
49
+ };
50
+
51
+ // general-purpose kernel for addition, multiplication and division of two tensors
52
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
46
53
  // cons: not very efficient
47
54
  kernel void kernel_add(
48
55
  device const char * src0,
@@ -83,16 +90,111 @@ kernel void kernel_add(
83
90
  const int64_t i12 = i02 % ne12;
84
91
  const int64_t i11 = i01 % ne11;
85
92
 
86
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
87
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
88
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
93
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
89
96
 
90
97
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
91
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
98
+ const int i10 = i0 % ne10;
99
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
100
+ }
101
+ }
92
102
 
93
- src0_ptr += ntg.x*nb00;
94
- src1_ptr += ntg.x*nb10;
95
- dst_ptr += ntg.x*nb0;
103
+ kernel void kernel_mul(
104
+ device const char * src0,
105
+ device const char * src1,
106
+ device char * dst,
107
+ constant int64_t & ne00,
108
+ constant int64_t & ne01,
109
+ constant int64_t & ne02,
110
+ constant int64_t & ne03,
111
+ constant int64_t & nb00,
112
+ constant int64_t & nb01,
113
+ constant int64_t & nb02,
114
+ constant int64_t & nb03,
115
+ constant int64_t & ne10,
116
+ constant int64_t & ne11,
117
+ constant int64_t & ne12,
118
+ constant int64_t & ne13,
119
+ constant int64_t & nb10,
120
+ constant int64_t & nb11,
121
+ constant int64_t & nb12,
122
+ constant int64_t & nb13,
123
+ constant int64_t & ne0,
124
+ constant int64_t & ne1,
125
+ constant int64_t & ne2,
126
+ constant int64_t & ne3,
127
+ constant int64_t & nb0,
128
+ constant int64_t & nb1,
129
+ constant int64_t & nb2,
130
+ constant int64_t & nb3,
131
+ uint3 tgpig[[threadgroup_position_in_grid]],
132
+ uint3 tpitg[[thread_position_in_threadgroup]],
133
+ uint3 ntg[[threads_per_threadgroup]]) {
134
+ const int64_t i03 = tgpig.z;
135
+ const int64_t i02 = tgpig.y;
136
+ const int64_t i01 = tgpig.x;
137
+
138
+ const int64_t i13 = i03 % ne13;
139
+ const int64_t i12 = i02 % ne12;
140
+ const int64_t i11 = i01 % ne11;
141
+
142
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
143
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
144
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
145
+
146
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
147
+ const int i10 = i0 % ne10;
148
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
149
+ }
150
+ }
151
+
152
+ kernel void kernel_div(
153
+ device const char * src0,
154
+ device const char * src1,
155
+ device char * dst,
156
+ constant int64_t & ne00,
157
+ constant int64_t & ne01,
158
+ constant int64_t & ne02,
159
+ constant int64_t & ne03,
160
+ constant int64_t & nb00,
161
+ constant int64_t & nb01,
162
+ constant int64_t & nb02,
163
+ constant int64_t & nb03,
164
+ constant int64_t & ne10,
165
+ constant int64_t & ne11,
166
+ constant int64_t & ne12,
167
+ constant int64_t & ne13,
168
+ constant int64_t & nb10,
169
+ constant int64_t & nb11,
170
+ constant int64_t & nb12,
171
+ constant int64_t & nb13,
172
+ constant int64_t & ne0,
173
+ constant int64_t & ne1,
174
+ constant int64_t & ne2,
175
+ constant int64_t & ne3,
176
+ constant int64_t & nb0,
177
+ constant int64_t & nb1,
178
+ constant int64_t & nb2,
179
+ constant int64_t & nb3,
180
+ uint3 tgpig[[threadgroup_position_in_grid]],
181
+ uint3 tpitg[[thread_position_in_threadgroup]],
182
+ uint3 ntg[[threads_per_threadgroup]]) {
183
+ const int64_t i03 = tgpig.z;
184
+ const int64_t i02 = tgpig.y;
185
+ const int64_t i01 = tgpig.x;
186
+
187
+ const int64_t i13 = i03 % ne13;
188
+ const int64_t i12 = i02 % ne12;
189
+ const int64_t i11 = i01 % ne11;
190
+
191
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
192
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
193
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
194
+
195
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
196
+ const int i10 = i0 % ne10;
197
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
96
198
  }
97
199
  }
98
200
 
@@ -107,23 +209,22 @@ kernel void kernel_add_row(
107
209
  dst[tpig] = src0[tpig] + src1[tpig % nb];
108
210
  }
109
211
 
110
- kernel void kernel_mul(
212
+ kernel void kernel_mul_row(
111
213
  device const float4 * src0,
112
214
  device const float4 * src1,
113
215
  device float4 * dst,
216
+ constant int64_t & nb [[buffer(27)]],
114
217
  uint tpig[[thread_position_in_grid]]) {
115
- dst[tpig] = src0[tpig] * src1[tpig];
218
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
116
219
  }
117
220
 
118
- // assumption: src1 is a row
119
- // broadcast src1 into src0
120
- kernel void kernel_mul_row(
221
+ kernel void kernel_div_row(
121
222
  device const float4 * src0,
122
223
  device const float4 * src1,
123
224
  device float4 * dst,
124
- constant int64_t & nb,
225
+ constant int64_t & nb [[buffer(27)]],
125
226
  uint tpig[[thread_position_in_grid]]) {
126
- dst[tpig] = src0[tpig] * src1[tpig % nb];
227
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
127
228
  }
128
229
 
129
230
  kernel void kernel_scale(
@@ -164,6 +265,54 @@ kernel void kernel_sqr(
164
265
  dst[tpig] = src0[tpig] * src0[tpig];
165
266
  }
166
267
 
268
+ kernel void kernel_sum_rows(
269
+ device const float * src0,
270
+ device float * dst,
271
+ constant int64_t & ne00,
272
+ constant int64_t & ne01,
273
+ constant int64_t & ne02,
274
+ constant int64_t & ne03,
275
+ constant int64_t & nb00,
276
+ constant int64_t & nb01,
277
+ constant int64_t & nb02,
278
+ constant int64_t & nb03,
279
+ constant int64_t & ne10,
280
+ constant int64_t & ne11,
281
+ constant int64_t & ne12,
282
+ constant int64_t & ne13,
283
+ constant int64_t & nb10,
284
+ constant int64_t & nb11,
285
+ constant int64_t & nb12,
286
+ constant int64_t & nb13,
287
+ constant int64_t & ne0,
288
+ constant int64_t & ne1,
289
+ constant int64_t & ne2,
290
+ constant int64_t & ne3,
291
+ constant int64_t & nb0,
292
+ constant int64_t & nb1,
293
+ constant int64_t & nb2,
294
+ constant int64_t & nb3,
295
+ uint3 tpig[[thread_position_in_grid]]) {
296
+ int64_t i3 = tpig.z;
297
+ int64_t i2 = tpig.y;
298
+ int64_t i1 = tpig.x;
299
+
300
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
301
+ return;
302
+ }
303
+
304
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
305
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
306
+
307
+ float row_sum = 0;
308
+
309
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
310
+ row_sum += src_row[i0];
311
+ }
312
+
313
+ dst_row[0] = row_sum;
314
+ }
315
+
167
316
  constant float GELU_COEF_A = 0.044715f;
168
317
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
169
318
 
@@ -582,9 +731,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
582
731
  // giard against the number of rows not being divisible by
583
732
  // N_DST, so this is another explicit assumption of the implementation.
584
733
  template<typename block_q_type, int nr, int nsg, int nw>
585
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
586
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
- uint3 tgpig, uint tiisg, uint sgitg) {
734
+ void mul_vec_q_n_f32(
735
+ device const void * src0,
736
+ device const float * src1,
737
+ device float * dst,
738
+ int64_t ne00,
739
+ int64_t ne01,
740
+ int64_t ne02,
741
+ int64_t ne10,
742
+ int64_t ne12,
743
+ int64_t ne0,
744
+ int64_t ne1,
745
+ uint r2,
746
+ uint r3,
747
+ uint3 tgpig, uint tiisg, uint sgitg) {
588
748
  const int nb = ne00/QK4_0;
589
749
 
590
750
  const int r0 = tgpig.x;
@@ -593,7 +753,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
593
753
 
594
754
  const int first_row = (r0 * nsg + sgitg) * nr;
595
755
 
596
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
756
+ const uint i12 = im%ne12;
757
+ const uint i13 = im/ne12;
758
+
759
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
597
760
 
598
761
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
762
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +806,14 @@ kernel void kernel_mul_mv_q4_0_f32(
643
806
  constant int64_t & ne02[[buffer(5)]],
644
807
  constant int64_t & ne10[[buffer(9)]],
645
808
  constant int64_t & ne12[[buffer(11)]],
646
- constant int64_t & ne0[[buffer(15)]],
647
- constant int64_t & ne1[[buffer(16)]],
648
- constant uint & gqa[[buffer(17)]],
809
+ constant int64_t & ne0 [[buffer(15)]],
810
+ constant int64_t & ne1 [[buffer(16)]],
811
+ constant uint & r2 [[buffer(17)]],
812
+ constant uint & r3 [[buffer(18)]],
649
813
  uint3 tgpig[[threadgroup_position_in_grid]],
650
814
  uint tiisg[[thread_index_in_simdgroup]],
651
815
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
816
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
653
817
  }
654
818
 
655
819
  kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +825,14 @@ kernel void kernel_mul_mv_q4_1_f32(
661
825
  constant int64_t & ne02[[buffer(5)]],
662
826
  constant int64_t & ne10[[buffer(9)]],
663
827
  constant int64_t & ne12[[buffer(11)]],
664
- constant int64_t & ne0[[buffer(15)]],
665
- constant int64_t & ne1[[buffer(16)]],
666
- constant uint & gqa[[buffer(17)]],
828
+ constant int64_t & ne0 [[buffer(15)]],
829
+ constant int64_t & ne1 [[buffer(16)]],
830
+ constant uint & r2 [[buffer(17)]],
831
+ constant uint & r3 [[buffer(18)]],
667
832
  uint3 tgpig[[threadgroup_position_in_grid]],
668
833
  uint tiisg[[thread_index_in_simdgroup]],
669
834
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
670
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
835
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
671
836
  }
672
837
 
673
838
  kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +844,14 @@ kernel void kernel_mul_mv_q5_0_f32(
679
844
  constant int64_t & ne02[[buffer(5)]],
680
845
  constant int64_t & ne10[[buffer(9)]],
681
846
  constant int64_t & ne12[[buffer(11)]],
682
- constant int64_t & ne0[[buffer(15)]],
683
- constant int64_t & ne1[[buffer(16)]],
684
- constant uint & gqa[[buffer(17)]],
847
+ constant int64_t & ne0 [[buffer(15)]],
848
+ constant int64_t & ne1 [[buffer(16)]],
849
+ constant uint & r2 [[buffer(17)]],
850
+ constant uint & r3 [[buffer(18)]],
685
851
  uint3 tgpig[[threadgroup_position_in_grid]],
686
852
  uint tiisg[[thread_index_in_simdgroup]],
687
853
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
854
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
689
855
  }
690
856
 
691
857
  kernel void kernel_mul_mv_q5_1_f32(
@@ -697,13 +863,14 @@ kernel void kernel_mul_mv_q5_1_f32(
697
863
  constant int64_t & ne02[[buffer(5)]],
698
864
  constant int64_t & ne10[[buffer(9)]],
699
865
  constant int64_t & ne12[[buffer(11)]],
700
- constant int64_t & ne0[[buffer(15)]],
701
- constant int64_t & ne1[[buffer(16)]],
702
- constant uint & gqa[[buffer(17)]],
866
+ constant int64_t & ne0 [[buffer(15)]],
867
+ constant int64_t & ne1 [[buffer(16)]],
868
+ constant uint & r2 [[buffer(17)]],
869
+ constant uint & r3 [[buffer(18)]],
703
870
  uint3 tgpig[[threadgroup_position_in_grid]],
704
871
  uint tiisg[[thread_index_in_simdgroup]],
705
872
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
873
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
707
874
  }
708
875
 
709
876
 
@@ -718,9 +885,10 @@ kernel void kernel_mul_mv_q8_0_f32(
718
885
  constant int64_t & ne02[[buffer(5)]],
719
886
  constant int64_t & ne10[[buffer(9)]],
720
887
  constant int64_t & ne12[[buffer(11)]],
721
- constant int64_t & ne0[[buffer(15)]],
722
- constant int64_t & ne1[[buffer(16)]],
723
- constant uint & gqa[[buffer(17)]],
888
+ constant int64_t & ne0 [[buffer(15)]],
889
+ constant int64_t & ne1 [[buffer(16)]],
890
+ constant uint & r2 [[buffer(17)]],
891
+ constant uint & r3 [[buffer(18)]],
724
892
  uint3 tgpig[[threadgroup_position_in_grid]],
725
893
  uint tiisg[[thread_index_in_simdgroup]],
726
894
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -732,8 +900,14 @@ kernel void kernel_mul_mv_q8_0_f32(
732
900
  const int r0 = tgpig.x;
733
901
  const int r1 = tgpig.y;
734
902
  const int im = tgpig.z;
903
+
735
904
  const int first_row = (r0 * nsg + sgitg) * nr;
736
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
905
+
906
+ const uint i12 = im%ne12;
907
+ const uint i13 = im/ne12;
908
+
909
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
910
+
737
911
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
738
912
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
739
913
 
@@ -791,6 +965,8 @@ kernel void kernel_mul_mv_f32_f32(
791
965
  constant uint64_t & nb12,
792
966
  constant int64_t & ne0,
793
967
  constant int64_t & ne1,
968
+ constant uint & r2 [[buffer(17)]],
969
+ constant uint & r3 [[buffer(18)]],
794
970
  uint3 tgpig[[threadgroup_position_in_grid]],
795
971
  uint tiisg[[thread_index_in_simdgroup]]) {
796
972
 
@@ -798,7 +974,12 @@ kernel void kernel_mul_mv_f32_f32(
798
974
  const int64_t rb = tgpig.y*N_F32_F32;
799
975
  const int64_t im = tgpig.z;
800
976
 
801
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
977
+ const uint i12 = im%ne12;
978
+ const uint i13 = im/ne12;
979
+
980
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
981
+
982
+ device const float * x = (device const float *) (src0 + offset0);
802
983
 
803
984
  if (ne00 < 128) {
804
985
  for (int row = 0; row < N_F32_F32; ++row) {
@@ -864,6 +1045,8 @@ kernel void kernel_mul_mv_f16_f16(
864
1045
  constant uint64_t & nb12,
865
1046
  constant int64_t & ne0,
866
1047
  constant int64_t & ne1,
1048
+ constant uint & r2 [[buffer(17)]],
1049
+ constant uint & r3 [[buffer(18)]],
867
1050
  uint3 tgpig[[threadgroup_position_in_grid]],
868
1051
  uint tiisg[[thread_index_in_simdgroup]]) {
869
1052
 
@@ -871,7 +1054,12 @@ kernel void kernel_mul_mv_f16_f16(
871
1054
  const int64_t rb = tgpig.y*N_F16_F16;
872
1055
  const int64_t im = tgpig.z;
873
1056
 
874
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1057
+ const uint i12 = im%ne12;
1058
+ const uint i13 = im/ne12;
1059
+
1060
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1061
+
1062
+ device const half * x = (device const half *) (src0 + offset0);
875
1063
 
876
1064
  if (ne00 < 128) {
877
1065
  for (int row = 0; row < N_F16_F16; ++row) {
@@ -935,6 +1123,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
935
1123
  constant uint64_t & nb12,
936
1124
  constant int64_t & ne0,
937
1125
  constant int64_t & ne1,
1126
+ constant uint & r2 [[buffer(17)]],
1127
+ constant uint & r3 [[buffer(18)]],
938
1128
  uint3 tgpig[[threadgroup_position_in_grid]],
939
1129
  uint tiisg[[thread_index_in_simdgroup]]) {
940
1130
 
@@ -942,7 +1132,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
942
1132
  const int64_t r1 = tgpig.y;
943
1133
  const int64_t im = tgpig.z;
944
1134
 
945
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1135
+ const uint i12 = im%ne12;
1136
+ const uint i13 = im/ne12;
1137
+
1138
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1139
+
1140
+ device const half * x = (device const half *) (src0 + offset0);
946
1141
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
947
1142
 
948
1143
  float sumf = 0;
@@ -989,6 +1184,8 @@ kernel void kernel_mul_mv_f16_f32(
989
1184
  constant uint64_t & nb12,
990
1185
  constant int64_t & ne0,
991
1186
  constant int64_t & ne1,
1187
+ constant uint & r2 [[buffer(17)]],
1188
+ constant uint & r3 [[buffer(18)]],
992
1189
  uint3 tgpig[[threadgroup_position_in_grid]],
993
1190
  uint tiisg[[thread_index_in_simdgroup]]) {
994
1191
 
@@ -996,7 +1193,12 @@ kernel void kernel_mul_mv_f16_f32(
996
1193
  const int64_t rb = tgpig.y*N_F16_F32;
997
1194
  const int64_t im = tgpig.z;
998
1195
 
999
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1196
+ const uint i12 = im%ne12;
1197
+ const uint i13 = im/ne12;
1198
+
1199
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1200
+
1201
+ device const half * x = (device const half *) (src0 + offset0);
1000
1202
 
1001
1203
  if (ne00 < 128) {
1002
1204
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -1061,6 +1263,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1061
1263
  constant uint64_t & nb12,
1062
1264
  constant int64_t & ne0,
1063
1265
  constant int64_t & ne1,
1266
+ constant uint & r2 [[buffer(17)]],
1267
+ constant uint & r3 [[buffer(18)]],
1064
1268
  uint3 tgpig[[threadgroup_position_in_grid]],
1065
1269
  uint tiisg[[thread_index_in_simdgroup]]) {
1066
1270
 
@@ -1068,7 +1272,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
1068
1272
  const int64_t r0 = tgpig.x;
1069
1273
  const int64_t im = tgpig.z;
1070
1274
 
1071
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1275
+ const uint i12 = im%ne12;
1276
+ const uint i13 = im/ne12;
1277
+
1278
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1279
+
1280
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
1072
1281
 
1073
1282
  for (int r1 = 0; r1 < nrows; ++r1) {
1074
1283
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1120,17 +1329,21 @@ kernel void kernel_alibi_f32(
1120
1329
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1121
1330
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1122
1331
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1332
+ const int64_t k = i3*ne3 + i2;
1123
1333
 
1124
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1125
1334
  float m_k;
1126
- if (i2 < n_heads_log2_floor) {
1127
- m_k = pow(m0, i2 + 1);
1335
+ if (k < n_heads_log2_floor) {
1336
+ m_k = pow(m0, k + 1);
1128
1337
  } else {
1129
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1338
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1130
1339
  }
1340
+
1341
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1342
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1131
1343
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1132
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1133
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1344
+ const float src_v = *(device float *)(src_row + i00*nb00);
1345
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
1346
+ *dst_v = i00 * m_k + src_v;
1134
1347
  }
1135
1348
  }
1136
1349
 
@@ -1335,6 +1548,58 @@ kernel void kernel_im2col_f16(
1335
1548
  }
1336
1549
  }
1337
1550
 
1551
+ // bitonic sort implementation following the CUDA kernels as reference
1552
+ typedef void (argsort_t)(
1553
+ device const float * x,
1554
+ device int32_t * dst,
1555
+ constant int64_t & ncols,
1556
+ uint3 tgpig[[threadgroup_position_in_grid]],
1557
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1558
+
1559
+ template<ggml_sort_order order>
1560
+ kernel void kernel_argsort_f32_i32(
1561
+ device const float * x,
1562
+ device int32_t * dst,
1563
+ constant int64_t & ncols,
1564
+ uint3 tgpig[[threadgroup_position_in_grid]],
1565
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1566
+ // bitonic sort
1567
+ int col = tpitg[0];
1568
+ int row = tgpig[1];
1569
+
1570
+ if (col >= ncols) return;
1571
+
1572
+ device const float * x_row = x + row * ncols;
1573
+ device int32_t * dst_row = dst + row * ncols;
1574
+
1575
+ // initialize indices
1576
+ if (col < ncols) {
1577
+ dst_row[col] = col;
1578
+ }
1579
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1580
+
1581
+ for (int k = 2; k <= ncols; k *= 2) {
1582
+ for (int j = k / 2; j > 0; j /= 2) {
1583
+ int ixj = col ^ j;
1584
+ if (ixj > col) {
1585
+ if ((col & k) == 0) {
1586
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1587
+ SWAP(dst_row[col], dst_row[ixj]);
1588
+ }
1589
+ } else {
1590
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1591
+ SWAP(dst_row[col], dst_row[ixj]);
1592
+ }
1593
+ }
1594
+ }
1595
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1596
+ }
1597
+ }
1598
+ }
1599
+
1600
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1602
+
1338
1603
  kernel void kernel_cpy_f16_f16(
1339
1604
  device const half * src0,
1340
1605
  device half * dst,
@@ -1460,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
1460
1725
  }
1461
1726
  }
1462
1727
 
1728
+ kernel void kernel_cpy_f32_q8_0(
1729
+ device const float * src0,
1730
+ device void * dst,
1731
+ constant int64_t & ne00,
1732
+ constant int64_t & ne01,
1733
+ constant int64_t & ne02,
1734
+ constant int64_t & ne03,
1735
+ constant uint64_t & nb00,
1736
+ constant uint64_t & nb01,
1737
+ constant uint64_t & nb02,
1738
+ constant uint64_t & nb03,
1739
+ constant int64_t & ne0,
1740
+ constant int64_t & ne1,
1741
+ constant int64_t & ne2,
1742
+ constant int64_t & ne3,
1743
+ constant uint64_t & nb0,
1744
+ constant uint64_t & nb1,
1745
+ constant uint64_t & nb2,
1746
+ constant uint64_t & nb3,
1747
+ uint3 tgpig[[threadgroup_position_in_grid]],
1748
+ uint3 tpitg[[thread_position_in_threadgroup]],
1749
+ uint3 ntg[[threads_per_threadgroup]]) {
1750
+ const int64_t i03 = tgpig[2];
1751
+ const int64_t i02 = tgpig[1];
1752
+ const int64_t i01 = tgpig[0];
1753
+
1754
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1755
+
1756
+ const int64_t i3 = n / (ne2*ne1*ne0);
1757
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1758
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1759
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
1760
+
1761
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1762
+
1763
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
1764
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1765
+
1766
+ float amax = 0.0f; // absolute max
1767
+
1768
+ for (int j = 0; j < QK8_0; j++) {
1769
+ const float v = src[j];
1770
+ amax = MAX(amax, fabs(v));
1771
+ }
1772
+
1773
+ const float d = amax / ((1 << 7) - 1);
1774
+ const float id = d ? 1.0f/d : 0.0f;
1775
+
1776
+ dst_data[i00/QK8_0].d = d;
1777
+
1778
+ for (int j = 0; j < QK8_0; ++j) {
1779
+ const float x0 = src[j]*id;
1780
+
1781
+ dst_data[i00/QK8_0].qs[j] = round(x0);
1782
+ }
1783
+ }
1784
+ }
1785
+
1786
+ kernel void kernel_cpy_f32_q4_0(
1787
+ device const float * src0,
1788
+ device void * dst,
1789
+ constant int64_t & ne00,
1790
+ constant int64_t & ne01,
1791
+ constant int64_t & ne02,
1792
+ constant int64_t & ne03,
1793
+ constant uint64_t & nb00,
1794
+ constant uint64_t & nb01,
1795
+ constant uint64_t & nb02,
1796
+ constant uint64_t & nb03,
1797
+ constant int64_t & ne0,
1798
+ constant int64_t & ne1,
1799
+ constant int64_t & ne2,
1800
+ constant int64_t & ne3,
1801
+ constant uint64_t & nb0,
1802
+ constant uint64_t & nb1,
1803
+ constant uint64_t & nb2,
1804
+ constant uint64_t & nb3,
1805
+ uint3 tgpig[[threadgroup_position_in_grid]],
1806
+ uint3 tpitg[[thread_position_in_threadgroup]],
1807
+ uint3 ntg[[threads_per_threadgroup]]) {
1808
+ const int64_t i03 = tgpig[2];
1809
+ const int64_t i02 = tgpig[1];
1810
+ const int64_t i01 = tgpig[0];
1811
+
1812
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1813
+
1814
+ const int64_t i3 = n / (ne2*ne1*ne0);
1815
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1816
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1817
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
1818
+
1819
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+
1821
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
1822
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1823
+
1824
+ float amax = 0.0f; // absolute max
1825
+ float max = 0.0f;
1826
+
1827
+ for (int j = 0; j < QK4_0; j++) {
1828
+ const float v = src[j];
1829
+ if (amax < fabs(v)) {
1830
+ amax = fabs(v);
1831
+ max = v;
1832
+ }
1833
+ }
1834
+
1835
+ const float d = max / -8;
1836
+ const float id = d ? 1.0f/d : 0.0f;
1837
+
1838
+ dst_data[i00/QK4_0].d = d;
1839
+
1840
+ for (int j = 0; j < QK4_0/2; ++j) {
1841
+ const float x0 = src[0 + j]*id;
1842
+ const float x1 = src[QK4_0/2 + j]*id;
1843
+
1844
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
1845
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
1846
+
1847
+ dst_data[i00/QK4_0].qs[j] = xi0;
1848
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
1849
+ }
1850
+ }
1851
+ }
1852
+
1853
+ kernel void kernel_cpy_f32_q4_1(
1854
+ device const float * src0,
1855
+ device void * dst,
1856
+ constant int64_t & ne00,
1857
+ constant int64_t & ne01,
1858
+ constant int64_t & ne02,
1859
+ constant int64_t & ne03,
1860
+ constant uint64_t & nb00,
1861
+ constant uint64_t & nb01,
1862
+ constant uint64_t & nb02,
1863
+ constant uint64_t & nb03,
1864
+ constant int64_t & ne0,
1865
+ constant int64_t & ne1,
1866
+ constant int64_t & ne2,
1867
+ constant int64_t & ne3,
1868
+ constant uint64_t & nb0,
1869
+ constant uint64_t & nb1,
1870
+ constant uint64_t & nb2,
1871
+ constant uint64_t & nb3,
1872
+ uint3 tgpig[[threadgroup_position_in_grid]],
1873
+ uint3 tpitg[[thread_position_in_threadgroup]],
1874
+ uint3 ntg[[threads_per_threadgroup]]) {
1875
+ const int64_t i03 = tgpig[2];
1876
+ const int64_t i02 = tgpig[1];
1877
+ const int64_t i01 = tgpig[0];
1878
+
1879
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1880
+
1881
+ const int64_t i3 = n / (ne2*ne1*ne0);
1882
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1883
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1884
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
1885
+
1886
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1887
+
1888
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
1889
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1890
+
1891
+ float min = FLT_MAX;
1892
+ float max = -FLT_MAX;
1893
+
1894
+ for (int j = 0; j < QK4_1; j++) {
1895
+ const float v = src[j];
1896
+ if (min > v) min = v;
1897
+ if (max < v) max = v;
1898
+ }
1899
+
1900
+ const float d = (max - min) / ((1 << 4) - 1);
1901
+ const float id = d ? 1.0f/d : 0.0f;
1902
+
1903
+ dst_data[i00/QK4_1].d = d;
1904
+ dst_data[i00/QK4_1].m = min;
1905
+
1906
+ for (int j = 0; j < QK4_1/2; ++j) {
1907
+ const float x0 = (src[0 + j] - min)*id;
1908
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
1909
+
1910
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
1911
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
1912
+
1913
+ dst_data[i00/QK4_1].qs[j] = xi0;
1914
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
1915
+ }
1916
+ }
1917
+ }
1918
+
1463
1919
  kernel void kernel_concat(
1464
1920
  device const char * src0,
1465
1921
  device const char * src1,
@@ -1617,23 +2073,30 @@ kernel void kernel_mul_mv_q2_K_f32(
1617
2073
  constant int64_t & ne02[[buffer(5)]],
1618
2074
  constant int64_t & ne10[[buffer(9)]],
1619
2075
  constant int64_t & ne12[[buffer(11)]],
1620
- constant int64_t & ne0[[buffer(15)]],
1621
- constant int64_t & ne1[[buffer(16)]],
1622
- constant uint & gqa[[buffer(17)]],
2076
+ constant int64_t & ne0 [[buffer(15)]],
2077
+ constant int64_t & ne1 [[buffer(16)]],
2078
+ constant uint & r2 [[buffer(17)]],
2079
+ constant uint & r3 [[buffer(18)]],
1623
2080
  uint3 tgpig[[threadgroup_position_in_grid]],
1624
- uint tiisg[[thread_index_in_simdgroup]],
1625
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2081
+ uint tiisg[[thread_index_in_simdgroup]],
2082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1626
2083
 
1627
2084
  const int nb = ne00/QK_K;
1628
2085
  const int r0 = tgpig.x;
1629
2086
  const int r1 = tgpig.y;
1630
- const int r2 = tgpig.z;
2087
+ const int im = tgpig.z;
1631
2088
 
1632
2089
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1633
2090
  const int ib_row = first_row * nb;
1634
- const uint offset0 = r2/gqa*(nb*ne0);
2091
+
2092
+ const uint i12 = im%ne12;
2093
+ const uint i13 = im/ne12;
2094
+
2095
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2096
+
1635
2097
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1636
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2098
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2099
+
1637
2100
  float yl[32];
1638
2101
  float sumf[N_DST]={0.f}, all_sum;
1639
2102
 
@@ -1642,11 +2105,11 @@ kernel void kernel_mul_mv_q2_K_f32(
1642
2105
  #if QK_K == 256
1643
2106
  const int ix = tiisg/8; // 0...3
1644
2107
  const int it = tiisg%8; // 0...7
1645
- const int im = it/4; // 0 or 1
2108
+ const int iq = it/4; // 0 or 1
1646
2109
  const int ir = it%4; // 0...3
1647
2110
  const int is = (8*ir)/16;// 0 or 1
1648
2111
 
1649
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
2112
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
1650
2113
 
1651
2114
  for (int ib = ix; ib < nb; ib += 4) {
1652
2115
 
@@ -1658,8 +2121,8 @@ kernel void kernel_mul_mv_q2_K_f32(
1658
2121
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1659
2122
  }
1660
2123
 
1661
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1662
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2124
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2125
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1663
2126
  device const half * dh = &x[ib].d;
1664
2127
 
1665
2128
  for (int row = 0; row < N_DST; row++) {
@@ -1746,7 +2209,7 @@ kernel void kernel_mul_mv_q2_K_f32(
1746
2209
  for (int row = 0; row < N_DST; ++row) {
1747
2210
  all_sum = simd_sum(sumf[row]);
1748
2211
  if (tiisg == 0) {
1749
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2212
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1750
2213
  }
1751
2214
  }
1752
2215
  }
@@ -1761,9 +2224,10 @@ kernel void kernel_mul_mv_q3_K_f32(
1761
2224
  constant int64_t & ne02[[buffer(5)]],
1762
2225
  constant int64_t & ne10[[buffer(9)]],
1763
2226
  constant int64_t & ne12[[buffer(11)]],
1764
- constant int64_t & ne0[[buffer(15)]],
1765
- constant int64_t & ne1[[buffer(16)]],
1766
- constant uint & gqa[[buffer(17)]],
2227
+ constant int64_t & ne0 [[buffer(15)]],
2228
+ constant int64_t & ne1 [[buffer(16)]],
2229
+ constant uint & r2 [[buffer(17)]],
2230
+ constant uint & r3 [[buffer(18)]],
1767
2231
  uint3 tgpig[[threadgroup_position_in_grid]],
1768
2232
  uint tiisg[[thread_index_in_simdgroup]],
1769
2233
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1772,12 +2236,17 @@ kernel void kernel_mul_mv_q3_K_f32(
1772
2236
 
1773
2237
  const int64_t r0 = tgpig.x;
1774
2238
  const int64_t r1 = tgpig.y;
1775
- const int64_t r2 = tgpig.z;
2239
+ const int64_t im = tgpig.z;
1776
2240
 
1777
2241
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1778
- const uint offset0 = r2/gqa*(nb*ne0);
2242
+
2243
+ const uint i12 = im%ne12;
2244
+ const uint i13 = im/ne12;
2245
+
2246
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2247
+
1779
2248
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1780
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2249
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1781
2250
 
1782
2251
  float yl[32];
1783
2252
 
@@ -1899,7 +2368,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1899
2368
  }
1900
2369
  if (tiisg == 0) {
1901
2370
  for (int row = 0; row < 2; ++row) {
1902
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
2371
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
1903
2372
  }
1904
2373
  }
1905
2374
  }
@@ -1913,26 +2382,33 @@ kernel void kernel_mul_mv_q3_K_f32(
1913
2382
  constant int64_t & ne02[[buffer(5)]],
1914
2383
  constant int64_t & ne10[[buffer(9)]],
1915
2384
  constant int64_t & ne12[[buffer(11)]],
1916
- constant int64_t & ne0[[buffer(15)]],
1917
- constant int64_t & ne1[[buffer(16)]],
1918
- constant uint & gqa[[buffer(17)]],
2385
+ constant int64_t & ne0 [[buffer(15)]],
2386
+ constant int64_t & ne1 [[buffer(16)]],
2387
+ constant uint & r2 [[buffer(17)]],
2388
+ constant uint & r3 [[buffer(18)]],
1919
2389
  uint3 tgpig[[threadgroup_position_in_grid]],
1920
- uint tiisg[[thread_index_in_simdgroup]],
1921
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2390
+ uint tiisg[[thread_index_in_simdgroup]],
2391
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1922
2392
 
1923
2393
  const int nb = ne00/QK_K;
1924
2394
 
1925
2395
  const int64_t r0 = tgpig.x;
1926
2396
  const int64_t r1 = tgpig.y;
1927
- const int64_t r2 = tgpig.z;
2397
+ const int64_t im = tgpig.z;
1928
2398
 
1929
2399
  const int row = 2 * r0 + sgitg;
1930
- const uint offset0 = r2/gqa*(nb*ne0);
2400
+
2401
+ const uint i12 = im%ne12;
2402
+ const uint i13 = im/ne12;
2403
+
2404
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2405
+
1931
2406
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1932
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2407
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2408
+
1933
2409
  const int ix = tiisg/4;
1934
2410
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1935
- const int im = il/8; // 0, 0, 1, 1
2411
+ const int iq = il/8; // 0, 0, 1, 1
1936
2412
  const int in = il%8; // 0, 4, 0, 4
1937
2413
 
1938
2414
  float2 sum = {0.f, 0.f};
@@ -1952,7 +2428,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1952
2428
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1953
2429
 
1954
2430
  for (int l = 0; l < 4; l += 2) {
1955
- const uint16_t hm = h[l/2] >> im;
2431
+ const uint16_t hm = h[l/2] >> iq;
1956
2432
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1957
2433
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1958
2434
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1968,7 +2444,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1968
2444
 
1969
2445
  const float tot = simd_sum(sumf);
1970
2446
  if (tiisg == 0) {
1971
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2447
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1972
2448
  }
1973
2449
 
1974
2450
  }
@@ -1986,10 +2462,11 @@ kernel void kernel_mul_mv_q4_K_f32(
1986
2462
  constant int64_t & ne12 [[buffer(11)]],
1987
2463
  constant int64_t & ne0 [[buffer(15)]],
1988
2464
  constant int64_t & ne1 [[buffer(16)]],
1989
- constant uint & gqa [[buffer(17)]],
2465
+ constant uint & r2 [[buffer(17)]],
2466
+ constant uint & r3 [[buffer(18)]],
1990
2467
  uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiisg[[thread_index_in_simdgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2468
+ uint tiisg[[thread_index_in_simdgroup]],
2469
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
2470
 
1994
2471
  const uint16_t kmask1 = 0x3f3f;
1995
2472
  const uint16_t kmask2 = 0x0f0f;
@@ -1997,26 +2474,32 @@ kernel void kernel_mul_mv_q4_K_f32(
1997
2474
 
1998
2475
  const int ix = tiisg/8; // 0...3
1999
2476
  const int it = tiisg%8; // 0...7
2000
- const int im = it/4; // 0 or 1
2477
+ const int iq = it/4; // 0 or 1
2001
2478
  const int ir = it%4; // 0...3
2002
2479
 
2003
2480
  const int nb = ne00/QK_K;
2004
2481
  const int r0 = tgpig.x;
2005
2482
  const int r1 = tgpig.y;
2006
- const int r2 = tgpig.z;
2483
+ const int im = tgpig.z;
2007
2484
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2008
2485
  const int first_row = r0 * N_DST;
2009
2486
  const int ib_row = first_row * nb;
2010
- const uint offset0 = r2/gqa*(nb*ne0);
2487
+
2488
+ const uint i12 = im%ne12;
2489
+ const uint i13 = im/ne12;
2490
+
2491
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2492
+
2011
2493
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2012
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2494
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2495
+
2013
2496
  float yl[16];
2014
2497
  float yh[16];
2015
2498
  float sumf[N_DST]={0.f}, all_sum;
2016
2499
 
2017
2500
  const int step = sizeof(block_q4_K) * nb / 2;
2018
2501
 
2019
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2502
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
2020
2503
 
2021
2504
  uint16_t sc16[4];
2022
2505
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -2031,8 +2514,8 @@ kernel void kernel_mul_mv_q4_K_f32(
2031
2514
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
2032
2515
  }
2033
2516
 
2034
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
2035
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2517
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2518
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
2036
2519
  device const half * dh = &x[ib].d;
2037
2520
 
2038
2521
  for (int row = 0; row < N_DST; row++) {
@@ -2076,7 +2559,7 @@ kernel void kernel_mul_mv_q4_K_f32(
2076
2559
  for (int row = 0; row < N_DST; ++row) {
2077
2560
  all_sum = simd_sum(sumf[row]);
2078
2561
  if (tiisg == 0) {
2079
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2562
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2080
2563
  }
2081
2564
  }
2082
2565
  }
@@ -2090,9 +2573,10 @@ kernel void kernel_mul_mv_q4_K_f32(
2090
2573
  constant int64_t & ne02[[buffer(5)]],
2091
2574
  constant int64_t & ne10[[buffer(9)]],
2092
2575
  constant int64_t & ne12[[buffer(11)]],
2093
- constant int64_t & ne0[[buffer(15)]],
2094
- constant int64_t & ne1[[buffer(16)]],
2095
- constant uint & gqa[[buffer(17)]],
2576
+ constant int64_t & ne0 [[buffer(15)]],
2577
+ constant int64_t & ne1 [[buffer(16)]],
2578
+ constant uint & r2 [[buffer(17)]],
2579
+ constant uint & r3 [[buffer(18)]],
2096
2580
  uint3 tgpig[[threadgroup_position_in_grid]],
2097
2581
  uint tiisg[[thread_index_in_simdgroup]],
2098
2582
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2103,12 +2587,18 @@ kernel void kernel_mul_mv_q4_K_f32(
2103
2587
  const int nb = ne00/QK_K;
2104
2588
  const int r0 = tgpig.x;
2105
2589
  const int r1 = tgpig.y;
2106
- const int r2 = tgpig.z;
2590
+ const int im = tgpig.z;
2107
2591
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2108
2592
  const int ib_row = first_row * nb;
2109
- const uint offset0 = r2/gqa*(nb*ne0);
2593
+
2594
+ const uint i12 = im%ne12;
2595
+ const uint i13 = im/ne12;
2596
+
2597
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2598
+
2110
2599
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2111
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2600
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2601
+
2112
2602
  float yl[8];
2113
2603
  float yh[8];
2114
2604
  float sumf[N_DST]={0.f}, all_sum;
@@ -2164,7 +2654,7 @@ kernel void kernel_mul_mv_q4_K_f32(
2164
2654
  for (int row = 0; row < N_DST; ++row) {
2165
2655
  all_sum = simd_sum(sumf[row]);
2166
2656
  if (tiisg == 0) {
2167
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
2657
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
2168
2658
  }
2169
2659
  }
2170
2660
  }
@@ -2179,9 +2669,10 @@ kernel void kernel_mul_mv_q5_K_f32(
2179
2669
  constant int64_t & ne02[[buffer(5)]],
2180
2670
  constant int64_t & ne10[[buffer(9)]],
2181
2671
  constant int64_t & ne12[[buffer(11)]],
2182
- constant int64_t & ne0[[buffer(15)]],
2183
- constant int64_t & ne1[[buffer(16)]],
2184
- constant uint & gqa[[buffer(17)]],
2672
+ constant int64_t & ne0 [[buffer(15)]],
2673
+ constant int64_t & ne1 [[buffer(16)]],
2674
+ constant uint & r2 [[buffer(17)]],
2675
+ constant uint & r3 [[buffer(18)]],
2185
2676
  uint3 tgpig[[threadgroup_position_in_grid]],
2186
2677
  uint tiisg[[thread_index_in_simdgroup]],
2187
2678
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2190,12 +2681,17 @@ kernel void kernel_mul_mv_q5_K_f32(
2190
2681
 
2191
2682
  const int64_t r0 = tgpig.x;
2192
2683
  const int64_t r1 = tgpig.y;
2193
- const int r2 = tgpig.z;
2684
+ const int im = tgpig.z;
2194
2685
 
2195
2686
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2196
- const uint offset0 = r2/gqa*(nb*ne0);
2687
+
2688
+ const uint i12 = im%ne12;
2689
+ const uint i13 = im/ne12;
2690
+
2691
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2692
+
2197
2693
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2198
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2694
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2199
2695
 
2200
2696
  float sumf[2]={0.f};
2201
2697
 
@@ -2211,15 +2707,15 @@ kernel void kernel_mul_mv_q5_K_f32(
2211
2707
 
2212
2708
  const int tid = tiisg/4;
2213
2709
  const int ix = tiisg%4;
2214
- const int im = tid/4;
2710
+ const int iq = tid/4;
2215
2711
  const int ir = tid%4;
2216
2712
  const int n = 8;
2217
2713
 
2218
2714
  const int l0 = n*ir;
2219
- const int q_offset = 32*im + l0;
2220
- const int y_offset = 64*im + l0;
2715
+ const int q_offset = 32*iq + l0;
2716
+ const int y_offset = 64*iq + l0;
2221
2717
 
2222
- const uint8_t hm1 = 1u << (2*im);
2718
+ const uint8_t hm1 = 1u << (2*iq);
2223
2719
  const uint8_t hm2 = hm1 << 1;
2224
2720
  const uint8_t hm3 = hm1 << 4;
2225
2721
  const uint8_t hm4 = hm2 << 4;
@@ -2234,7 +2730,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2234
2730
  device const uint8_t * q1 = x[i].qs + q_offset;
2235
2731
  device const uint8_t * qh = x[i].qh + l0;
2236
2732
  device const half * dh = &x[i].d;
2237
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
2733
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
2238
2734
 
2239
2735
  device const float * y2 = y1 + 128;
2240
2736
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2290,7 +2786,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2290
2786
 
2291
2787
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2292
2788
  const int ix = tiisg%8;
2293
- const int im = il/8; // 0, 0, 1, 1
2789
+ const int iq = il/8; // 0, 0, 1, 1
2294
2790
  const int in = il%8; // 0, 4, 0, 4
2295
2791
 
2296
2792
  device const float * y = yy + ix*QK_K + il;
@@ -2315,7 +2811,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2315
2811
 
2316
2812
  float2 acc = {0.f, 0.f};
2317
2813
  for (int l = 0; l < 4; ++l) {
2318
- const uint8_t hl = h[l] >> im;
2814
+ const uint8_t hl = h[l] >> iq;
2319
2815
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2320
2816
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2321
2817
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2337,7 +2833,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2337
2833
  for (int row = 0; row < 2; ++row) {
2338
2834
  const float tot = simd_sum(sumf[row]);
2339
2835
  if (tiisg == 0) {
2340
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
2836
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2341
2837
  }
2342
2838
  }
2343
2839
 
@@ -2352,9 +2848,10 @@ kernel void kernel_mul_mv_q6_K_f32(
2352
2848
  constant int64_t & ne02[[buffer(5)]],
2353
2849
  constant int64_t & ne10[[buffer(9)]],
2354
2850
  constant int64_t & ne12[[buffer(11)]],
2355
- constant int64_t & ne0[[buffer(15)]],
2356
- constant int64_t & ne1[[buffer(16)]],
2357
- constant uint & gqa[[buffer(17)]],
2851
+ constant int64_t & ne0 [[buffer(15)]],
2852
+ constant int64_t & ne1 [[buffer(16)]],
2853
+ constant uint & r2 [[buffer(17)]],
2854
+ constant uint & r3 [[buffer(18)]],
2358
2855
  uint3 tgpig[[threadgroup_position_in_grid]],
2359
2856
  uint tiisg[[thread_index_in_simdgroup]],
2360
2857
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2368,12 +2865,17 @@ kernel void kernel_mul_mv_q6_K_f32(
2368
2865
 
2369
2866
  const int64_t r0 = tgpig.x;
2370
2867
  const int64_t r1 = tgpig.y;
2371
- const int r2 = tgpig.z;
2868
+ const int im = tgpig.z;
2372
2869
 
2373
2870
  const int row = 2 * r0 + sgitg;
2374
- const uint offset0 = r2/gqa*(nb*ne0);
2871
+
2872
+ const uint i12 = im%ne12;
2873
+ const uint i13 = im/ne12;
2874
+
2875
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2876
+
2375
2877
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2376
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2878
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2377
2879
 
2378
2880
  float sumf = 0;
2379
2881
 
@@ -2439,7 +2941,7 @@ kernel void kernel_mul_mv_q6_K_f32(
2439
2941
 
2440
2942
  const float tot = simd_sum(sumf);
2441
2943
  if (tiisg == 0) {
2442
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2944
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2443
2945
  }
2444
2946
  }
2445
2947
 
@@ -2749,24 +3251,25 @@ kernel void kernel_get_rows(
2749
3251
 
2750
3252
  // each block_q contains 16*nl weights
2751
3253
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2752
- kernel void kernel_mul_mm(device const uchar * src0,
2753
- device const uchar * src1,
2754
- device float * dst,
2755
- constant int64_t & ne00,
2756
- constant int64_t & ne02,
2757
- constant int64_t & nb01,
2758
- constant int64_t & nb02,
2759
- constant int64_t & ne12,
2760
- constant int64_t & nb10,
2761
- constant int64_t & nb11,
2762
- constant int64_t & nb12,
2763
- constant int64_t & ne0,
2764
- constant int64_t & ne1,
2765
- constant uint & gqa,
2766
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2767
- uint3 tgpig[[threadgroup_position_in_grid]],
2768
- uint tiitg[[thread_index_in_threadgroup]],
2769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3254
+ void kernel_mul_mm_impl(device const uchar * src0,
3255
+ device const uchar * src1,
3256
+ device float * dst,
3257
+ constant int64_t & ne00,
3258
+ constant int64_t & ne02,
3259
+ constant int64_t & nb01,
3260
+ constant int64_t & nb02,
3261
+ constant int64_t & ne12,
3262
+ constant int64_t & nb10,
3263
+ constant int64_t & nb11,
3264
+ constant int64_t & nb12,
3265
+ constant int64_t & ne0,
3266
+ constant int64_t & ne1,
3267
+ constant uint & r2,
3268
+ constant uint & r3,
3269
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3270
+ uint3 tgpig[[threadgroup_position_in_grid]],
3271
+ uint tiitg[[thread_index_in_threadgroup]],
3272
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2770
3273
 
2771
3274
  threadgroup half * sa = (threadgroup half *)(shared_memory);
2772
3275
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2792,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
2792
3295
 
2793
3296
  short il = (tiitg % THREAD_PER_ROW);
2794
3297
 
2795
- uint offset0 = im/gqa*nb02;
3298
+ const uint i12 = im%ne12;
3299
+ const uint i13 = im/ne12;
3300
+
3301
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
2796
3302
  ushort offset1 = il/nl;
2797
3303
 
2798
3304
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2876,14 +3382,116 @@ kernel void kernel_mul_mm(device const uchar * src0,
2876
3382
  }
2877
3383
  }
2878
3384
 
3385
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3386
+ kernel void kernel_mul_mm(device const uchar * src0,
3387
+ device const uchar * src1,
3388
+ device float * dst,
3389
+ constant int64_t & ne00,
3390
+ constant int64_t & ne02,
3391
+ constant int64_t & nb01,
3392
+ constant int64_t & nb02,
3393
+ constant int64_t & ne12,
3394
+ constant int64_t & nb10,
3395
+ constant int64_t & nb11,
3396
+ constant int64_t & nb12,
3397
+ constant int64_t & ne0,
3398
+ constant int64_t & ne1,
3399
+ constant uint & r2,
3400
+ constant uint & r3,
3401
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3402
+ uint3 tgpig[[threadgroup_position_in_grid]],
3403
+ uint tiitg[[thread_index_in_threadgroup]],
3404
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3405
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3406
+ src0,
3407
+ src1,
3408
+ dst,
3409
+ ne00,
3410
+ ne02,
3411
+ nb01,
3412
+ nb02,
3413
+ ne12,
3414
+ nb10,
3415
+ nb11,
3416
+ nb12,
3417
+ ne0,
3418
+ ne1,
3419
+ r2,
3420
+ r3,
3421
+ shared_memory,
3422
+ tgpig,
3423
+ tiitg,
3424
+ sgitg);
3425
+ }
3426
+
3427
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
+ kernel void kernel_mul_mm_id(
3429
+ device const int32_t * ids,
3430
+ device const uchar * src1,
3431
+ device float * dst,
3432
+ constant int64_t & ne00,
3433
+ constant int64_t & ne02,
3434
+ constant int64_t & nb01,
3435
+ constant int64_t & nb02,
3436
+ constant int64_t & ne12,
3437
+ constant int64_t & nb10,
3438
+ constant int64_t & nb11,
3439
+ constant int64_t & nb12,
3440
+ constant int64_t & ne0,
3441
+ constant int64_t & ne1,
3442
+ constant uint & r2,
3443
+ constant uint & r3,
3444
+ constant int & idx,
3445
+ device const uchar * src00,
3446
+ device const uchar * src01,
3447
+ device const uchar * src02,
3448
+ device const uchar * src03,
3449
+ device const uchar * src04,
3450
+ device const uchar * src05,
3451
+ device const uchar * src06,
3452
+ device const uchar * src07,
3453
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3454
+ uint3 tgpig[[threadgroup_position_in_grid]],
3455
+ uint tiitg[[thread_index_in_threadgroup]],
3456
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
+
3459
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
+ src0[ids[idx]],
3461
+ src1,
3462
+ dst,
3463
+ ne00,
3464
+ ne02,
3465
+ nb01,
3466
+ nb02,
3467
+ ne12,
3468
+ nb10,
3469
+ nb11,
3470
+ nb12,
3471
+ ne0,
3472
+ ne1,
3473
+ r2,
3474
+ r3,
3475
+ shared_memory,
3476
+ tgpig,
3477
+ tiitg,
3478
+ sgitg);
3479
+ }
3480
+
2879
3481
  #if QK_K == 256
2880
3482
  #define QK_NL 16
2881
3483
  #else
2882
3484
  #define QK_NL 4
2883
3485
  #endif
2884
3486
 
2885
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2886
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
3487
+ typedef void (get_rows_t)(
3488
+ device const void * src0,
3489
+ device const int * src1,
3490
+ device float * dst,
3491
+ constant int64_t & ne00,
3492
+ constant uint64_t & nb01,
3493
+ constant uint64_t & nb1,
3494
+ uint, uint, uint);
2887
3495
 
2888
3496
  template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2889
3497
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
@@ -2912,8 +3520,10 @@ typedef void (mat_mm_t)(
2912
3520
  constant int64_t & nb12,
2913
3521
  constant int64_t & ne0,
2914
3522
  constant int64_t & ne1,
2915
- constant uint & gqa,
2916
- threadgroup uchar *, uint3, uint, uint);
3523
+ constant uint & r2,
3524
+ constant uint & r3,
3525
+ threadgroup uchar *,
3526
+ uint3, uint, uint);
2917
3527
 
2918
3528
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2919
3529
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -2927,3 +3537,44 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
2927
3537
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2928
3538
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2929
3539
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
3540
+
3541
+ typedef void (mat_mm_id_t)(
3542
+ device const int32_t * ids,
3543
+ device const uchar * src1,
3544
+ device float * dst,
3545
+ constant int64_t & ne00,
3546
+ constant int64_t & ne02,
3547
+ constant int64_t & nb01,
3548
+ constant int64_t & nb02,
3549
+ constant int64_t & ne12,
3550
+ constant int64_t & nb10,
3551
+ constant int64_t & nb11,
3552
+ constant int64_t & nb12,
3553
+ constant int64_t & ne0,
3554
+ constant int64_t & ne1,
3555
+ constant uint & r2,
3556
+ constant uint & r3,
3557
+ constant int & idx,
3558
+ device const uchar * src00,
3559
+ device const uchar * src01,
3560
+ device const uchar * src02,
3561
+ device const uchar * src03,
3562
+ device const uchar * src04,
3563
+ device const uchar * src05,
3564
+ device const uchar * src06,
3565
+ device const uchar * src07,
3566
+ threadgroup uchar *,
3567
+ uint3, uint, uint);
3568
+
3569
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
3570
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
3571
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
3572
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
3573
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
3574
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
3575
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
3576
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
3577
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
3578
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;