llama_cpp 0.9.5 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,8 @@
3
3
  using namespace metal;
4
4
 
5
5
  #define MAX(x, y) ((x) > (y) ? (x) : (y))
6
+ #define MIN(x, y) ((x) < (y) ? (x) : (y))
7
+ #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
6
8
 
7
9
  #define QK4_0 32
8
10
  #define QR4_0 2
@@ -41,8 +43,13 @@ typedef struct {
41
43
 
42
44
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
43
45
 
44
- // general-purpose kernel for addition of two tensors
45
- // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
46
+ enum ggml_sort_order {
47
+ GGML_SORT_ASC,
48
+ GGML_SORT_DESC,
49
+ };
50
+
51
+ // general-purpose kernel for addition, multiplication and division of two tensors
52
+ // pros: works for non-contiguous tensors, supports broadcast across all dims
46
53
  // cons: not very efficient
47
54
  kernel void kernel_add(
48
55
  device const char * src0,
@@ -83,16 +90,111 @@ kernel void kernel_add(
83
90
  const int64_t i12 = i02 % ne12;
84
91
  const int64_t i11 = i01 % ne11;
85
92
 
86
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
87
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
88
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
93
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
94
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
95
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
89
96
 
90
97
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
91
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
98
+ const int i10 = i0 % ne10;
99
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
100
+ }
101
+ }
92
102
 
93
- src0_ptr += ntg.x*nb00;
94
- src1_ptr += ntg.x*nb10;
95
- dst_ptr += ntg.x*nb0;
103
+ kernel void kernel_mul(
104
+ device const char * src0,
105
+ device const char * src1,
106
+ device char * dst,
107
+ constant int64_t & ne00,
108
+ constant int64_t & ne01,
109
+ constant int64_t & ne02,
110
+ constant int64_t & ne03,
111
+ constant int64_t & nb00,
112
+ constant int64_t & nb01,
113
+ constant int64_t & nb02,
114
+ constant int64_t & nb03,
115
+ constant int64_t & ne10,
116
+ constant int64_t & ne11,
117
+ constant int64_t & ne12,
118
+ constant int64_t & ne13,
119
+ constant int64_t & nb10,
120
+ constant int64_t & nb11,
121
+ constant int64_t & nb12,
122
+ constant int64_t & nb13,
123
+ constant int64_t & ne0,
124
+ constant int64_t & ne1,
125
+ constant int64_t & ne2,
126
+ constant int64_t & ne3,
127
+ constant int64_t & nb0,
128
+ constant int64_t & nb1,
129
+ constant int64_t & nb2,
130
+ constant int64_t & nb3,
131
+ uint3 tgpig[[threadgroup_position_in_grid]],
132
+ uint3 tpitg[[thread_position_in_threadgroup]],
133
+ uint3 ntg[[threads_per_threadgroup]]) {
134
+ const int64_t i03 = tgpig.z;
135
+ const int64_t i02 = tgpig.y;
136
+ const int64_t i01 = tgpig.x;
137
+
138
+ const int64_t i13 = i03 % ne13;
139
+ const int64_t i12 = i02 % ne12;
140
+ const int64_t i11 = i01 % ne11;
141
+
142
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
143
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
144
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
145
+
146
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
147
+ const int i10 = i0 % ne10;
148
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
149
+ }
150
+ }
151
+
152
+ kernel void kernel_div(
153
+ device const char * src0,
154
+ device const char * src1,
155
+ device char * dst,
156
+ constant int64_t & ne00,
157
+ constant int64_t & ne01,
158
+ constant int64_t & ne02,
159
+ constant int64_t & ne03,
160
+ constant int64_t & nb00,
161
+ constant int64_t & nb01,
162
+ constant int64_t & nb02,
163
+ constant int64_t & nb03,
164
+ constant int64_t & ne10,
165
+ constant int64_t & ne11,
166
+ constant int64_t & ne12,
167
+ constant int64_t & ne13,
168
+ constant int64_t & nb10,
169
+ constant int64_t & nb11,
170
+ constant int64_t & nb12,
171
+ constant int64_t & nb13,
172
+ constant int64_t & ne0,
173
+ constant int64_t & ne1,
174
+ constant int64_t & ne2,
175
+ constant int64_t & ne3,
176
+ constant int64_t & nb0,
177
+ constant int64_t & nb1,
178
+ constant int64_t & nb2,
179
+ constant int64_t & nb3,
180
+ uint3 tgpig[[threadgroup_position_in_grid]],
181
+ uint3 tpitg[[thread_position_in_threadgroup]],
182
+ uint3 ntg[[threads_per_threadgroup]]) {
183
+ const int64_t i03 = tgpig.z;
184
+ const int64_t i02 = tgpig.y;
185
+ const int64_t i01 = tgpig.x;
186
+
187
+ const int64_t i13 = i03 % ne13;
188
+ const int64_t i12 = i02 % ne12;
189
+ const int64_t i11 = i01 % ne11;
190
+
191
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
192
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
193
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
194
+
195
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
196
+ const int i10 = i0 % ne10;
197
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
96
198
  }
97
199
  }
98
200
 
@@ -107,23 +209,22 @@ kernel void kernel_add_row(
107
209
  dst[tpig] = src0[tpig] + src1[tpig % nb];
108
210
  }
109
211
 
110
- kernel void kernel_mul(
212
+ kernel void kernel_mul_row(
111
213
  device const float4 * src0,
112
214
  device const float4 * src1,
113
215
  device float4 * dst,
216
+ constant int64_t & nb [[buffer(27)]],
114
217
  uint tpig[[thread_position_in_grid]]) {
115
- dst[tpig] = src0[tpig] * src1[tpig];
218
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
116
219
  }
117
220
 
118
- // assumption: src1 is a row
119
- // broadcast src1 into src0
120
- kernel void kernel_mul_row(
221
+ kernel void kernel_div_row(
121
222
  device const float4 * src0,
122
223
  device const float4 * src1,
123
224
  device float4 * dst,
124
- constant int64_t & nb,
225
+ constant int64_t & nb [[buffer(27)]],
125
226
  uint tpig[[thread_position_in_grid]]) {
126
- dst[tpig] = src0[tpig] * src1[tpig % nb];
227
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
127
228
  }
128
229
 
129
230
  kernel void kernel_scale(
@@ -164,6 +265,54 @@ kernel void kernel_sqr(
164
265
  dst[tpig] = src0[tpig] * src0[tpig];
165
266
  }
166
267
 
268
+ kernel void kernel_sum_rows(
269
+ device const float * src0,
270
+ device float * dst,
271
+ constant int64_t & ne00,
272
+ constant int64_t & ne01,
273
+ constant int64_t & ne02,
274
+ constant int64_t & ne03,
275
+ constant int64_t & nb00,
276
+ constant int64_t & nb01,
277
+ constant int64_t & nb02,
278
+ constant int64_t & nb03,
279
+ constant int64_t & ne10,
280
+ constant int64_t & ne11,
281
+ constant int64_t & ne12,
282
+ constant int64_t & ne13,
283
+ constant int64_t & nb10,
284
+ constant int64_t & nb11,
285
+ constant int64_t & nb12,
286
+ constant int64_t & nb13,
287
+ constant int64_t & ne0,
288
+ constant int64_t & ne1,
289
+ constant int64_t & ne2,
290
+ constant int64_t & ne3,
291
+ constant int64_t & nb0,
292
+ constant int64_t & nb1,
293
+ constant int64_t & nb2,
294
+ constant int64_t & nb3,
295
+ uint3 tpig[[thread_position_in_grid]]) {
296
+ int64_t i3 = tpig.z;
297
+ int64_t i2 = tpig.y;
298
+ int64_t i1 = tpig.x;
299
+
300
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
301
+ return;
302
+ }
303
+
304
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
305
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
306
+
307
+ float row_sum = 0;
308
+
309
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
310
+ row_sum += src_row[i0];
311
+ }
312
+
313
+ dst_row[0] = row_sum;
314
+ }
315
+
167
316
  constant float GELU_COEF_A = 0.044715f;
168
317
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
169
318
 
@@ -582,9 +731,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
582
731
  // giard against the number of rows not being divisible by
583
732
  // N_DST, so this is another explicit assumption of the implementation.
584
733
  template<typename block_q_type, int nr, int nsg, int nw>
585
- void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
586
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
587
- uint3 tgpig, uint tiisg, uint sgitg) {
734
+ void mul_vec_q_n_f32(
735
+ device const void * src0,
736
+ device const float * src1,
737
+ device float * dst,
738
+ int64_t ne00,
739
+ int64_t ne01,
740
+ int64_t ne02,
741
+ int64_t ne10,
742
+ int64_t ne12,
743
+ int64_t ne0,
744
+ int64_t ne1,
745
+ uint r2,
746
+ uint r3,
747
+ uint3 tgpig, uint tiisg, uint sgitg) {
588
748
  const int nb = ne00/QK4_0;
589
749
 
590
750
  const int r0 = tgpig.x;
@@ -593,7 +753,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
593
753
 
594
754
  const int first_row = (r0 * nsg + sgitg) * nr;
595
755
 
596
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
756
+ const uint i12 = im%ne12;
757
+ const uint i13 = im/ne12;
758
+
759
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
597
760
 
598
761
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
599
762
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
@@ -643,13 +806,14 @@ kernel void kernel_mul_mv_q4_0_f32(
643
806
  constant int64_t & ne02[[buffer(5)]],
644
807
  constant int64_t & ne10[[buffer(9)]],
645
808
  constant int64_t & ne12[[buffer(11)]],
646
- constant int64_t & ne0[[buffer(15)]],
647
- constant int64_t & ne1[[buffer(16)]],
648
- constant uint & gqa[[buffer(17)]],
809
+ constant int64_t & ne0 [[buffer(15)]],
810
+ constant int64_t & ne1 [[buffer(16)]],
811
+ constant uint & r2 [[buffer(17)]],
812
+ constant uint & r3 [[buffer(18)]],
649
813
  uint3 tgpig[[threadgroup_position_in_grid]],
650
814
  uint tiisg[[thread_index_in_simdgroup]],
651
815
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
652
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
816
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
653
817
  }
654
818
 
655
819
  kernel void kernel_mul_mv_q4_1_f32(
@@ -661,13 +825,14 @@ kernel void kernel_mul_mv_q4_1_f32(
661
825
  constant int64_t & ne02[[buffer(5)]],
662
826
  constant int64_t & ne10[[buffer(9)]],
663
827
  constant int64_t & ne12[[buffer(11)]],
664
- constant int64_t & ne0[[buffer(15)]],
665
- constant int64_t & ne1[[buffer(16)]],
666
- constant uint & gqa[[buffer(17)]],
828
+ constant int64_t & ne0 [[buffer(15)]],
829
+ constant int64_t & ne1 [[buffer(16)]],
830
+ constant uint & r2 [[buffer(17)]],
831
+ constant uint & r3 [[buffer(18)]],
667
832
  uint3 tgpig[[threadgroup_position_in_grid]],
668
833
  uint tiisg[[thread_index_in_simdgroup]],
669
834
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
670
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
835
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
671
836
  }
672
837
 
673
838
  kernel void kernel_mul_mv_q5_0_f32(
@@ -679,13 +844,14 @@ kernel void kernel_mul_mv_q5_0_f32(
679
844
  constant int64_t & ne02[[buffer(5)]],
680
845
  constant int64_t & ne10[[buffer(9)]],
681
846
  constant int64_t & ne12[[buffer(11)]],
682
- constant int64_t & ne0[[buffer(15)]],
683
- constant int64_t & ne1[[buffer(16)]],
684
- constant uint & gqa[[buffer(17)]],
847
+ constant int64_t & ne0 [[buffer(15)]],
848
+ constant int64_t & ne1 [[buffer(16)]],
849
+ constant uint & r2 [[buffer(17)]],
850
+ constant uint & r3 [[buffer(18)]],
685
851
  uint3 tgpig[[threadgroup_position_in_grid]],
686
852
  uint tiisg[[thread_index_in_simdgroup]],
687
853
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
688
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
854
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
689
855
  }
690
856
 
691
857
  kernel void kernel_mul_mv_q5_1_f32(
@@ -697,13 +863,14 @@ kernel void kernel_mul_mv_q5_1_f32(
697
863
  constant int64_t & ne02[[buffer(5)]],
698
864
  constant int64_t & ne10[[buffer(9)]],
699
865
  constant int64_t & ne12[[buffer(11)]],
700
- constant int64_t & ne0[[buffer(15)]],
701
- constant int64_t & ne1[[buffer(16)]],
702
- constant uint & gqa[[buffer(17)]],
866
+ constant int64_t & ne0 [[buffer(15)]],
867
+ constant int64_t & ne1 [[buffer(16)]],
868
+ constant uint & r2 [[buffer(17)]],
869
+ constant uint & r3 [[buffer(18)]],
703
870
  uint3 tgpig[[threadgroup_position_in_grid]],
704
871
  uint tiisg[[thread_index_in_simdgroup]],
705
872
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
706
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
873
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
707
874
  }
708
875
 
709
876
 
@@ -718,9 +885,10 @@ kernel void kernel_mul_mv_q8_0_f32(
718
885
  constant int64_t & ne02[[buffer(5)]],
719
886
  constant int64_t & ne10[[buffer(9)]],
720
887
  constant int64_t & ne12[[buffer(11)]],
721
- constant int64_t & ne0[[buffer(15)]],
722
- constant int64_t & ne1[[buffer(16)]],
723
- constant uint & gqa[[buffer(17)]],
888
+ constant int64_t & ne0 [[buffer(15)]],
889
+ constant int64_t & ne1 [[buffer(16)]],
890
+ constant uint & r2 [[buffer(17)]],
891
+ constant uint & r3 [[buffer(18)]],
724
892
  uint3 tgpig[[threadgroup_position_in_grid]],
725
893
  uint tiisg[[thread_index_in_simdgroup]],
726
894
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -732,8 +900,14 @@ kernel void kernel_mul_mv_q8_0_f32(
732
900
  const int r0 = tgpig.x;
733
901
  const int r1 = tgpig.y;
734
902
  const int im = tgpig.z;
903
+
735
904
  const int first_row = (r0 * nsg + sgitg) * nr;
736
- const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
905
+
906
+ const uint i12 = im%ne12;
907
+ const uint i13 = im/ne12;
908
+
909
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
910
+
737
911
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
738
912
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
739
913
 
@@ -791,6 +965,8 @@ kernel void kernel_mul_mv_f32_f32(
791
965
  constant uint64_t & nb12,
792
966
  constant int64_t & ne0,
793
967
  constant int64_t & ne1,
968
+ constant uint & r2 [[buffer(17)]],
969
+ constant uint & r3 [[buffer(18)]],
794
970
  uint3 tgpig[[threadgroup_position_in_grid]],
795
971
  uint tiisg[[thread_index_in_simdgroup]]) {
796
972
 
@@ -798,7 +974,12 @@ kernel void kernel_mul_mv_f32_f32(
798
974
  const int64_t rb = tgpig.y*N_F32_F32;
799
975
  const int64_t im = tgpig.z;
800
976
 
801
- device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
977
+ const uint i12 = im%ne12;
978
+ const uint i13 = im/ne12;
979
+
980
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
981
+
982
+ device const float * x = (device const float *) (src0 + offset0);
802
983
 
803
984
  if (ne00 < 128) {
804
985
  for (int row = 0; row < N_F32_F32; ++row) {
@@ -864,6 +1045,8 @@ kernel void kernel_mul_mv_f16_f16(
864
1045
  constant uint64_t & nb12,
865
1046
  constant int64_t & ne0,
866
1047
  constant int64_t & ne1,
1048
+ constant uint & r2 [[buffer(17)]],
1049
+ constant uint & r3 [[buffer(18)]],
867
1050
  uint3 tgpig[[threadgroup_position_in_grid]],
868
1051
  uint tiisg[[thread_index_in_simdgroup]]) {
869
1052
 
@@ -871,7 +1054,12 @@ kernel void kernel_mul_mv_f16_f16(
871
1054
  const int64_t rb = tgpig.y*N_F16_F16;
872
1055
  const int64_t im = tgpig.z;
873
1056
 
874
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1057
+ const uint i12 = im%ne12;
1058
+ const uint i13 = im/ne12;
1059
+
1060
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1061
+
1062
+ device const half * x = (device const half *) (src0 + offset0);
875
1063
 
876
1064
  if (ne00 < 128) {
877
1065
  for (int row = 0; row < N_F16_F16; ++row) {
@@ -935,6 +1123,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
935
1123
  constant uint64_t & nb12,
936
1124
  constant int64_t & ne0,
937
1125
  constant int64_t & ne1,
1126
+ constant uint & r2 [[buffer(17)]],
1127
+ constant uint & r3 [[buffer(18)]],
938
1128
  uint3 tgpig[[threadgroup_position_in_grid]],
939
1129
  uint tiisg[[thread_index_in_simdgroup]]) {
940
1130
 
@@ -942,7 +1132,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
942
1132
  const int64_t r1 = tgpig.y;
943
1133
  const int64_t im = tgpig.z;
944
1134
 
945
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1135
+ const uint i12 = im%ne12;
1136
+ const uint i13 = im/ne12;
1137
+
1138
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1139
+
1140
+ device const half * x = (device const half *) (src0 + offset0);
946
1141
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
947
1142
 
948
1143
  float sumf = 0;
@@ -989,6 +1184,8 @@ kernel void kernel_mul_mv_f16_f32(
989
1184
  constant uint64_t & nb12,
990
1185
  constant int64_t & ne0,
991
1186
  constant int64_t & ne1,
1187
+ constant uint & r2 [[buffer(17)]],
1188
+ constant uint & r3 [[buffer(18)]],
992
1189
  uint3 tgpig[[threadgroup_position_in_grid]],
993
1190
  uint tiisg[[thread_index_in_simdgroup]]) {
994
1191
 
@@ -996,7 +1193,12 @@ kernel void kernel_mul_mv_f16_f32(
996
1193
  const int64_t rb = tgpig.y*N_F16_F32;
997
1194
  const int64_t im = tgpig.z;
998
1195
 
999
- device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1196
+ const uint i12 = im%ne12;
1197
+ const uint i13 = im/ne12;
1198
+
1199
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1200
+
1201
+ device const half * x = (device const half *) (src0 + offset0);
1000
1202
 
1001
1203
  if (ne00 < 128) {
1002
1204
  for (int row = 0; row < N_F16_F32; ++row) {
@@ -1061,6 +1263,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
1061
1263
  constant uint64_t & nb12,
1062
1264
  constant int64_t & ne0,
1063
1265
  constant int64_t & ne1,
1266
+ constant uint & r2 [[buffer(17)]],
1267
+ constant uint & r3 [[buffer(18)]],
1064
1268
  uint3 tgpig[[threadgroup_position_in_grid]],
1065
1269
  uint tiisg[[thread_index_in_simdgroup]]) {
1066
1270
 
@@ -1068,7 +1272,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
1068
1272
  const int64_t r0 = tgpig.x;
1069
1273
  const int64_t im = tgpig.z;
1070
1274
 
1071
- device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
1275
+ const uint i12 = im%ne12;
1276
+ const uint i13 = im/ne12;
1277
+
1278
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
1279
+
1280
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
1072
1281
 
1073
1282
  for (int r1 = 0; r1 < nrows; ++r1) {
1074
1283
  device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1120,17 +1329,21 @@ kernel void kernel_alibi_f32(
1120
1329
  const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1121
1330
  const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1122
1331
  const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1332
+ const int64_t k = i3*ne3 + i2;
1123
1333
 
1124
- device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1125
1334
  float m_k;
1126
- if (i2 < n_heads_log2_floor) {
1127
- m_k = pow(m0, i2 + 1);
1335
+ if (k < n_heads_log2_floor) {
1336
+ m_k = pow(m0, k + 1);
1128
1337
  } else {
1129
- m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
1338
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1130
1339
  }
1340
+
1341
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1342
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1131
1343
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1132
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1133
- dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
1344
+ const float src_v = *(device float *)(src_row + i00*nb00);
1345
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
1346
+ *dst_v = i00 * m_k + src_v;
1134
1347
  }
1135
1348
  }
1136
1349
 
@@ -1335,6 +1548,58 @@ kernel void kernel_im2col_f16(
1335
1548
  }
1336
1549
  }
1337
1550
 
1551
+ // bitonic sort implementation following the CUDA kernels as reference
1552
+ typedef void (argsort_t)(
1553
+ device const float * x,
1554
+ device int32_t * dst,
1555
+ constant int64_t & ncols,
1556
+ uint3 tgpig[[threadgroup_position_in_grid]],
1557
+ uint3 tpitg[[thread_position_in_threadgroup]]);
1558
+
1559
+ template<ggml_sort_order order>
1560
+ kernel void kernel_argsort_f32_i32(
1561
+ device const float * x,
1562
+ device int32_t * dst,
1563
+ constant int64_t & ncols,
1564
+ uint3 tgpig[[threadgroup_position_in_grid]],
1565
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
1566
+ // bitonic sort
1567
+ int col = tpitg[0];
1568
+ int row = tgpig[1];
1569
+
1570
+ if (col >= ncols) return;
1571
+
1572
+ device const float * x_row = x + row * ncols;
1573
+ device int32_t * dst_row = dst + row * ncols;
1574
+
1575
+ // initialize indices
1576
+ if (col < ncols) {
1577
+ dst_row[col] = col;
1578
+ }
1579
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1580
+
1581
+ for (int k = 2; k <= ncols; k *= 2) {
1582
+ for (int j = k / 2; j > 0; j /= 2) {
1583
+ int ixj = col ^ j;
1584
+ if (ixj > col) {
1585
+ if ((col & k) == 0) {
1586
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
1587
+ SWAP(dst_row[col], dst_row[ixj]);
1588
+ }
1589
+ } else {
1590
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
1591
+ SWAP(dst_row[col], dst_row[ixj]);
1592
+ }
1593
+ }
1594
+ }
1595
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1596
+ }
1597
+ }
1598
+ }
1599
+
1600
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
1601
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
1602
+
1338
1603
  kernel void kernel_cpy_f16_f16(
1339
1604
  device const half * src0,
1340
1605
  device half * dst,
@@ -1460,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
1460
1725
  }
1461
1726
  }
1462
1727
 
1728
+ kernel void kernel_cpy_f32_q8_0(
1729
+ device const float * src0,
1730
+ device void * dst,
1731
+ constant int64_t & ne00,
1732
+ constant int64_t & ne01,
1733
+ constant int64_t & ne02,
1734
+ constant int64_t & ne03,
1735
+ constant uint64_t & nb00,
1736
+ constant uint64_t & nb01,
1737
+ constant uint64_t & nb02,
1738
+ constant uint64_t & nb03,
1739
+ constant int64_t & ne0,
1740
+ constant int64_t & ne1,
1741
+ constant int64_t & ne2,
1742
+ constant int64_t & ne3,
1743
+ constant uint64_t & nb0,
1744
+ constant uint64_t & nb1,
1745
+ constant uint64_t & nb2,
1746
+ constant uint64_t & nb3,
1747
+ uint3 tgpig[[threadgroup_position_in_grid]],
1748
+ uint3 tpitg[[thread_position_in_threadgroup]],
1749
+ uint3 ntg[[threads_per_threadgroup]]) {
1750
+ const int64_t i03 = tgpig[2];
1751
+ const int64_t i02 = tgpig[1];
1752
+ const int64_t i01 = tgpig[0];
1753
+
1754
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1755
+
1756
+ const int64_t i3 = n / (ne2*ne1*ne0);
1757
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1758
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1759
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
1760
+
1761
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1762
+
1763
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
1764
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1765
+
1766
+ float amax = 0.0f; // absolute max
1767
+
1768
+ for (int j = 0; j < QK8_0; j++) {
1769
+ const float v = src[j];
1770
+ amax = MAX(amax, fabs(v));
1771
+ }
1772
+
1773
+ const float d = amax / ((1 << 7) - 1);
1774
+ const float id = d ? 1.0f/d : 0.0f;
1775
+
1776
+ dst_data[i00/QK8_0].d = d;
1777
+
1778
+ for (int j = 0; j < QK8_0; ++j) {
1779
+ const float x0 = src[j]*id;
1780
+
1781
+ dst_data[i00/QK8_0].qs[j] = round(x0);
1782
+ }
1783
+ }
1784
+ }
1785
+
1786
+ kernel void kernel_cpy_f32_q4_0(
1787
+ device const float * src0,
1788
+ device void * dst,
1789
+ constant int64_t & ne00,
1790
+ constant int64_t & ne01,
1791
+ constant int64_t & ne02,
1792
+ constant int64_t & ne03,
1793
+ constant uint64_t & nb00,
1794
+ constant uint64_t & nb01,
1795
+ constant uint64_t & nb02,
1796
+ constant uint64_t & nb03,
1797
+ constant int64_t & ne0,
1798
+ constant int64_t & ne1,
1799
+ constant int64_t & ne2,
1800
+ constant int64_t & ne3,
1801
+ constant uint64_t & nb0,
1802
+ constant uint64_t & nb1,
1803
+ constant uint64_t & nb2,
1804
+ constant uint64_t & nb3,
1805
+ uint3 tgpig[[threadgroup_position_in_grid]],
1806
+ uint3 tpitg[[thread_position_in_threadgroup]],
1807
+ uint3 ntg[[threads_per_threadgroup]]) {
1808
+ const int64_t i03 = tgpig[2];
1809
+ const int64_t i02 = tgpig[1];
1810
+ const int64_t i01 = tgpig[0];
1811
+
1812
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1813
+
1814
+ const int64_t i3 = n / (ne2*ne1*ne0);
1815
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1816
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1817
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
1818
+
1819
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+
1821
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
1822
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1823
+
1824
+ float amax = 0.0f; // absolute max
1825
+ float max = 0.0f;
1826
+
1827
+ for (int j = 0; j < QK4_0; j++) {
1828
+ const float v = src[j];
1829
+ if (amax < fabs(v)) {
1830
+ amax = fabs(v);
1831
+ max = v;
1832
+ }
1833
+ }
1834
+
1835
+ const float d = max / -8;
1836
+ const float id = d ? 1.0f/d : 0.0f;
1837
+
1838
+ dst_data[i00/QK4_0].d = d;
1839
+
1840
+ for (int j = 0; j < QK4_0/2; ++j) {
1841
+ const float x0 = src[0 + j]*id;
1842
+ const float x1 = src[QK4_0/2 + j]*id;
1843
+
1844
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
1845
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
1846
+
1847
+ dst_data[i00/QK4_0].qs[j] = xi0;
1848
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
1849
+ }
1850
+ }
1851
+ }
1852
+
1853
+ kernel void kernel_cpy_f32_q4_1(
1854
+ device const float * src0,
1855
+ device void * dst,
1856
+ constant int64_t & ne00,
1857
+ constant int64_t & ne01,
1858
+ constant int64_t & ne02,
1859
+ constant int64_t & ne03,
1860
+ constant uint64_t & nb00,
1861
+ constant uint64_t & nb01,
1862
+ constant uint64_t & nb02,
1863
+ constant uint64_t & nb03,
1864
+ constant int64_t & ne0,
1865
+ constant int64_t & ne1,
1866
+ constant int64_t & ne2,
1867
+ constant int64_t & ne3,
1868
+ constant uint64_t & nb0,
1869
+ constant uint64_t & nb1,
1870
+ constant uint64_t & nb2,
1871
+ constant uint64_t & nb3,
1872
+ uint3 tgpig[[threadgroup_position_in_grid]],
1873
+ uint3 tpitg[[thread_position_in_threadgroup]],
1874
+ uint3 ntg[[threads_per_threadgroup]]) {
1875
+ const int64_t i03 = tgpig[2];
1876
+ const int64_t i02 = tgpig[1];
1877
+ const int64_t i01 = tgpig[0];
1878
+
1879
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1880
+
1881
+ const int64_t i3 = n / (ne2*ne1*ne0);
1882
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1883
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1884
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
1885
+
1886
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1887
+
1888
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
1889
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1890
+
1891
+ float min = FLT_MAX;
1892
+ float max = -FLT_MAX;
1893
+
1894
+ for (int j = 0; j < QK4_1; j++) {
1895
+ const float v = src[j];
1896
+ if (min > v) min = v;
1897
+ if (max < v) max = v;
1898
+ }
1899
+
1900
+ const float d = (max - min) / ((1 << 4) - 1);
1901
+ const float id = d ? 1.0f/d : 0.0f;
1902
+
1903
+ dst_data[i00/QK4_1].d = d;
1904
+ dst_data[i00/QK4_1].m = min;
1905
+
1906
+ for (int j = 0; j < QK4_1/2; ++j) {
1907
+ const float x0 = (src[0 + j] - min)*id;
1908
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
1909
+
1910
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
1911
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
1912
+
1913
+ dst_data[i00/QK4_1].qs[j] = xi0;
1914
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
1915
+ }
1916
+ }
1917
+ }
1918
+
1463
1919
  kernel void kernel_concat(
1464
1920
  device const char * src0,
1465
1921
  device const char * src1,
@@ -1617,23 +2073,30 @@ kernel void kernel_mul_mv_q2_K_f32(
1617
2073
  constant int64_t & ne02[[buffer(5)]],
1618
2074
  constant int64_t & ne10[[buffer(9)]],
1619
2075
  constant int64_t & ne12[[buffer(11)]],
1620
- constant int64_t & ne0[[buffer(15)]],
1621
- constant int64_t & ne1[[buffer(16)]],
1622
- constant uint & gqa[[buffer(17)]],
2076
+ constant int64_t & ne0 [[buffer(15)]],
2077
+ constant int64_t & ne1 [[buffer(16)]],
2078
+ constant uint & r2 [[buffer(17)]],
2079
+ constant uint & r3 [[buffer(18)]],
1623
2080
  uint3 tgpig[[threadgroup_position_in_grid]],
1624
- uint tiisg[[thread_index_in_simdgroup]],
1625
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2081
+ uint tiisg[[thread_index_in_simdgroup]],
2082
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1626
2083
 
1627
2084
  const int nb = ne00/QK_K;
1628
2085
  const int r0 = tgpig.x;
1629
2086
  const int r1 = tgpig.y;
1630
- const int r2 = tgpig.z;
2087
+ const int im = tgpig.z;
1631
2088
 
1632
2089
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1633
2090
  const int ib_row = first_row * nb;
1634
- const uint offset0 = r2/gqa*(nb*ne0);
2091
+
2092
+ const uint i12 = im%ne12;
2093
+ const uint i13 = im/ne12;
2094
+
2095
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2096
+
1635
2097
  device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
1636
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2098
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2099
+
1637
2100
  float yl[32];
1638
2101
  float sumf[N_DST]={0.f}, all_sum;
1639
2102
 
@@ -1642,11 +2105,11 @@ kernel void kernel_mul_mv_q2_K_f32(
1642
2105
  #if QK_K == 256
1643
2106
  const int ix = tiisg/8; // 0...3
1644
2107
  const int it = tiisg%8; // 0...7
1645
- const int im = it/4; // 0 or 1
2108
+ const int iq = it/4; // 0 or 1
1646
2109
  const int ir = it%4; // 0...3
1647
2110
  const int is = (8*ir)/16;// 0 or 1
1648
2111
 
1649
- device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
2112
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
1650
2113
 
1651
2114
  for (int ib = ix; ib < nb; ib += 4) {
1652
2115
 
@@ -1658,8 +2121,8 @@ kernel void kernel_mul_mv_q2_K_f32(
1658
2121
  yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
1659
2122
  }
1660
2123
 
1661
- device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
1662
- device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2124
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
2125
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
1663
2126
  device const half * dh = &x[ib].d;
1664
2127
 
1665
2128
  for (int row = 0; row < N_DST; row++) {
@@ -1746,7 +2209,7 @@ kernel void kernel_mul_mv_q2_K_f32(
1746
2209
  for (int row = 0; row < N_DST; ++row) {
1747
2210
  all_sum = simd_sum(sumf[row]);
1748
2211
  if (tiisg == 0) {
1749
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2212
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
1750
2213
  }
1751
2214
  }
1752
2215
  }
@@ -1761,9 +2224,10 @@ kernel void kernel_mul_mv_q3_K_f32(
1761
2224
  constant int64_t & ne02[[buffer(5)]],
1762
2225
  constant int64_t & ne10[[buffer(9)]],
1763
2226
  constant int64_t & ne12[[buffer(11)]],
1764
- constant int64_t & ne0[[buffer(15)]],
1765
- constant int64_t & ne1[[buffer(16)]],
1766
- constant uint & gqa[[buffer(17)]],
2227
+ constant int64_t & ne0 [[buffer(15)]],
2228
+ constant int64_t & ne1 [[buffer(16)]],
2229
+ constant uint & r2 [[buffer(17)]],
2230
+ constant uint & r3 [[buffer(18)]],
1767
2231
  uint3 tgpig[[threadgroup_position_in_grid]],
1768
2232
  uint tiisg[[thread_index_in_simdgroup]],
1769
2233
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1772,12 +2236,17 @@ kernel void kernel_mul_mv_q3_K_f32(
1772
2236
 
1773
2237
  const int64_t r0 = tgpig.x;
1774
2238
  const int64_t r1 = tgpig.y;
1775
- const int64_t r2 = tgpig.z;
2239
+ const int64_t im = tgpig.z;
1776
2240
 
1777
2241
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1778
- const uint offset0 = r2/gqa*(nb*ne0);
2242
+
2243
+ const uint i12 = im%ne12;
2244
+ const uint i13 = im/ne12;
2245
+
2246
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2247
+
1779
2248
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1780
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2249
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
1781
2250
 
1782
2251
  float yl[32];
1783
2252
 
@@ -1899,7 +2368,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1899
2368
  }
1900
2369
  if (tiisg == 0) {
1901
2370
  for (int row = 0; row < 2; ++row) {
1902
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
2371
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
1903
2372
  }
1904
2373
  }
1905
2374
  }
@@ -1913,26 +2382,33 @@ kernel void kernel_mul_mv_q3_K_f32(
1913
2382
  constant int64_t & ne02[[buffer(5)]],
1914
2383
  constant int64_t & ne10[[buffer(9)]],
1915
2384
  constant int64_t & ne12[[buffer(11)]],
1916
- constant int64_t & ne0[[buffer(15)]],
1917
- constant int64_t & ne1[[buffer(16)]],
1918
- constant uint & gqa[[buffer(17)]],
2385
+ constant int64_t & ne0 [[buffer(15)]],
2386
+ constant int64_t & ne1 [[buffer(16)]],
2387
+ constant uint & r2 [[buffer(17)]],
2388
+ constant uint & r3 [[buffer(18)]],
1919
2389
  uint3 tgpig[[threadgroup_position_in_grid]],
1920
- uint tiisg[[thread_index_in_simdgroup]],
1921
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2390
+ uint tiisg[[thread_index_in_simdgroup]],
2391
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1922
2392
 
1923
2393
  const int nb = ne00/QK_K;
1924
2394
 
1925
2395
  const int64_t r0 = tgpig.x;
1926
2396
  const int64_t r1 = tgpig.y;
1927
- const int64_t r2 = tgpig.z;
2397
+ const int64_t im = tgpig.z;
1928
2398
 
1929
2399
  const int row = 2 * r0 + sgitg;
1930
- const uint offset0 = r2/gqa*(nb*ne0);
2400
+
2401
+ const uint i12 = im%ne12;
2402
+ const uint i13 = im/ne12;
2403
+
2404
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2405
+
1931
2406
  device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1932
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2407
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2408
+
1933
2409
  const int ix = tiisg/4;
1934
2410
  const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1935
- const int im = il/8; // 0, 0, 1, 1
2411
+ const int iq = il/8; // 0, 0, 1, 1
1936
2412
  const int in = il%8; // 0, 4, 0, 4
1937
2413
 
1938
2414
  float2 sum = {0.f, 0.f};
@@ -1952,7 +2428,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1952
2428
  const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1953
2429
 
1954
2430
  for (int l = 0; l < 4; l += 2) {
1955
- const uint16_t hm = h[l/2] >> im;
2431
+ const uint16_t hm = h[l/2] >> iq;
1956
2432
  sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1957
2433
  + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1958
2434
  + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -1968,7 +2444,7 @@ kernel void kernel_mul_mv_q3_K_f32(
1968
2444
 
1969
2445
  const float tot = simd_sum(sumf);
1970
2446
  if (tiisg == 0) {
1971
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2447
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
1972
2448
  }
1973
2449
 
1974
2450
  }
@@ -1986,10 +2462,11 @@ kernel void kernel_mul_mv_q4_K_f32(
1986
2462
  constant int64_t & ne12 [[buffer(11)]],
1987
2463
  constant int64_t & ne0 [[buffer(15)]],
1988
2464
  constant int64_t & ne1 [[buffer(16)]],
1989
- constant uint & gqa [[buffer(17)]],
2465
+ constant uint & r2 [[buffer(17)]],
2466
+ constant uint & r3 [[buffer(18)]],
1990
2467
  uint3 tgpig[[threadgroup_position_in_grid]],
1991
- uint tiisg[[thread_index_in_simdgroup]],
1992
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2468
+ uint tiisg[[thread_index_in_simdgroup]],
2469
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1993
2470
 
1994
2471
  const uint16_t kmask1 = 0x3f3f;
1995
2472
  const uint16_t kmask2 = 0x0f0f;
@@ -1997,26 +2474,32 @@ kernel void kernel_mul_mv_q4_K_f32(
1997
2474
 
1998
2475
  const int ix = tiisg/8; // 0...3
1999
2476
  const int it = tiisg%8; // 0...7
2000
- const int im = it/4; // 0 or 1
2477
+ const int iq = it/4; // 0 or 1
2001
2478
  const int ir = it%4; // 0...3
2002
2479
 
2003
2480
  const int nb = ne00/QK_K;
2004
2481
  const int r0 = tgpig.x;
2005
2482
  const int r1 = tgpig.y;
2006
- const int r2 = tgpig.z;
2483
+ const int im = tgpig.z;
2007
2484
  //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2008
2485
  const int first_row = r0 * N_DST;
2009
2486
  const int ib_row = first_row * nb;
2010
- const uint offset0 = r2/gqa*(nb*ne0);
2487
+
2488
+ const uint i12 = im%ne12;
2489
+ const uint i13 = im/ne12;
2490
+
2491
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2492
+
2011
2493
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2012
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2494
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2495
+
2013
2496
  float yl[16];
2014
2497
  float yh[16];
2015
2498
  float sumf[N_DST]={0.f}, all_sum;
2016
2499
 
2017
2500
  const int step = sizeof(block_q4_K) * nb / 2;
2018
2501
 
2019
- device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
2502
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
2020
2503
 
2021
2504
  uint16_t sc16[4];
2022
2505
  thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -2031,8 +2514,8 @@ kernel void kernel_mul_mv_q4_K_f32(
2031
2514
  yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
2032
2515
  }
2033
2516
 
2034
- device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
2035
- device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
2517
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
2518
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
2036
2519
  device const half * dh = &x[ib].d;
2037
2520
 
2038
2521
  for (int row = 0; row < N_DST; row++) {
@@ -2076,7 +2559,7 @@ kernel void kernel_mul_mv_q4_K_f32(
2076
2559
  for (int row = 0; row < N_DST; ++row) {
2077
2560
  all_sum = simd_sum(sumf[row]);
2078
2561
  if (tiisg == 0) {
2079
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
2562
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
2080
2563
  }
2081
2564
  }
2082
2565
  }
@@ -2090,9 +2573,10 @@ kernel void kernel_mul_mv_q4_K_f32(
2090
2573
  constant int64_t & ne02[[buffer(5)]],
2091
2574
  constant int64_t & ne10[[buffer(9)]],
2092
2575
  constant int64_t & ne12[[buffer(11)]],
2093
- constant int64_t & ne0[[buffer(15)]],
2094
- constant int64_t & ne1[[buffer(16)]],
2095
- constant uint & gqa[[buffer(17)]],
2576
+ constant int64_t & ne0 [[buffer(15)]],
2577
+ constant int64_t & ne1 [[buffer(16)]],
2578
+ constant uint & r2 [[buffer(17)]],
2579
+ constant uint & r3 [[buffer(18)]],
2096
2580
  uint3 tgpig[[threadgroup_position_in_grid]],
2097
2581
  uint tiisg[[thread_index_in_simdgroup]],
2098
2582
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2103,12 +2587,18 @@ kernel void kernel_mul_mv_q4_K_f32(
2103
2587
  const int nb = ne00/QK_K;
2104
2588
  const int r0 = tgpig.x;
2105
2589
  const int r1 = tgpig.y;
2106
- const int r2 = tgpig.z;
2590
+ const int im = tgpig.z;
2107
2591
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
2108
2592
  const int ib_row = first_row * nb;
2109
- const uint offset0 = r2/gqa*(nb*ne0);
2593
+
2594
+ const uint i12 = im%ne12;
2595
+ const uint i13 = im/ne12;
2596
+
2597
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2598
+
2110
2599
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
2111
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2600
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2601
+
2112
2602
  float yl[8];
2113
2603
  float yh[8];
2114
2604
  float sumf[N_DST]={0.f}, all_sum;
@@ -2164,7 +2654,7 @@ kernel void kernel_mul_mv_q4_K_f32(
2164
2654
  for (int row = 0; row < N_DST; ++row) {
2165
2655
  all_sum = simd_sum(sumf[row]);
2166
2656
  if (tiisg == 0) {
2167
- dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
2657
+ dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
2168
2658
  }
2169
2659
  }
2170
2660
  }
@@ -2179,9 +2669,10 @@ kernel void kernel_mul_mv_q5_K_f32(
2179
2669
  constant int64_t & ne02[[buffer(5)]],
2180
2670
  constant int64_t & ne10[[buffer(9)]],
2181
2671
  constant int64_t & ne12[[buffer(11)]],
2182
- constant int64_t & ne0[[buffer(15)]],
2183
- constant int64_t & ne1[[buffer(16)]],
2184
- constant uint & gqa[[buffer(17)]],
2672
+ constant int64_t & ne0 [[buffer(15)]],
2673
+ constant int64_t & ne1 [[buffer(16)]],
2674
+ constant uint & r2 [[buffer(17)]],
2675
+ constant uint & r3 [[buffer(18)]],
2185
2676
  uint3 tgpig[[threadgroup_position_in_grid]],
2186
2677
  uint tiisg[[thread_index_in_simdgroup]],
2187
2678
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2190,12 +2681,17 @@ kernel void kernel_mul_mv_q5_K_f32(
2190
2681
 
2191
2682
  const int64_t r0 = tgpig.x;
2192
2683
  const int64_t r1 = tgpig.y;
2193
- const int r2 = tgpig.z;
2684
+ const int im = tgpig.z;
2194
2685
 
2195
2686
  const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
2196
- const uint offset0 = r2/gqa*(nb*ne0);
2687
+
2688
+ const uint i12 = im%ne12;
2689
+ const uint i13 = im/ne12;
2690
+
2691
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2692
+
2197
2693
  device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
2198
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2694
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2199
2695
 
2200
2696
  float sumf[2]={0.f};
2201
2697
 
@@ -2211,15 +2707,15 @@ kernel void kernel_mul_mv_q5_K_f32(
2211
2707
 
2212
2708
  const int tid = tiisg/4;
2213
2709
  const int ix = tiisg%4;
2214
- const int im = tid/4;
2710
+ const int iq = tid/4;
2215
2711
  const int ir = tid%4;
2216
2712
  const int n = 8;
2217
2713
 
2218
2714
  const int l0 = n*ir;
2219
- const int q_offset = 32*im + l0;
2220
- const int y_offset = 64*im + l0;
2715
+ const int q_offset = 32*iq + l0;
2716
+ const int y_offset = 64*iq + l0;
2221
2717
 
2222
- const uint8_t hm1 = 1u << (2*im);
2718
+ const uint8_t hm1 = 1u << (2*iq);
2223
2719
  const uint8_t hm2 = hm1 << 1;
2224
2720
  const uint8_t hm3 = hm1 << 4;
2225
2721
  const uint8_t hm4 = hm2 << 4;
@@ -2234,7 +2730,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2234
2730
  device const uint8_t * q1 = x[i].qs + q_offset;
2235
2731
  device const uint8_t * qh = x[i].qh + l0;
2236
2732
  device const half * dh = &x[i].d;
2237
- device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
2733
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
2238
2734
 
2239
2735
  device const float * y2 = y1 + 128;
2240
2736
  float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2290,7 +2786,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2290
2786
 
2291
2787
  const int il = 4 * (tiisg/8); // 0, 4, 8, 12
2292
2788
  const int ix = tiisg%8;
2293
- const int im = il/8; // 0, 0, 1, 1
2789
+ const int iq = il/8; // 0, 0, 1, 1
2294
2790
  const int in = il%8; // 0, 4, 0, 4
2295
2791
 
2296
2792
  device const float * y = yy + ix*QK_K + il;
@@ -2315,7 +2811,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2315
2811
 
2316
2812
  float2 acc = {0.f, 0.f};
2317
2813
  for (int l = 0; l < 4; ++l) {
2318
- const uint8_t hl = h[l] >> im;
2814
+ const uint8_t hl = h[l] >> iq;
2319
2815
  acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
2320
2816
  + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
2321
2817
  acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2337,7 +2833,7 @@ kernel void kernel_mul_mv_q5_K_f32(
2337
2833
  for (int row = 0; row < 2; ++row) {
2338
2834
  const float tot = simd_sum(sumf[row]);
2339
2835
  if (tiisg == 0) {
2340
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
2836
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
2341
2837
  }
2342
2838
  }
2343
2839
 
@@ -2352,9 +2848,10 @@ kernel void kernel_mul_mv_q6_K_f32(
2352
2848
  constant int64_t & ne02[[buffer(5)]],
2353
2849
  constant int64_t & ne10[[buffer(9)]],
2354
2850
  constant int64_t & ne12[[buffer(11)]],
2355
- constant int64_t & ne0[[buffer(15)]],
2356
- constant int64_t & ne1[[buffer(16)]],
2357
- constant uint & gqa[[buffer(17)]],
2851
+ constant int64_t & ne0 [[buffer(15)]],
2852
+ constant int64_t & ne1 [[buffer(16)]],
2853
+ constant uint & r2 [[buffer(17)]],
2854
+ constant uint & r3 [[buffer(18)]],
2358
2855
  uint3 tgpig[[threadgroup_position_in_grid]],
2359
2856
  uint tiisg[[thread_index_in_simdgroup]],
2360
2857
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2368,12 +2865,17 @@ kernel void kernel_mul_mv_q6_K_f32(
2368
2865
 
2369
2866
  const int64_t r0 = tgpig.x;
2370
2867
  const int64_t r1 = tgpig.y;
2371
- const int r2 = tgpig.z;
2868
+ const int im = tgpig.z;
2372
2869
 
2373
2870
  const int row = 2 * r0 + sgitg;
2374
- const uint offset0 = r2/gqa*(nb*ne0);
2871
+
2872
+ const uint i12 = im%ne12;
2873
+ const uint i13 = im/ne12;
2874
+
2875
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
2876
+
2375
2877
  device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
2376
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
2878
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
2377
2879
 
2378
2880
  float sumf = 0;
2379
2881
 
@@ -2439,7 +2941,7 @@ kernel void kernel_mul_mv_q6_K_f32(
2439
2941
 
2440
2942
  const float tot = simd_sum(sumf);
2441
2943
  if (tiisg == 0) {
2442
- dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
2944
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
2443
2945
  }
2444
2946
  }
2445
2947
 
@@ -2749,24 +3251,25 @@ kernel void kernel_get_rows(
2749
3251
 
2750
3252
  // each block_q contains 16*nl weights
2751
3253
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2752
- kernel void kernel_mul_mm(device const uchar * src0,
2753
- device const uchar * src1,
2754
- device float * dst,
2755
- constant int64_t & ne00,
2756
- constant int64_t & ne02,
2757
- constant int64_t & nb01,
2758
- constant int64_t & nb02,
2759
- constant int64_t & ne12,
2760
- constant int64_t & nb10,
2761
- constant int64_t & nb11,
2762
- constant int64_t & nb12,
2763
- constant int64_t & ne0,
2764
- constant int64_t & ne1,
2765
- constant uint & gqa,
2766
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2767
- uint3 tgpig[[threadgroup_position_in_grid]],
2768
- uint tiitg[[thread_index_in_threadgroup]],
2769
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3254
+ void kernel_mul_mm_impl(device const uchar * src0,
3255
+ device const uchar * src1,
3256
+ device float * dst,
3257
+ constant int64_t & ne00,
3258
+ constant int64_t & ne02,
3259
+ constant int64_t & nb01,
3260
+ constant int64_t & nb02,
3261
+ constant int64_t & ne12,
3262
+ constant int64_t & nb10,
3263
+ constant int64_t & nb11,
3264
+ constant int64_t & nb12,
3265
+ constant int64_t & ne0,
3266
+ constant int64_t & ne1,
3267
+ constant uint & r2,
3268
+ constant uint & r3,
3269
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3270
+ uint3 tgpig[[threadgroup_position_in_grid]],
3271
+ uint tiitg[[thread_index_in_threadgroup]],
3272
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2770
3273
 
2771
3274
  threadgroup half * sa = (threadgroup half *)(shared_memory);
2772
3275
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2792,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
2792
3295
 
2793
3296
  short il = (tiitg % THREAD_PER_ROW);
2794
3297
 
2795
- uint offset0 = im/gqa*nb02;
3298
+ const uint i12 = im%ne12;
3299
+ const uint i13 = im/ne12;
3300
+
3301
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
2796
3302
  ushort offset1 = il/nl;
2797
3303
 
2798
3304
  device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -2876,14 +3382,116 @@ kernel void kernel_mul_mm(device const uchar * src0,
2876
3382
  }
2877
3383
  }
2878
3384
 
3385
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3386
+ kernel void kernel_mul_mm(device const uchar * src0,
3387
+ device const uchar * src1,
3388
+ device float * dst,
3389
+ constant int64_t & ne00,
3390
+ constant int64_t & ne02,
3391
+ constant int64_t & nb01,
3392
+ constant int64_t & nb02,
3393
+ constant int64_t & ne12,
3394
+ constant int64_t & nb10,
3395
+ constant int64_t & nb11,
3396
+ constant int64_t & nb12,
3397
+ constant int64_t & ne0,
3398
+ constant int64_t & ne1,
3399
+ constant uint & r2,
3400
+ constant uint & r3,
3401
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3402
+ uint3 tgpig[[threadgroup_position_in_grid]],
3403
+ uint tiitg[[thread_index_in_threadgroup]],
3404
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3405
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3406
+ src0,
3407
+ src1,
3408
+ dst,
3409
+ ne00,
3410
+ ne02,
3411
+ nb01,
3412
+ nb02,
3413
+ ne12,
3414
+ nb10,
3415
+ nb11,
3416
+ nb12,
3417
+ ne0,
3418
+ ne1,
3419
+ r2,
3420
+ r3,
3421
+ shared_memory,
3422
+ tgpig,
3423
+ tiitg,
3424
+ sgitg);
3425
+ }
3426
+
3427
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
3428
+ kernel void kernel_mul_mm_id(
3429
+ device const int32_t * ids,
3430
+ device const uchar * src1,
3431
+ device float * dst,
3432
+ constant int64_t & ne00,
3433
+ constant int64_t & ne02,
3434
+ constant int64_t & nb01,
3435
+ constant int64_t & nb02,
3436
+ constant int64_t & ne12,
3437
+ constant int64_t & nb10,
3438
+ constant int64_t & nb11,
3439
+ constant int64_t & nb12,
3440
+ constant int64_t & ne0,
3441
+ constant int64_t & ne1,
3442
+ constant uint & r2,
3443
+ constant uint & r3,
3444
+ constant int & idx,
3445
+ device const uchar * src00,
3446
+ device const uchar * src01,
3447
+ device const uchar * src02,
3448
+ device const uchar * src03,
3449
+ device const uchar * src04,
3450
+ device const uchar * src05,
3451
+ device const uchar * src06,
3452
+ device const uchar * src07,
3453
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
3454
+ uint3 tgpig[[threadgroup_position_in_grid]],
3455
+ uint tiitg[[thread_index_in_threadgroup]],
3456
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
3457
+ device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
3458
+
3459
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
3460
+ src0[ids[idx]],
3461
+ src1,
3462
+ dst,
3463
+ ne00,
3464
+ ne02,
3465
+ nb01,
3466
+ nb02,
3467
+ ne12,
3468
+ nb10,
3469
+ nb11,
3470
+ nb12,
3471
+ ne0,
3472
+ ne1,
3473
+ r2,
3474
+ r3,
3475
+ shared_memory,
3476
+ tgpig,
3477
+ tiitg,
3478
+ sgitg);
3479
+ }
3480
+
2879
3481
  #if QK_K == 256
2880
3482
  #define QK_NL 16
2881
3483
  #else
2882
3484
  #define QK_NL 4
2883
3485
  #endif
2884
3486
 
2885
- typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2886
- constant uint64_t &, constant uint64_t &, uint, uint, uint);
3487
+ typedef void (get_rows_t)(
3488
+ device const void * src0,
3489
+ device const int * src1,
3490
+ device float * dst,
3491
+ constant int64_t & ne00,
3492
+ constant uint64_t & nb01,
3493
+ constant uint64_t & nb1,
3494
+ uint, uint, uint);
2887
3495
 
2888
3496
  template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2889
3497
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
@@ -2912,8 +3520,10 @@ typedef void (mat_mm_t)(
2912
3520
  constant int64_t & nb12,
2913
3521
  constant int64_t & ne0,
2914
3522
  constant int64_t & ne1,
2915
- constant uint & gqa,
2916
- threadgroup uchar *, uint3, uint, uint);
3523
+ constant uint & r2,
3524
+ constant uint & r3,
3525
+ threadgroup uchar *,
3526
+ uint3, uint, uint);
2917
3527
 
2918
3528
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2919
3529
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -2927,3 +3537,44 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
2927
3537
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2928
3538
  template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2929
3539
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
3540
+
3541
+ typedef void (mat_mm_id_t)(
3542
+ device const int32_t * ids,
3543
+ device const uchar * src1,
3544
+ device float * dst,
3545
+ constant int64_t & ne00,
3546
+ constant int64_t & ne02,
3547
+ constant int64_t & nb01,
3548
+ constant int64_t & nb02,
3549
+ constant int64_t & ne12,
3550
+ constant int64_t & nb10,
3551
+ constant int64_t & nb11,
3552
+ constant int64_t & nb12,
3553
+ constant int64_t & ne0,
3554
+ constant int64_t & ne1,
3555
+ constant uint & r2,
3556
+ constant uint & r3,
3557
+ constant int & idx,
3558
+ device const uchar * src00,
3559
+ device const uchar * src01,
3560
+ device const uchar * src02,
3561
+ device const uchar * src03,
3562
+ device const uchar * src04,
3563
+ device const uchar * src05,
3564
+ device const uchar * src06,
3565
+ device const uchar * src07,
3566
+ threadgroup uchar *,
3567
+ uint3, uint, uint);
3568
+
3569
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
3570
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
3571
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
3572
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
3573
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
3574
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
3575
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
3576
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
3577
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
3578
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
3579
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
3580
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;