llama_cpp 0.9.5 → 0.10.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/ext/llama_cpp/llama_cpp.cpp +123 -15
- data/ext/llama_cpp/src/ggml-alloc.c +42 -7
- data/ext/llama_cpp/src/ggml-alloc.h +8 -1
- data/ext/llama_cpp/src/ggml-backend-impl.h +46 -21
- data/ext/llama_cpp/src/ggml-backend.c +563 -156
- data/ext/llama_cpp/src/ggml-backend.h +62 -17
- data/ext/llama_cpp/src/ggml-cuda.cu +1796 -413
- data/ext/llama_cpp/src/ggml-cuda.h +9 -1
- data/ext/llama_cpp/src/ggml-impl.h +1 -1
- data/ext/llama_cpp/src/ggml-metal.h +6 -0
- data/ext/llama_cpp/src/ggml-metal.m +998 -169
- data/ext/llama_cpp/src/ggml-metal.metal +2253 -274
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +634 -248
- data/ext/llama_cpp/src/ggml.h +81 -15
- data/ext/llama_cpp/src/llama.cpp +932 -352
- data/ext/llama_cpp/src/llama.h +28 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +22 -2
- metadata +2 -2
@@ -3,6 +3,8 @@
|
|
3
3
|
using namespace metal;
|
4
4
|
|
5
5
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
6
|
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
7
|
+
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
6
8
|
|
7
9
|
#define QK4_0 32
|
8
10
|
#define QR4_0 2
|
@@ -41,8 +43,13 @@ typedef struct {
|
|
41
43
|
|
42
44
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
43
45
|
|
44
|
-
|
45
|
-
|
46
|
+
enum ggml_sort_order {
|
47
|
+
GGML_SORT_ASC,
|
48
|
+
GGML_SORT_DESC,
|
49
|
+
};
|
50
|
+
|
51
|
+
// general-purpose kernel for addition, multiplication and division of two tensors
|
52
|
+
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
46
53
|
// cons: not very efficient
|
47
54
|
kernel void kernel_add(
|
48
55
|
device const char * src0,
|
@@ -72,6 +79,7 @@ kernel void kernel_add(
|
|
72
79
|
constant int64_t & nb1,
|
73
80
|
constant int64_t & nb2,
|
74
81
|
constant int64_t & nb3,
|
82
|
+
constant int64_t & offs,
|
75
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
76
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
77
85
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -83,16 +91,111 @@ kernel void kernel_add(
|
|
83
91
|
const int64_t i12 = i02 % ne12;
|
84
92
|
const int64_t i11 = i01 % ne11;
|
85
93
|
|
86
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 +
|
87
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11
|
88
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 +
|
94
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
95
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
96
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
97
|
+
|
98
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
99
|
+
const int i10 = i0 % ne10;
|
100
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
kernel void kernel_mul(
|
105
|
+
device const char * src0,
|
106
|
+
device const char * src1,
|
107
|
+
device char * dst,
|
108
|
+
constant int64_t & ne00,
|
109
|
+
constant int64_t & ne01,
|
110
|
+
constant int64_t & ne02,
|
111
|
+
constant int64_t & ne03,
|
112
|
+
constant int64_t & nb00,
|
113
|
+
constant int64_t & nb01,
|
114
|
+
constant int64_t & nb02,
|
115
|
+
constant int64_t & nb03,
|
116
|
+
constant int64_t & ne10,
|
117
|
+
constant int64_t & ne11,
|
118
|
+
constant int64_t & ne12,
|
119
|
+
constant int64_t & ne13,
|
120
|
+
constant int64_t & nb10,
|
121
|
+
constant int64_t & nb11,
|
122
|
+
constant int64_t & nb12,
|
123
|
+
constant int64_t & nb13,
|
124
|
+
constant int64_t & ne0,
|
125
|
+
constant int64_t & ne1,
|
126
|
+
constant int64_t & ne2,
|
127
|
+
constant int64_t & ne3,
|
128
|
+
constant int64_t & nb0,
|
129
|
+
constant int64_t & nb1,
|
130
|
+
constant int64_t & nb2,
|
131
|
+
constant int64_t & nb3,
|
132
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
133
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
134
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
135
|
+
const int64_t i03 = tgpig.z;
|
136
|
+
const int64_t i02 = tgpig.y;
|
137
|
+
const int64_t i01 = tgpig.x;
|
138
|
+
|
139
|
+
const int64_t i13 = i03 % ne13;
|
140
|
+
const int64_t i12 = i02 % ne12;
|
141
|
+
const int64_t i11 = i01 % ne11;
|
142
|
+
|
143
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
144
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
145
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
89
146
|
|
90
147
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
91
|
-
|
148
|
+
const int i10 = i0 % ne10;
|
149
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
|
150
|
+
}
|
151
|
+
}
|
152
|
+
|
153
|
+
kernel void kernel_div(
|
154
|
+
device const char * src0,
|
155
|
+
device const char * src1,
|
156
|
+
device char * dst,
|
157
|
+
constant int64_t & ne00,
|
158
|
+
constant int64_t & ne01,
|
159
|
+
constant int64_t & ne02,
|
160
|
+
constant int64_t & ne03,
|
161
|
+
constant int64_t & nb00,
|
162
|
+
constant int64_t & nb01,
|
163
|
+
constant int64_t & nb02,
|
164
|
+
constant int64_t & nb03,
|
165
|
+
constant int64_t & ne10,
|
166
|
+
constant int64_t & ne11,
|
167
|
+
constant int64_t & ne12,
|
168
|
+
constant int64_t & ne13,
|
169
|
+
constant int64_t & nb10,
|
170
|
+
constant int64_t & nb11,
|
171
|
+
constant int64_t & nb12,
|
172
|
+
constant int64_t & nb13,
|
173
|
+
constant int64_t & ne0,
|
174
|
+
constant int64_t & ne1,
|
175
|
+
constant int64_t & ne2,
|
176
|
+
constant int64_t & ne3,
|
177
|
+
constant int64_t & nb0,
|
178
|
+
constant int64_t & nb1,
|
179
|
+
constant int64_t & nb2,
|
180
|
+
constant int64_t & nb3,
|
181
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
182
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
183
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
184
|
+
const int64_t i03 = tgpig.z;
|
185
|
+
const int64_t i02 = tgpig.y;
|
186
|
+
const int64_t i01 = tgpig.x;
|
187
|
+
|
188
|
+
const int64_t i13 = i03 % ne13;
|
189
|
+
const int64_t i12 = i02 % ne12;
|
190
|
+
const int64_t i11 = i01 % ne11;
|
92
191
|
|
93
|
-
|
94
|
-
|
95
|
-
|
192
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
193
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
194
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
195
|
+
|
196
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
197
|
+
const int i10 = i0 % ne10;
|
198
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
|
96
199
|
}
|
97
200
|
}
|
98
201
|
|
@@ -102,28 +205,27 @@ kernel void kernel_add_row(
|
|
102
205
|
device const float4 * src0,
|
103
206
|
device const float4 * src1,
|
104
207
|
device float4 * dst,
|
105
|
-
constant int64_t & nb [[buffer(
|
208
|
+
constant int64_t & nb [[buffer(28)]],
|
106
209
|
uint tpig[[thread_position_in_grid]]) {
|
107
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
108
211
|
}
|
109
212
|
|
110
|
-
kernel void
|
213
|
+
kernel void kernel_mul_row(
|
111
214
|
device const float4 * src0,
|
112
215
|
device const float4 * src1,
|
113
216
|
device float4 * dst,
|
217
|
+
constant int64_t & nb [[buffer(28)]],
|
114
218
|
uint tpig[[thread_position_in_grid]]) {
|
115
|
-
dst[tpig] = src0[tpig] * src1[tpig];
|
219
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
116
220
|
}
|
117
221
|
|
118
|
-
|
119
|
-
// broadcast src1 into src0
|
120
|
-
kernel void kernel_mul_row(
|
222
|
+
kernel void kernel_div_row(
|
121
223
|
device const float4 * src0,
|
122
224
|
device const float4 * src1,
|
123
225
|
device float4 * dst,
|
124
|
-
constant int64_t & nb,
|
226
|
+
constant int64_t & nb [[buffer(28)]],
|
125
227
|
uint tpig[[thread_position_in_grid]]) {
|
126
|
-
dst[tpig] = src0[tpig]
|
228
|
+
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
127
229
|
}
|
128
230
|
|
129
231
|
kernel void kernel_scale(
|
@@ -142,14 +244,6 @@ kernel void kernel_scale_4(
|
|
142
244
|
dst[tpig] = src0[tpig] * scale;
|
143
245
|
}
|
144
246
|
|
145
|
-
kernel void kernel_silu(
|
146
|
-
device const float4 * src0,
|
147
|
-
device float4 * dst,
|
148
|
-
uint tpig[[thread_position_in_grid]]) {
|
149
|
-
device const float4 & x = src0[tpig];
|
150
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
151
|
-
}
|
152
|
-
|
153
247
|
kernel void kernel_relu(
|
154
248
|
device const float * src0,
|
155
249
|
device float * dst,
|
@@ -157,15 +251,17 @@ kernel void kernel_relu(
|
|
157
251
|
dst[tpig] = max(0.0f, src0[tpig]);
|
158
252
|
}
|
159
253
|
|
160
|
-
kernel void
|
254
|
+
kernel void kernel_tanh(
|
161
255
|
device const float * src0,
|
162
256
|
device float * dst,
|
163
257
|
uint tpig[[thread_position_in_grid]]) {
|
164
|
-
|
258
|
+
device const float & x = src0[tpig];
|
259
|
+
dst[tpig] = precise::tanh(x);
|
165
260
|
}
|
166
261
|
|
167
|
-
constant float GELU_COEF_A
|
168
|
-
constant float
|
262
|
+
constant float GELU_COEF_A = 0.044715f;
|
263
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
264
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
169
265
|
|
170
266
|
kernel void kernel_gelu(
|
171
267
|
device const float4 * src0,
|
@@ -180,6 +276,78 @@ kernel void kernel_gelu(
|
|
180
276
|
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
181
277
|
}
|
182
278
|
|
279
|
+
kernel void kernel_gelu_quick(
|
280
|
+
device const float4 * src0,
|
281
|
+
device float4 * dst,
|
282
|
+
uint tpig[[thread_position_in_grid]]) {
|
283
|
+
device const float4 & x = src0[tpig];
|
284
|
+
|
285
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
286
|
+
}
|
287
|
+
|
288
|
+
kernel void kernel_silu(
|
289
|
+
device const float4 * src0,
|
290
|
+
device float4 * dst,
|
291
|
+
uint tpig[[thread_position_in_grid]]) {
|
292
|
+
device const float4 & x = src0[tpig];
|
293
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
294
|
+
}
|
295
|
+
|
296
|
+
kernel void kernel_sqr(
|
297
|
+
device const float * src0,
|
298
|
+
device float * dst,
|
299
|
+
uint tpig[[thread_position_in_grid]]) {
|
300
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
301
|
+
}
|
302
|
+
|
303
|
+
kernel void kernel_sum_rows(
|
304
|
+
device const float * src0,
|
305
|
+
device float * dst,
|
306
|
+
constant int64_t & ne00,
|
307
|
+
constant int64_t & ne01,
|
308
|
+
constant int64_t & ne02,
|
309
|
+
constant int64_t & ne03,
|
310
|
+
constant int64_t & nb00,
|
311
|
+
constant int64_t & nb01,
|
312
|
+
constant int64_t & nb02,
|
313
|
+
constant int64_t & nb03,
|
314
|
+
constant int64_t & ne10,
|
315
|
+
constant int64_t & ne11,
|
316
|
+
constant int64_t & ne12,
|
317
|
+
constant int64_t & ne13,
|
318
|
+
constant int64_t & nb10,
|
319
|
+
constant int64_t & nb11,
|
320
|
+
constant int64_t & nb12,
|
321
|
+
constant int64_t & nb13,
|
322
|
+
constant int64_t & ne0,
|
323
|
+
constant int64_t & ne1,
|
324
|
+
constant int64_t & ne2,
|
325
|
+
constant int64_t & ne3,
|
326
|
+
constant int64_t & nb0,
|
327
|
+
constant int64_t & nb1,
|
328
|
+
constant int64_t & nb2,
|
329
|
+
constant int64_t & nb3,
|
330
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
331
|
+
int64_t i3 = tpig.z;
|
332
|
+
int64_t i2 = tpig.y;
|
333
|
+
int64_t i1 = tpig.x;
|
334
|
+
|
335
|
+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
336
|
+
return;
|
337
|
+
}
|
338
|
+
|
339
|
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
340
|
+
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
341
|
+
|
342
|
+
float row_sum = 0;
|
343
|
+
|
344
|
+
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
345
|
+
row_sum += src_row[i0];
|
346
|
+
}
|
347
|
+
|
348
|
+
dst_row[0] = row_sum;
|
349
|
+
}
|
350
|
+
|
183
351
|
kernel void kernel_soft_max(
|
184
352
|
device const float * src0,
|
185
353
|
device const float * src1,
|
@@ -198,9 +366,9 @@ kernel void kernel_soft_max(
|
|
198
366
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
199
367
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
200
368
|
|
201
|
-
device const float * psrc0 =
|
202
|
-
device const float * pmask = src1 ? src1
|
203
|
-
device float * pdst =
|
369
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
370
|
+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
371
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
204
372
|
|
205
373
|
// parallel max
|
206
374
|
float lmax = -INFINITY;
|
@@ -236,7 +404,12 @@ kernel void kernel_soft_max(
|
|
236
404
|
pdst[i00] = exp_psrc0;
|
237
405
|
}
|
238
406
|
|
407
|
+
// This barrier fixes a failing test
|
408
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
409
|
+
threadgroup_barrier(mem_flags::mem_none);
|
410
|
+
|
239
411
|
float sum = simd_sum(lsum);
|
412
|
+
|
240
413
|
if (ntg > N_SIMDWIDTH) {
|
241
414
|
if (sgitg == 0) {
|
242
415
|
buf[tiisg] = 0.0f;
|
@@ -279,9 +452,9 @@ kernel void kernel_soft_max_4(
|
|
279
452
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
280
453
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
281
454
|
|
282
|
-
device const float4 * psrc4 =
|
283
|
-
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
284
|
-
device float4 * pdst4 =
|
455
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
456
|
+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
457
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
285
458
|
|
286
459
|
// parallel max
|
287
460
|
float4 lmax4 = -INFINITY;
|
@@ -319,7 +492,13 @@ kernel void kernel_soft_max_4(
|
|
319
492
|
}
|
320
493
|
|
321
494
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
495
|
+
|
496
|
+
// This barrier fixes a failing test
|
497
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
498
|
+
threadgroup_barrier(mem_flags::mem_none);
|
499
|
+
|
322
500
|
float sum = simd_sum(lsum);
|
501
|
+
|
323
502
|
if (ntg > N_SIMDWIDTH) {
|
324
503
|
if (sgitg == 0) {
|
325
504
|
buf[tiisg] = 0.0f;
|
@@ -490,6 +669,94 @@ kernel void kernel_rms_norm(
|
|
490
669
|
}
|
491
670
|
}
|
492
671
|
|
672
|
+
kernel void kernel_group_norm(
|
673
|
+
device const float * src0,
|
674
|
+
device float * dst,
|
675
|
+
constant int64_t & ne00,
|
676
|
+
constant int64_t & ne01,
|
677
|
+
constant int64_t & ne02,
|
678
|
+
constant uint64_t & nb00,
|
679
|
+
constant uint64_t & nb01,
|
680
|
+
constant uint64_t & nb02,
|
681
|
+
constant int32_t & n_groups,
|
682
|
+
constant float & eps,
|
683
|
+
threadgroup float * buf [[threadgroup(0)]],
|
684
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
685
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
686
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
687
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
688
|
+
uint ntg[[threads_per_threadgroup]]) {
|
689
|
+
const int64_t ne = ne00*ne01*ne02;
|
690
|
+
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
691
|
+
|
692
|
+
int start = tgpig * gs;
|
693
|
+
int end = start + gs;
|
694
|
+
|
695
|
+
start += tpitg;
|
696
|
+
|
697
|
+
if (end >= ne) {
|
698
|
+
end = ne;
|
699
|
+
}
|
700
|
+
|
701
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
702
|
+
|
703
|
+
for (int j = start; j < end; j += ntg) {
|
704
|
+
tmp += src0[j];
|
705
|
+
}
|
706
|
+
|
707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
708
|
+
tmp = simd_sum(tmp);
|
709
|
+
if (ntg > N_SIMDWIDTH) {
|
710
|
+
if (sgitg == 0) {
|
711
|
+
buf[tiisg] = 0.0f;
|
712
|
+
}
|
713
|
+
|
714
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
715
|
+
|
716
|
+
if (tiisg == 0) {
|
717
|
+
buf[sgitg] = tmp;
|
718
|
+
}
|
719
|
+
|
720
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
721
|
+
|
722
|
+
tmp = buf[tiisg];
|
723
|
+
tmp = simd_sum(tmp);
|
724
|
+
}
|
725
|
+
|
726
|
+
const float mean = tmp / gs;
|
727
|
+
tmp = 0.0f;
|
728
|
+
|
729
|
+
for (int j = start; j < end; j += ntg) {
|
730
|
+
float xi = src0[j] - mean;
|
731
|
+
dst[j] = xi;
|
732
|
+
tmp += xi * xi;
|
733
|
+
}
|
734
|
+
|
735
|
+
tmp = simd_sum(tmp);
|
736
|
+
if (ntg > N_SIMDWIDTH) {
|
737
|
+
if (sgitg == 0) {
|
738
|
+
buf[tiisg] = 0.0f;
|
739
|
+
}
|
740
|
+
|
741
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
742
|
+
|
743
|
+
if (tiisg == 0) {
|
744
|
+
buf[sgitg] = tmp;
|
745
|
+
}
|
746
|
+
|
747
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
748
|
+
|
749
|
+
tmp = buf[tiisg];
|
750
|
+
tmp = simd_sum(tmp);
|
751
|
+
}
|
752
|
+
|
753
|
+
const float variance = tmp / gs;
|
754
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
755
|
+
for (int j = start; j < end; j += ntg) {
|
756
|
+
dst[j] *= scale;
|
757
|
+
}
|
758
|
+
}
|
759
|
+
|
493
760
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
494
761
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
495
762
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
@@ -582,9 +849,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
582
849
|
// giard against the number of rows not being divisible by
|
583
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
584
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
585
|
-
void
|
586
|
-
|
587
|
-
|
852
|
+
void mul_vec_q_n_f32_impl(
|
853
|
+
device const void * src0,
|
854
|
+
device const float * src1,
|
855
|
+
device float * dst,
|
856
|
+
int64_t ne00,
|
857
|
+
int64_t ne01,
|
858
|
+
int64_t ne02,
|
859
|
+
int64_t ne10,
|
860
|
+
int64_t ne12,
|
861
|
+
int64_t ne0,
|
862
|
+
int64_t ne1,
|
863
|
+
uint r2,
|
864
|
+
uint r3,
|
865
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
588
866
|
const int nb = ne00/QK4_0;
|
589
867
|
|
590
868
|
const int r0 = tgpig.x;
|
@@ -593,7 +871,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
593
871
|
|
594
872
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
595
873
|
|
596
|
-
const uint
|
874
|
+
const uint i12 = im%ne12;
|
875
|
+
const uint i13 = im/ne12;
|
876
|
+
|
877
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
597
878
|
|
598
879
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
599
880
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
@@ -643,13 +924,14 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
643
924
|
constant int64_t & ne02[[buffer(5)]],
|
644
925
|
constant int64_t & ne10[[buffer(9)]],
|
645
926
|
constant int64_t & ne12[[buffer(11)]],
|
646
|
-
constant int64_t & ne0[[buffer(15)]],
|
647
|
-
constant int64_t & ne1[[buffer(16)]],
|
648
|
-
constant uint &
|
927
|
+
constant int64_t & ne0 [[buffer(15)]],
|
928
|
+
constant int64_t & ne1 [[buffer(16)]],
|
929
|
+
constant uint & r2 [[buffer(17)]],
|
930
|
+
constant uint & r3 [[buffer(18)]],
|
649
931
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
650
932
|
uint tiisg[[thread_index_in_simdgroup]],
|
651
933
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
652
|
-
|
934
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
653
935
|
}
|
654
936
|
|
655
937
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -661,13 +943,14 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
661
943
|
constant int64_t & ne02[[buffer(5)]],
|
662
944
|
constant int64_t & ne10[[buffer(9)]],
|
663
945
|
constant int64_t & ne12[[buffer(11)]],
|
664
|
-
constant int64_t & ne0[[buffer(15)]],
|
665
|
-
constant int64_t & ne1[[buffer(16)]],
|
666
|
-
constant uint &
|
946
|
+
constant int64_t & ne0 [[buffer(15)]],
|
947
|
+
constant int64_t & ne1 [[buffer(16)]],
|
948
|
+
constant uint & r2 [[buffer(17)]],
|
949
|
+
constant uint & r3 [[buffer(18)]],
|
667
950
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
668
951
|
uint tiisg[[thread_index_in_simdgroup]],
|
669
952
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
670
|
-
|
953
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
671
954
|
}
|
672
955
|
|
673
956
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -679,13 +962,14 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
679
962
|
constant int64_t & ne02[[buffer(5)]],
|
680
963
|
constant int64_t & ne10[[buffer(9)]],
|
681
964
|
constant int64_t & ne12[[buffer(11)]],
|
682
|
-
constant int64_t & ne0[[buffer(15)]],
|
683
|
-
constant int64_t & ne1[[buffer(16)]],
|
684
|
-
constant uint &
|
965
|
+
constant int64_t & ne0 [[buffer(15)]],
|
966
|
+
constant int64_t & ne1 [[buffer(16)]],
|
967
|
+
constant uint & r2 [[buffer(17)]],
|
968
|
+
constant uint & r3 [[buffer(18)]],
|
685
969
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
686
970
|
uint tiisg[[thread_index_in_simdgroup]],
|
687
971
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
688
|
-
|
972
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
689
973
|
}
|
690
974
|
|
691
975
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -697,33 +981,35 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
697
981
|
constant int64_t & ne02[[buffer(5)]],
|
698
982
|
constant int64_t & ne10[[buffer(9)]],
|
699
983
|
constant int64_t & ne12[[buffer(11)]],
|
700
|
-
constant int64_t & ne0[[buffer(15)]],
|
701
|
-
constant int64_t & ne1[[buffer(16)]],
|
702
|
-
constant uint &
|
984
|
+
constant int64_t & ne0 [[buffer(15)]],
|
985
|
+
constant int64_t & ne1 [[buffer(16)]],
|
986
|
+
constant uint & r2 [[buffer(17)]],
|
987
|
+
constant uint & r3 [[buffer(18)]],
|
703
988
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
704
989
|
uint tiisg[[thread_index_in_simdgroup]],
|
705
990
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
706
|
-
|
991
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
707
992
|
}
|
708
993
|
|
709
994
|
|
710
995
|
#define NB_Q8_0 8
|
711
996
|
|
712
|
-
|
997
|
+
void kernel_mul_mv_q8_0_f32_impl(
|
713
998
|
device const void * src0,
|
714
999
|
device const float * src1,
|
715
1000
|
device float * dst,
|
716
1001
|
constant int64_t & ne00,
|
717
|
-
constant int64_t & ne01
|
718
|
-
constant int64_t & ne02
|
719
|
-
constant int64_t & ne10
|
720
|
-
constant int64_t & ne12
|
721
|
-
constant int64_t & ne0
|
722
|
-
constant int64_t & ne1
|
723
|
-
constant uint &
|
1002
|
+
constant int64_t & ne01,
|
1003
|
+
constant int64_t & ne02,
|
1004
|
+
constant int64_t & ne10,
|
1005
|
+
constant int64_t & ne12,
|
1006
|
+
constant int64_t & ne0,
|
1007
|
+
constant int64_t & ne1,
|
1008
|
+
constant uint & r2,
|
1009
|
+
constant uint & r3,
|
724
1010
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
725
|
-
uint
|
726
|
-
uint
|
1011
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1012
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
727
1013
|
const int nr = N_DST;
|
728
1014
|
const int nsg = N_SIMDGROUP;
|
729
1015
|
const int nw = N_SIMDWIDTH;
|
@@ -732,8 +1018,14 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
732
1018
|
const int r0 = tgpig.x;
|
733
1019
|
const int r1 = tgpig.y;
|
734
1020
|
const int im = tgpig.z;
|
1021
|
+
|
735
1022
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
736
|
-
|
1023
|
+
|
1024
|
+
const uint i12 = im%ne12;
|
1025
|
+
const uint i13 = im/ne12;
|
1026
|
+
|
1027
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
1028
|
+
|
737
1029
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
738
1030
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
739
1031
|
|
@@ -771,11 +1063,31 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
771
1063
|
}
|
772
1064
|
}
|
773
1065
|
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
device const
|
778
|
-
device
|
1066
|
+
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
1067
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
1068
|
+
device const void * src0,
|
1069
|
+
device const float * src1,
|
1070
|
+
device float * dst,
|
1071
|
+
constant int64_t & ne00,
|
1072
|
+
constant int64_t & ne01,
|
1073
|
+
constant int64_t & ne02,
|
1074
|
+
constant int64_t & ne10,
|
1075
|
+
constant int64_t & ne12,
|
1076
|
+
constant int64_t & ne0,
|
1077
|
+
constant int64_t & ne1,
|
1078
|
+
constant uint & r2 [[buffer(17)]],
|
1079
|
+
constant uint & r3 [[buffer(18)]],
|
1080
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1083
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1084
|
+
}
|
1085
|
+
|
1086
|
+
#define N_F32_F32 4
|
1087
|
+
|
1088
|
+
void kernel_mul_mv_f32_f32_impl(
|
1089
|
+
device const char * src0,
|
1090
|
+
device const char * src1,
|
779
1091
|
device float * dst,
|
780
1092
|
constant int64_t & ne00,
|
781
1093
|
constant int64_t & ne01,
|
@@ -791,6 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
791
1103
|
constant uint64_t & nb12,
|
792
1104
|
constant int64_t & ne0,
|
793
1105
|
constant int64_t & ne1,
|
1106
|
+
constant uint & r2,
|
1107
|
+
constant uint & r3,
|
794
1108
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
795
1109
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
796
1110
|
|
@@ -798,7 +1112,12 @@ kernel void kernel_mul_mv_f32_f32(
|
|
798
1112
|
const int64_t rb = tgpig.y*N_F32_F32;
|
799
1113
|
const int64_t im = tgpig.z;
|
800
1114
|
|
801
|
-
|
1115
|
+
const uint i12 = im%ne12;
|
1116
|
+
const uint i13 = im/ne12;
|
1117
|
+
|
1118
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1119
|
+
|
1120
|
+
device const float * x = (device const float *) (src0 + offset0);
|
802
1121
|
|
803
1122
|
if (ne00 < 128) {
|
804
1123
|
for (int row = 0; row < N_F32_F32; ++row) {
|
@@ -844,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
|
|
844
1163
|
}
|
845
1164
|
}
|
846
1165
|
|
1166
|
+
[[host_name("kernel_mul_mv_f32_f32")]]
|
1167
|
+
kernel void kernel_mul_mv_f32_f32(
|
1168
|
+
device const char * src0,
|
1169
|
+
device const char * src1,
|
1170
|
+
device float * dst,
|
1171
|
+
constant int64_t & ne00,
|
1172
|
+
constant int64_t & ne01,
|
1173
|
+
constant int64_t & ne02,
|
1174
|
+
constant uint64_t & nb00,
|
1175
|
+
constant uint64_t & nb01,
|
1176
|
+
constant uint64_t & nb02,
|
1177
|
+
constant int64_t & ne10,
|
1178
|
+
constant int64_t & ne11,
|
1179
|
+
constant int64_t & ne12,
|
1180
|
+
constant uint64_t & nb10,
|
1181
|
+
constant uint64_t & nb11,
|
1182
|
+
constant uint64_t & nb12,
|
1183
|
+
constant int64_t & ne0,
|
1184
|
+
constant int64_t & ne1,
|
1185
|
+
constant uint & r2 [[buffer(17)]],
|
1186
|
+
constant uint & r3 [[buffer(18)]],
|
1187
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1188
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1189
|
+
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1190
|
+
}
|
1191
|
+
|
847
1192
|
#define N_F16_F16 4
|
848
1193
|
|
849
1194
|
kernel void kernel_mul_mv_f16_f16(
|
@@ -864,6 +1209,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
864
1209
|
constant uint64_t & nb12,
|
865
1210
|
constant int64_t & ne0,
|
866
1211
|
constant int64_t & ne1,
|
1212
|
+
constant uint & r2 [[buffer(17)]],
|
1213
|
+
constant uint & r3 [[buffer(18)]],
|
867
1214
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
868
1215
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
869
1216
|
|
@@ -871,7 +1218,12 @@ kernel void kernel_mul_mv_f16_f16(
|
|
871
1218
|
const int64_t rb = tgpig.y*N_F16_F16;
|
872
1219
|
const int64_t im = tgpig.z;
|
873
1220
|
|
874
|
-
|
1221
|
+
const uint i12 = im%ne12;
|
1222
|
+
const uint i13 = im/ne12;
|
1223
|
+
|
1224
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1225
|
+
|
1226
|
+
device const half * x = (device const half *) (src0 + offset0);
|
875
1227
|
|
876
1228
|
if (ne00 < 128) {
|
877
1229
|
for (int row = 0; row < N_F16_F16; ++row) {
|
@@ -917,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
|
|
917
1269
|
}
|
918
1270
|
}
|
919
1271
|
|
920
|
-
|
1272
|
+
void kernel_mul_mv_f16_f32_1row_impl(
|
921
1273
|
device const char * src0,
|
922
1274
|
device const char * src1,
|
923
1275
|
device float * dst,
|
@@ -935,6 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
935
1287
|
constant uint64_t & nb12,
|
936
1288
|
constant int64_t & ne0,
|
937
1289
|
constant int64_t & ne1,
|
1290
|
+
constant uint & r2,
|
1291
|
+
constant uint & r3,
|
938
1292
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
939
1293
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
940
1294
|
|
@@ -942,7 +1296,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
942
1296
|
const int64_t r1 = tgpig.y;
|
943
1297
|
const int64_t im = tgpig.z;
|
944
1298
|
|
945
|
-
|
1299
|
+
const uint i12 = im%ne12;
|
1300
|
+
const uint i13 = im/ne12;
|
1301
|
+
|
1302
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1303
|
+
|
1304
|
+
device const half * x = (device const half *) (src0 + offset0);
|
946
1305
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
947
1306
|
|
948
1307
|
float sumf = 0;
|
@@ -966,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
966
1325
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
967
1326
|
}
|
968
1327
|
}
|
1328
|
+
}
|
969
1329
|
|
1330
|
+
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
1331
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
1332
|
+
device const char * src0,
|
1333
|
+
device const char * src1,
|
1334
|
+
device float * dst,
|
1335
|
+
constant int64_t & ne00,
|
1336
|
+
constant int64_t & ne01,
|
1337
|
+
constant int64_t & ne02,
|
1338
|
+
constant uint64_t & nb00,
|
1339
|
+
constant uint64_t & nb01,
|
1340
|
+
constant uint64_t & nb02,
|
1341
|
+
constant int64_t & ne10,
|
1342
|
+
constant int64_t & ne11,
|
1343
|
+
constant int64_t & ne12,
|
1344
|
+
constant uint64_t & nb10,
|
1345
|
+
constant uint64_t & nb11,
|
1346
|
+
constant uint64_t & nb12,
|
1347
|
+
constant int64_t & ne0,
|
1348
|
+
constant int64_t & ne1,
|
1349
|
+
constant uint & r2 [[buffer(17)]],
|
1350
|
+
constant uint & r3 [[buffer(18)]],
|
1351
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1353
|
+
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
970
1354
|
}
|
971
1355
|
|
972
1356
|
#define N_F16_F32 4
|
973
1357
|
|
974
|
-
|
1358
|
+
void kernel_mul_mv_f16_f32_impl(
|
975
1359
|
device const char * src0,
|
976
1360
|
device const char * src1,
|
977
1361
|
device float * dst,
|
@@ -989,6 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
989
1373
|
constant uint64_t & nb12,
|
990
1374
|
constant int64_t & ne0,
|
991
1375
|
constant int64_t & ne1,
|
1376
|
+
constant uint & r2,
|
1377
|
+
constant uint & r3,
|
992
1378
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
993
1379
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
994
1380
|
|
@@ -996,7 +1382,12 @@ kernel void kernel_mul_mv_f16_f32(
|
|
996
1382
|
const int64_t rb = tgpig.y*N_F16_F32;
|
997
1383
|
const int64_t im = tgpig.z;
|
998
1384
|
|
999
|
-
|
1385
|
+
const uint i12 = im%ne12;
|
1386
|
+
const uint i13 = im/ne12;
|
1387
|
+
|
1388
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1389
|
+
|
1390
|
+
device const half * x = (device const half *) (src0 + offset0);
|
1000
1391
|
|
1001
1392
|
if (ne00 < 128) {
|
1002
1393
|
for (int row = 0; row < N_F16_F32; ++row) {
|
@@ -1042,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1042
1433
|
}
|
1043
1434
|
}
|
1044
1435
|
|
1436
|
+
[[host_name("kernel_mul_mv_f16_f32")]]
|
1437
|
+
kernel void kernel_mul_mv_f16_f32(
|
1438
|
+
device const char * src0,
|
1439
|
+
device const char * src1,
|
1440
|
+
device float * dst,
|
1441
|
+
constant int64_t & ne00,
|
1442
|
+
constant int64_t & ne01,
|
1443
|
+
constant int64_t & ne02,
|
1444
|
+
constant uint64_t & nb00,
|
1445
|
+
constant uint64_t & nb01,
|
1446
|
+
constant uint64_t & nb02,
|
1447
|
+
constant int64_t & ne10,
|
1448
|
+
constant int64_t & ne11,
|
1449
|
+
constant int64_t & ne12,
|
1450
|
+
constant uint64_t & nb10,
|
1451
|
+
constant uint64_t & nb11,
|
1452
|
+
constant uint64_t & nb12,
|
1453
|
+
constant int64_t & ne0,
|
1454
|
+
constant int64_t & ne1,
|
1455
|
+
constant uint & r2 [[buffer(17)]],
|
1456
|
+
constant uint & r3 [[buffer(18)]],
|
1457
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1458
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1459
|
+
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1460
|
+
}
|
1461
|
+
|
1045
1462
|
// Assumes row size (ne00) is a multiple of 4
|
1046
1463
|
kernel void kernel_mul_mv_f16_f32_l4(
|
1047
1464
|
device const char * src0,
|
@@ -1061,6 +1478,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1061
1478
|
constant uint64_t & nb12,
|
1062
1479
|
constant int64_t & ne0,
|
1063
1480
|
constant int64_t & ne1,
|
1481
|
+
constant uint & r2 [[buffer(17)]],
|
1482
|
+
constant uint & r3 [[buffer(18)]],
|
1064
1483
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1065
1484
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1066
1485
|
|
@@ -1068,7 +1487,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1068
1487
|
const int64_t r0 = tgpig.x;
|
1069
1488
|
const int64_t im = tgpig.z;
|
1070
1489
|
|
1071
|
-
|
1490
|
+
const uint i12 = im%ne12;
|
1491
|
+
const uint i13 = im/ne12;
|
1492
|
+
|
1493
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1494
|
+
|
1495
|
+
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
1072
1496
|
|
1073
1497
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
1074
1498
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
@@ -1120,17 +1544,21 @@ kernel void kernel_alibi_f32(
|
|
1120
1544
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1121
1545
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1122
1546
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1547
|
+
const int64_t k = i3*ne3 + i2;
|
1123
1548
|
|
1124
|
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1125
1549
|
float m_k;
|
1126
|
-
if (
|
1127
|
-
m_k = pow(m0,
|
1550
|
+
if (k < n_heads_log2_floor) {
|
1551
|
+
m_k = pow(m0, k + 1);
|
1128
1552
|
} else {
|
1129
|
-
m_k = pow(m1, 2 * (
|
1553
|
+
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
1130
1554
|
}
|
1555
|
+
|
1556
|
+
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
1557
|
+
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
1131
1558
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1132
|
-
|
1133
|
-
|
1559
|
+
const float src_v = *(device float *)(src_row + i00*nb00);
|
1560
|
+
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
1561
|
+
*dst_v = i00 * m_k + src_v;
|
1134
1562
|
}
|
1135
1563
|
}
|
1136
1564
|
|
@@ -1335,9 +1763,160 @@ kernel void kernel_im2col_f16(
|
|
1335
1763
|
}
|
1336
1764
|
}
|
1337
1765
|
|
1766
|
+
kernel void kernel_upscale_f32(
|
1767
|
+
device const char * src0,
|
1768
|
+
device char * dst,
|
1769
|
+
constant int64_t & ne00,
|
1770
|
+
constant int64_t & ne01,
|
1771
|
+
constant int64_t & ne02,
|
1772
|
+
constant int64_t & ne03,
|
1773
|
+
constant uint64_t & nb00,
|
1774
|
+
constant uint64_t & nb01,
|
1775
|
+
constant uint64_t & nb02,
|
1776
|
+
constant uint64_t & nb03,
|
1777
|
+
constant int64_t & ne0,
|
1778
|
+
constant int64_t & ne1,
|
1779
|
+
constant int64_t & ne2,
|
1780
|
+
constant int64_t & ne3,
|
1781
|
+
constant uint64_t & nb0,
|
1782
|
+
constant uint64_t & nb1,
|
1783
|
+
constant uint64_t & nb2,
|
1784
|
+
constant uint64_t & nb3,
|
1785
|
+
constant int32_t & sf,
|
1786
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1787
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1788
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1789
|
+
|
1790
|
+
const int64_t i3 = tgpig.z;
|
1791
|
+
const int64_t i2 = tgpig.y;
|
1792
|
+
const int64_t i1 = tgpig.x;
|
1793
|
+
|
1794
|
+
const int64_t i03 = i3;
|
1795
|
+
const int64_t i02 = i2;
|
1796
|
+
const int64_t i01 = i1/sf;
|
1797
|
+
|
1798
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1799
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1800
|
+
|
1801
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1802
|
+
dst_ptr[i0] = src0_ptr[i0/sf];
|
1803
|
+
}
|
1804
|
+
}
|
1805
|
+
|
1806
|
+
kernel void kernel_pad_f32(
|
1807
|
+
device const char * src0,
|
1808
|
+
device char * dst,
|
1809
|
+
constant int64_t & ne00,
|
1810
|
+
constant int64_t & ne01,
|
1811
|
+
constant int64_t & ne02,
|
1812
|
+
constant int64_t & ne03,
|
1813
|
+
constant uint64_t & nb00,
|
1814
|
+
constant uint64_t & nb01,
|
1815
|
+
constant uint64_t & nb02,
|
1816
|
+
constant uint64_t & nb03,
|
1817
|
+
constant int64_t & ne0,
|
1818
|
+
constant int64_t & ne1,
|
1819
|
+
constant int64_t & ne2,
|
1820
|
+
constant int64_t & ne3,
|
1821
|
+
constant uint64_t & nb0,
|
1822
|
+
constant uint64_t & nb1,
|
1823
|
+
constant uint64_t & nb2,
|
1824
|
+
constant uint64_t & nb3,
|
1825
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1826
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1827
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1828
|
+
|
1829
|
+
const int64_t i3 = tgpig.z;
|
1830
|
+
const int64_t i2 = tgpig.y;
|
1831
|
+
const int64_t i1 = tgpig.x;
|
1832
|
+
|
1833
|
+
const int64_t i03 = i3;
|
1834
|
+
const int64_t i02 = i2;
|
1835
|
+
const int64_t i01 = i1;
|
1836
|
+
|
1837
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1838
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1839
|
+
|
1840
|
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
1841
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1842
|
+
if (i0 < ne00) {
|
1843
|
+
dst_ptr[i0] = src0_ptr[i0];
|
1844
|
+
} else {
|
1845
|
+
dst_ptr[i0] = 0.0f;
|
1846
|
+
}
|
1847
|
+
}
|
1848
|
+
|
1849
|
+
return;
|
1850
|
+
}
|
1851
|
+
|
1852
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1853
|
+
dst_ptr[i0] = 0.0f;
|
1854
|
+
}
|
1855
|
+
}
|
1856
|
+
|
1857
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
1858
|
+
typedef void (argsort_t)(
|
1859
|
+
device const float * x,
|
1860
|
+
device int32_t * dst,
|
1861
|
+
constant int64_t & ncols,
|
1862
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1863
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1864
|
+
|
1865
|
+
template<ggml_sort_order order>
|
1866
|
+
kernel void kernel_argsort_f32_i32(
|
1867
|
+
device const float * x,
|
1868
|
+
device int32_t * dst,
|
1869
|
+
constant int64_t & ncols,
|
1870
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1871
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1872
|
+
// bitonic sort
|
1873
|
+
int col = tpitg[0];
|
1874
|
+
int row = tgpig[1];
|
1875
|
+
|
1876
|
+
if (col >= ncols) return;
|
1877
|
+
|
1878
|
+
device const float * x_row = x + row * ncols;
|
1879
|
+
device int32_t * dst_row = dst + row * ncols;
|
1880
|
+
|
1881
|
+
// initialize indices
|
1882
|
+
if (col < ncols) {
|
1883
|
+
dst_row[col] = col;
|
1884
|
+
}
|
1885
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1886
|
+
|
1887
|
+
for (int k = 2; k <= ncols; k *= 2) {
|
1888
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
1889
|
+
int ixj = col ^ j;
|
1890
|
+
if (ixj > col) {
|
1891
|
+
if ((col & k) == 0) {
|
1892
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
1893
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
1894
|
+
}
|
1895
|
+
} else {
|
1896
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
1897
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
1898
|
+
}
|
1899
|
+
}
|
1900
|
+
}
|
1901
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1902
|
+
}
|
1903
|
+
}
|
1904
|
+
}
|
1905
|
+
|
1906
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
1907
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
1908
|
+
|
1909
|
+
kernel void kernel_leaky_relu_f32(
|
1910
|
+
device const float * src0,
|
1911
|
+
device float * dst,
|
1912
|
+
constant float & slope,
|
1913
|
+
uint tpig[[thread_position_in_grid]]) {
|
1914
|
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
1915
|
+
}
|
1916
|
+
|
1338
1917
|
kernel void kernel_cpy_f16_f16(
|
1339
|
-
device
|
1340
|
-
device
|
1918
|
+
device const half * src0,
|
1919
|
+
device half * dst,
|
1341
1920
|
constant int64_t & ne00,
|
1342
1921
|
constant int64_t & ne01,
|
1343
1922
|
constant int64_t & ne02,
|
@@ -1376,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
|
|
1376
1955
|
}
|
1377
1956
|
}
|
1378
1957
|
|
1958
|
+
kernel void kernel_cpy_f16_f32(
|
1959
|
+
device const half * src0,
|
1960
|
+
device float * dst,
|
1961
|
+
constant int64_t & ne00,
|
1962
|
+
constant int64_t & ne01,
|
1963
|
+
constant int64_t & ne02,
|
1964
|
+
constant int64_t & ne03,
|
1965
|
+
constant uint64_t & nb00,
|
1966
|
+
constant uint64_t & nb01,
|
1967
|
+
constant uint64_t & nb02,
|
1968
|
+
constant uint64_t & nb03,
|
1969
|
+
constant int64_t & ne0,
|
1970
|
+
constant int64_t & ne1,
|
1971
|
+
constant int64_t & ne2,
|
1972
|
+
constant int64_t & ne3,
|
1973
|
+
constant uint64_t & nb0,
|
1974
|
+
constant uint64_t & nb1,
|
1975
|
+
constant uint64_t & nb2,
|
1976
|
+
constant uint64_t & nb3,
|
1977
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1978
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1979
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1980
|
+
const int64_t i03 = tgpig[2];
|
1981
|
+
const int64_t i02 = tgpig[1];
|
1982
|
+
const int64_t i01 = tgpig[0];
|
1983
|
+
|
1984
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1985
|
+
|
1986
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1987
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1988
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1989
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1990
|
+
|
1991
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1992
|
+
|
1993
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1994
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1995
|
+
dst_data[i00] = src[0];
|
1996
|
+
}
|
1997
|
+
}
|
1998
|
+
|
1379
1999
|
kernel void kernel_cpy_f32_f16(
|
1380
2000
|
device const float * src0,
|
1381
2001
|
device half * dst,
|
@@ -1460,47 +2080,238 @@ kernel void kernel_cpy_f32_f32(
|
|
1460
2080
|
}
|
1461
2081
|
}
|
1462
2082
|
|
1463
|
-
kernel void
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
2083
|
+
kernel void kernel_cpy_f32_q8_0(
|
2084
|
+
device const float * src0,
|
2085
|
+
device void * dst,
|
2086
|
+
constant int64_t & ne00,
|
2087
|
+
constant int64_t & ne01,
|
2088
|
+
constant int64_t & ne02,
|
2089
|
+
constant int64_t & ne03,
|
2090
|
+
constant uint64_t & nb00,
|
2091
|
+
constant uint64_t & nb01,
|
2092
|
+
constant uint64_t & nb02,
|
2093
|
+
constant uint64_t & nb03,
|
2094
|
+
constant int64_t & ne0,
|
2095
|
+
constant int64_t & ne1,
|
2096
|
+
constant int64_t & ne2,
|
2097
|
+
constant int64_t & ne3,
|
2098
|
+
constant uint64_t & nb0,
|
2099
|
+
constant uint64_t & nb1,
|
2100
|
+
constant uint64_t & nb2,
|
2101
|
+
constant uint64_t & nb3,
|
2102
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2103
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2104
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2105
|
+
const int64_t i03 = tgpig[2];
|
2106
|
+
const int64_t i02 = tgpig[1];
|
2107
|
+
const int64_t i01 = tgpig[0];
|
2108
|
+
|
2109
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2110
|
+
|
2111
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2112
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2113
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2114
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
2115
|
+
|
2116
|
+
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2117
|
+
|
2118
|
+
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
2119
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2120
|
+
|
2121
|
+
float amax = 0.0f; // absolute max
|
2122
|
+
|
2123
|
+
for (int j = 0; j < QK8_0; j++) {
|
2124
|
+
const float v = src[j];
|
2125
|
+
amax = MAX(amax, fabs(v));
|
2126
|
+
}
|
2127
|
+
|
2128
|
+
const float d = amax / ((1 << 7) - 1);
|
2129
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2130
|
+
|
2131
|
+
dst_data[i00/QK8_0].d = d;
|
2132
|
+
|
2133
|
+
for (int j = 0; j < QK8_0; ++j) {
|
2134
|
+
const float x0 = src[j]*id;
|
2135
|
+
|
2136
|
+
dst_data[i00/QK8_0].qs[j] = round(x0);
|
2137
|
+
}
|
2138
|
+
}
|
2139
|
+
}
|
2140
|
+
|
2141
|
+
kernel void kernel_cpy_f32_q4_0(
|
2142
|
+
device const float * src0,
|
2143
|
+
device void * dst,
|
2144
|
+
constant int64_t & ne00,
|
2145
|
+
constant int64_t & ne01,
|
2146
|
+
constant int64_t & ne02,
|
2147
|
+
constant int64_t & ne03,
|
2148
|
+
constant uint64_t & nb00,
|
2149
|
+
constant uint64_t & nb01,
|
2150
|
+
constant uint64_t & nb02,
|
2151
|
+
constant uint64_t & nb03,
|
2152
|
+
constant int64_t & ne0,
|
2153
|
+
constant int64_t & ne1,
|
2154
|
+
constant int64_t & ne2,
|
2155
|
+
constant int64_t & ne3,
|
2156
|
+
constant uint64_t & nb0,
|
2157
|
+
constant uint64_t & nb1,
|
2158
|
+
constant uint64_t & nb2,
|
2159
|
+
constant uint64_t & nb3,
|
2160
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2161
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2162
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2163
|
+
const int64_t i03 = tgpig[2];
|
2164
|
+
const int64_t i02 = tgpig[1];
|
2165
|
+
const int64_t i01 = tgpig[0];
|
2166
|
+
|
2167
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2168
|
+
|
2169
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2170
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2171
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2172
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
|
2173
|
+
|
2174
|
+
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2175
|
+
|
2176
|
+
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
|
2177
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2178
|
+
|
2179
|
+
float amax = 0.0f; // absolute max
|
2180
|
+
float max = 0.0f;
|
2181
|
+
|
2182
|
+
for (int j = 0; j < QK4_0; j++) {
|
2183
|
+
const float v = src[j];
|
2184
|
+
if (amax < fabs(v)) {
|
2185
|
+
amax = fabs(v);
|
2186
|
+
max = v;
|
2187
|
+
}
|
2188
|
+
}
|
2189
|
+
|
2190
|
+
const float d = max / -8;
|
2191
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2192
|
+
|
2193
|
+
dst_data[i00/QK4_0].d = d;
|
2194
|
+
|
2195
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
2196
|
+
const float x0 = src[0 + j]*id;
|
2197
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
2198
|
+
|
2199
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
2200
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
2201
|
+
|
2202
|
+
dst_data[i00/QK4_0].qs[j] = xi0;
|
2203
|
+
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
2204
|
+
}
|
2205
|
+
}
|
2206
|
+
}
|
2207
|
+
|
2208
|
+
kernel void kernel_cpy_f32_q4_1(
|
2209
|
+
device const float * src0,
|
2210
|
+
device void * dst,
|
2211
|
+
constant int64_t & ne00,
|
2212
|
+
constant int64_t & ne01,
|
2213
|
+
constant int64_t & ne02,
|
2214
|
+
constant int64_t & ne03,
|
2215
|
+
constant uint64_t & nb00,
|
2216
|
+
constant uint64_t & nb01,
|
2217
|
+
constant uint64_t & nb02,
|
2218
|
+
constant uint64_t & nb03,
|
2219
|
+
constant int64_t & ne0,
|
2220
|
+
constant int64_t & ne1,
|
2221
|
+
constant int64_t & ne2,
|
2222
|
+
constant int64_t & ne3,
|
2223
|
+
constant uint64_t & nb0,
|
2224
|
+
constant uint64_t & nb1,
|
2225
|
+
constant uint64_t & nb2,
|
2226
|
+
constant uint64_t & nb3,
|
2227
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2228
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2229
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2230
|
+
const int64_t i03 = tgpig[2];
|
2231
|
+
const int64_t i02 = tgpig[1];
|
2232
|
+
const int64_t i01 = tgpig[0];
|
2233
|
+
|
2234
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2235
|
+
|
2236
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2237
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2238
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2239
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
|
2240
|
+
|
2241
|
+
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2242
|
+
|
2243
|
+
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
|
2244
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2245
|
+
|
2246
|
+
float min = FLT_MAX;
|
2247
|
+
float max = -FLT_MAX;
|
2248
|
+
|
2249
|
+
for (int j = 0; j < QK4_1; j++) {
|
2250
|
+
const float v = src[j];
|
2251
|
+
if (min > v) min = v;
|
2252
|
+
if (max < v) max = v;
|
2253
|
+
}
|
2254
|
+
|
2255
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
2256
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2257
|
+
|
2258
|
+
dst_data[i00/QK4_1].d = d;
|
2259
|
+
dst_data[i00/QK4_1].m = min;
|
2260
|
+
|
2261
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
2262
|
+
const float x0 = (src[0 + j] - min)*id;
|
2263
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
2264
|
+
|
2265
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
2266
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
2267
|
+
|
2268
|
+
dst_data[i00/QK4_1].qs[j] = xi0;
|
2269
|
+
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
2270
|
+
}
|
2271
|
+
}
|
2272
|
+
}
|
2273
|
+
|
2274
|
+
kernel void kernel_concat(
|
2275
|
+
device const char * src0,
|
2276
|
+
device const char * src1,
|
2277
|
+
device char * dst,
|
2278
|
+
constant int64_t & ne00,
|
2279
|
+
constant int64_t & ne01,
|
2280
|
+
constant int64_t & ne02,
|
2281
|
+
constant int64_t & ne03,
|
2282
|
+
constant uint64_t & nb00,
|
2283
|
+
constant uint64_t & nb01,
|
2284
|
+
constant uint64_t & nb02,
|
2285
|
+
constant uint64_t & nb03,
|
2286
|
+
constant int64_t & ne10,
|
2287
|
+
constant int64_t & ne11,
|
2288
|
+
constant int64_t & ne12,
|
2289
|
+
constant int64_t & ne13,
|
2290
|
+
constant uint64_t & nb10,
|
2291
|
+
constant uint64_t & nb11,
|
2292
|
+
constant uint64_t & nb12,
|
2293
|
+
constant uint64_t & nb13,
|
2294
|
+
constant int64_t & ne0,
|
2295
|
+
constant int64_t & ne1,
|
2296
|
+
constant int64_t & ne2,
|
2297
|
+
constant int64_t & ne3,
|
2298
|
+
constant uint64_t & nb0,
|
2299
|
+
constant uint64_t & nb1,
|
2300
|
+
constant uint64_t & nb2,
|
2301
|
+
constant uint64_t & nb3,
|
2302
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2303
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2304
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2305
|
+
|
2306
|
+
const int64_t i03 = tgpig.z;
|
2307
|
+
const int64_t i02 = tgpig.y;
|
2308
|
+
const int64_t i01 = tgpig.x;
|
2309
|
+
|
2310
|
+
const int64_t i13 = i03 % ne13;
|
1500
2311
|
const int64_t i12 = i02 % ne12;
|
1501
2312
|
const int64_t i11 = i01 % ne11;
|
1502
2313
|
|
1503
|
-
device const char * src0_ptr = src0 + i03
|
2314
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
1504
2315
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1505
2316
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1506
2317
|
|
@@ -1608,32 +2419,39 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
1608
2419
|
|
1609
2420
|
//====================================== dot products =========================
|
1610
2421
|
|
1611
|
-
|
2422
|
+
void kernel_mul_mv_q2_K_f32_impl(
|
1612
2423
|
device const void * src0,
|
1613
2424
|
device const float * src1,
|
1614
2425
|
device float * dst,
|
1615
2426
|
constant int64_t & ne00,
|
1616
|
-
constant int64_t & ne01
|
1617
|
-
constant int64_t & ne02
|
1618
|
-
constant int64_t & ne10
|
1619
|
-
constant int64_t & ne12
|
1620
|
-
constant int64_t & ne0
|
1621
|
-
constant int64_t & ne1
|
1622
|
-
constant uint &
|
2427
|
+
constant int64_t & ne01,
|
2428
|
+
constant int64_t & ne02,
|
2429
|
+
constant int64_t & ne10,
|
2430
|
+
constant int64_t & ne12,
|
2431
|
+
constant int64_t & ne0,
|
2432
|
+
constant int64_t & ne1,
|
2433
|
+
constant uint & r2,
|
2434
|
+
constant uint & r3,
|
1623
2435
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1624
|
-
uint
|
1625
|
-
uint
|
2436
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2437
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1626
2438
|
|
1627
2439
|
const int nb = ne00/QK_K;
|
1628
2440
|
const int r0 = tgpig.x;
|
1629
2441
|
const int r1 = tgpig.y;
|
1630
|
-
const int
|
2442
|
+
const int im = tgpig.z;
|
1631
2443
|
|
1632
2444
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1633
2445
|
const int ib_row = first_row * nb;
|
1634
|
-
|
2446
|
+
|
2447
|
+
const uint i12 = im%ne12;
|
2448
|
+
const uint i13 = im/ne12;
|
2449
|
+
|
2450
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2451
|
+
|
1635
2452
|
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
1636
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2453
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2454
|
+
|
1637
2455
|
float yl[32];
|
1638
2456
|
float sumf[N_DST]={0.f}, all_sum;
|
1639
2457
|
|
@@ -1642,11 +2460,11 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1642
2460
|
#if QK_K == 256
|
1643
2461
|
const int ix = tiisg/8; // 0...3
|
1644
2462
|
const int it = tiisg%8; // 0...7
|
1645
|
-
const int
|
2463
|
+
const int iq = it/4; // 0 or 1
|
1646
2464
|
const int ir = it%4; // 0...3
|
1647
2465
|
const int is = (8*ir)/16;// 0 or 1
|
1648
2466
|
|
1649
|
-
device const float * y4 = y + ix * QK_K + 128 *
|
2467
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
1650
2468
|
|
1651
2469
|
for (int ib = ix; ib < nb; ib += 4) {
|
1652
2470
|
|
@@ -1658,8 +2476,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1658
2476
|
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
1659
2477
|
}
|
1660
2478
|
|
1661
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*
|
1662
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 *
|
2479
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
2480
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
1663
2481
|
device const half * dh = &x[ib].d;
|
1664
2482
|
|
1665
2483
|
for (int row = 0; row < N_DST; row++) {
|
@@ -1746,13 +2564,13 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1746
2564
|
for (int row = 0; row < N_DST; ++row) {
|
1747
2565
|
all_sum = simd_sum(sumf[row]);
|
1748
2566
|
if (tiisg == 0) {
|
1749
|
-
dst[r1*ne0 +
|
2567
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
1750
2568
|
}
|
1751
2569
|
}
|
1752
2570
|
}
|
1753
2571
|
|
1754
|
-
|
1755
|
-
kernel void
|
2572
|
+
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
2573
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
1756
2574
|
device const void * src0,
|
1757
2575
|
device const float * src1,
|
1758
2576
|
device float * dst,
|
@@ -1761,23 +2579,50 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1761
2579
|
constant int64_t & ne02[[buffer(5)]],
|
1762
2580
|
constant int64_t & ne10[[buffer(9)]],
|
1763
2581
|
constant int64_t & ne12[[buffer(11)]],
|
1764
|
-
constant int64_t & ne0[[buffer(15)]],
|
1765
|
-
constant int64_t & ne1[[buffer(16)]],
|
1766
|
-
constant uint &
|
2582
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2583
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2584
|
+
constant uint & r2 [[buffer(17)]],
|
2585
|
+
constant uint & r3 [[buffer(18)]],
|
1767
2586
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1768
|
-
uint
|
1769
|
-
uint
|
2587
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2588
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2589
|
+
|
2590
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2591
|
+
}
|
2592
|
+
|
2593
|
+
#if QK_K == 256
|
2594
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
2595
|
+
device const void * src0,
|
2596
|
+
device const float * src1,
|
2597
|
+
device float * dst,
|
2598
|
+
constant int64_t & ne00,
|
2599
|
+
constant int64_t & ne01,
|
2600
|
+
constant int64_t & ne02,
|
2601
|
+
constant int64_t & ne10,
|
2602
|
+
constant int64_t & ne12,
|
2603
|
+
constant int64_t & ne0,
|
2604
|
+
constant int64_t & ne1,
|
2605
|
+
constant uint & r2,
|
2606
|
+
constant uint & r3,
|
2607
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2608
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2609
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1770
2610
|
|
1771
2611
|
const int nb = ne00/QK_K;
|
1772
2612
|
|
1773
2613
|
const int64_t r0 = tgpig.x;
|
1774
2614
|
const int64_t r1 = tgpig.y;
|
1775
|
-
const int64_t
|
2615
|
+
const int64_t im = tgpig.z;
|
1776
2616
|
|
1777
2617
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1778
|
-
|
2618
|
+
|
2619
|
+
const uint i12 = im%ne12;
|
2620
|
+
const uint i13 = im/ne12;
|
2621
|
+
|
2622
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2623
|
+
|
1779
2624
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1780
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2625
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
1781
2626
|
|
1782
2627
|
float yl[32];
|
1783
2628
|
|
@@ -1899,40 +2744,47 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1899
2744
|
}
|
1900
2745
|
if (tiisg == 0) {
|
1901
2746
|
for (int row = 0; row < 2; ++row) {
|
1902
|
-
dst[r1*ne0 +
|
2747
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
|
1903
2748
|
}
|
1904
2749
|
}
|
1905
2750
|
}
|
1906
2751
|
#else
|
1907
|
-
|
2752
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
1908
2753
|
device const void * src0,
|
1909
2754
|
device const float * src1,
|
1910
2755
|
device float * dst,
|
1911
2756
|
constant int64_t & ne00,
|
1912
|
-
constant int64_t & ne01
|
1913
|
-
constant int64_t & ne02
|
1914
|
-
constant int64_t & ne10
|
1915
|
-
constant int64_t & ne12
|
1916
|
-
constant int64_t & ne0
|
1917
|
-
constant int64_t & ne1
|
1918
|
-
constant uint &
|
2757
|
+
constant int64_t & ne01,
|
2758
|
+
constant int64_t & ne02,
|
2759
|
+
constant int64_t & ne10,
|
2760
|
+
constant int64_t & ne12,
|
2761
|
+
constant int64_t & ne0,
|
2762
|
+
constant int64_t & ne1,
|
2763
|
+
constant uint & r2,
|
2764
|
+
constant uint & r3,
|
1919
2765
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1920
|
-
uint
|
1921
|
-
uint
|
2766
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2767
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1922
2768
|
|
1923
2769
|
const int nb = ne00/QK_K;
|
1924
2770
|
|
1925
2771
|
const int64_t r0 = tgpig.x;
|
1926
2772
|
const int64_t r1 = tgpig.y;
|
1927
|
-
const int64_t
|
2773
|
+
const int64_t im = tgpig.z;
|
1928
2774
|
|
1929
2775
|
const int row = 2 * r0 + sgitg;
|
1930
|
-
|
2776
|
+
|
2777
|
+
const uint i12 = im%ne12;
|
2778
|
+
const uint i13 = im/ne12;
|
2779
|
+
|
2780
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2781
|
+
|
1931
2782
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
1932
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2783
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2784
|
+
|
1933
2785
|
const int ix = tiisg/4;
|
1934
2786
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1935
|
-
const int
|
2787
|
+
const int iq = il/8; // 0, 0, 1, 1
|
1936
2788
|
const int in = il%8; // 0, 4, 0, 4
|
1937
2789
|
|
1938
2790
|
float2 sum = {0.f, 0.f};
|
@@ -1952,7 +2804,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1952
2804
|
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
1953
2805
|
|
1954
2806
|
for (int l = 0; l < 4; l += 2) {
|
1955
|
-
const uint16_t hm = h[l/2] >>
|
2807
|
+
const uint16_t hm = h[l/2] >> iq;
|
1956
2808
|
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
1957
2809
|
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
1958
2810
|
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
@@ -1968,28 +2820,50 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1968
2820
|
|
1969
2821
|
const float tot = simd_sum(sumf);
|
1970
2822
|
if (tiisg == 0) {
|
1971
|
-
dst[r1*ne0 +
|
2823
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
1972
2824
|
}
|
1973
2825
|
|
1974
2826
|
}
|
1975
2827
|
#endif
|
1976
2828
|
|
2829
|
+
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
2830
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
2831
|
+
device const void * src0,
|
2832
|
+
device const float * src1,
|
2833
|
+
device float * dst,
|
2834
|
+
constant int64_t & ne00,
|
2835
|
+
constant int64_t & ne01[[buffer(4)]],
|
2836
|
+
constant int64_t & ne02[[buffer(5)]],
|
2837
|
+
constant int64_t & ne10[[buffer(9)]],
|
2838
|
+
constant int64_t & ne12[[buffer(11)]],
|
2839
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2840
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2841
|
+
constant uint & r2 [[buffer(17)]],
|
2842
|
+
constant uint & r3 [[buffer(18)]],
|
2843
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2844
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2845
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2846
|
+
|
2847
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2848
|
+
}
|
2849
|
+
|
1977
2850
|
#if QK_K == 256
|
1978
|
-
|
2851
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
1979
2852
|
device const void * src0,
|
1980
2853
|
device const float * src1,
|
1981
2854
|
device float * dst,
|
1982
2855
|
constant int64_t & ne00,
|
1983
|
-
constant int64_t & ne01
|
1984
|
-
constant int64_t & ne02
|
1985
|
-
constant int64_t & ne10
|
1986
|
-
constant int64_t & ne12
|
1987
|
-
constant int64_t & ne0
|
1988
|
-
constant int64_t & ne1
|
1989
|
-
constant uint &
|
2856
|
+
constant int64_t & ne01,
|
2857
|
+
constant int64_t & ne02,
|
2858
|
+
constant int64_t & ne10,
|
2859
|
+
constant int64_t & ne12,
|
2860
|
+
constant int64_t & ne0,
|
2861
|
+
constant int64_t & ne1,
|
2862
|
+
constant uint & r2,
|
2863
|
+
constant uint & r3,
|
1990
2864
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1991
|
-
uint
|
1992
|
-
uint
|
2865
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2866
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1993
2867
|
|
1994
2868
|
const uint16_t kmask1 = 0x3f3f;
|
1995
2869
|
const uint16_t kmask2 = 0x0f0f;
|
@@ -1997,26 +2871,32 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
1997
2871
|
|
1998
2872
|
const int ix = tiisg/8; // 0...3
|
1999
2873
|
const int it = tiisg%8; // 0...7
|
2000
|
-
const int
|
2874
|
+
const int iq = it/4; // 0 or 1
|
2001
2875
|
const int ir = it%4; // 0...3
|
2002
2876
|
|
2003
2877
|
const int nb = ne00/QK_K;
|
2004
2878
|
const int r0 = tgpig.x;
|
2005
2879
|
const int r1 = tgpig.y;
|
2006
|
-
const int
|
2880
|
+
const int im = tgpig.z;
|
2007
2881
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
2008
2882
|
const int first_row = r0 * N_DST;
|
2009
2883
|
const int ib_row = first_row * nb;
|
2010
|
-
|
2884
|
+
|
2885
|
+
const uint i12 = im%ne12;
|
2886
|
+
const uint i13 = im/ne12;
|
2887
|
+
|
2888
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2889
|
+
|
2011
2890
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
2012
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2891
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2892
|
+
|
2013
2893
|
float yl[16];
|
2014
2894
|
float yh[16];
|
2015
2895
|
float sumf[N_DST]={0.f}, all_sum;
|
2016
2896
|
|
2017
2897
|
const int step = sizeof(block_q4_K) * nb / 2;
|
2018
2898
|
|
2019
|
-
device const float * y4 = y + ix * QK_K + 64 *
|
2899
|
+
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
2020
2900
|
|
2021
2901
|
uint16_t sc16[4];
|
2022
2902
|
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
@@ -2031,8 +2911,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2031
2911
|
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
2032
2912
|
}
|
2033
2913
|
|
2034
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales +
|
2035
|
-
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 *
|
2914
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
|
2915
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
2036
2916
|
device const half * dh = &x[ib].d;
|
2037
2917
|
|
2038
2918
|
for (int row = 0; row < N_DST; row++) {
|
@@ -2076,23 +2956,24 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2076
2956
|
for (int row = 0; row < N_DST; ++row) {
|
2077
2957
|
all_sum = simd_sum(sumf[row]);
|
2078
2958
|
if (tiisg == 0) {
|
2079
|
-
dst[r1*ne0 +
|
2959
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
2080
2960
|
}
|
2081
2961
|
}
|
2082
2962
|
}
|
2083
2963
|
#else
|
2084
|
-
|
2964
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
2085
2965
|
device const void * src0,
|
2086
2966
|
device const float * src1,
|
2087
2967
|
device float * dst,
|
2088
2968
|
constant int64_t & ne00,
|
2089
|
-
constant int64_t & ne01
|
2090
|
-
constant int64_t & ne02
|
2091
|
-
constant int64_t & ne10
|
2092
|
-
constant int64_t & ne12
|
2093
|
-
constant int64_t & ne0
|
2094
|
-
constant int64_t & ne1
|
2095
|
-
constant uint &
|
2969
|
+
constant int64_t & ne01,
|
2970
|
+
constant int64_t & ne02,
|
2971
|
+
constant int64_t & ne10,
|
2972
|
+
constant int64_t & ne12,
|
2973
|
+
constant int64_t & ne0,
|
2974
|
+
constant int64_t & ne1,
|
2975
|
+
constant uint & r2,
|
2976
|
+
constant uint & r3,
|
2096
2977
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2097
2978
|
uint tiisg[[thread_index_in_simdgroup]],
|
2098
2979
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2103,12 +2984,18 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2103
2984
|
const int nb = ne00/QK_K;
|
2104
2985
|
const int r0 = tgpig.x;
|
2105
2986
|
const int r1 = tgpig.y;
|
2106
|
-
const int
|
2987
|
+
const int im = tgpig.z;
|
2107
2988
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
2108
2989
|
const int ib_row = first_row * nb;
|
2109
|
-
|
2990
|
+
|
2991
|
+
const uint i12 = im%ne12;
|
2992
|
+
const uint i13 = im/ne12;
|
2993
|
+
|
2994
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2995
|
+
|
2110
2996
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
2111
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2997
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2998
|
+
|
2112
2999
|
float yl[8];
|
2113
3000
|
float yh[8];
|
2114
3001
|
float sumf[N_DST]={0.f}, all_sum;
|
@@ -2164,13 +3051,14 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2164
3051
|
for (int row = 0; row < N_DST; ++row) {
|
2165
3052
|
all_sum = simd_sum(sumf[row]);
|
2166
3053
|
if (tiisg == 0) {
|
2167
|
-
dst[r1*ne0+
|
3054
|
+
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
2168
3055
|
}
|
2169
3056
|
}
|
2170
3057
|
}
|
2171
3058
|
#endif
|
2172
3059
|
|
2173
|
-
|
3060
|
+
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
3061
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
2174
3062
|
device const void * src0,
|
2175
3063
|
device const float * src1,
|
2176
3064
|
device float * dst,
|
@@ -2179,23 +3067,49 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2179
3067
|
constant int64_t & ne02[[buffer(5)]],
|
2180
3068
|
constant int64_t & ne10[[buffer(9)]],
|
2181
3069
|
constant int64_t & ne12[[buffer(11)]],
|
2182
|
-
constant int64_t & ne0[[buffer(15)]],
|
2183
|
-
constant int64_t & ne1[[buffer(16)]],
|
2184
|
-
constant uint &
|
3070
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3071
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3072
|
+
constant uint & r2 [[buffer(17)]],
|
3073
|
+
constant uint & r3 [[buffer(18)]],
|
2185
3074
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2186
3075
|
uint tiisg[[thread_index_in_simdgroup]],
|
2187
3076
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2188
3077
|
|
3078
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3079
|
+
}
|
3080
|
+
|
3081
|
+
void kernel_mul_mv_q5_K_f32_impl(
|
3082
|
+
device const void * src0,
|
3083
|
+
device const float * src1,
|
3084
|
+
device float * dst,
|
3085
|
+
constant int64_t & ne00,
|
3086
|
+
constant int64_t & ne01,
|
3087
|
+
constant int64_t & ne02,
|
3088
|
+
constant int64_t & ne10,
|
3089
|
+
constant int64_t & ne12,
|
3090
|
+
constant int64_t & ne0,
|
3091
|
+
constant int64_t & ne1,
|
3092
|
+
constant uint & r2,
|
3093
|
+
constant uint & r3,
|
3094
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3095
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3096
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3097
|
+
|
2189
3098
|
const int nb = ne00/QK_K;
|
2190
3099
|
|
2191
3100
|
const int64_t r0 = tgpig.x;
|
2192
3101
|
const int64_t r1 = tgpig.y;
|
2193
|
-
const int
|
3102
|
+
const int im = tgpig.z;
|
2194
3103
|
|
2195
3104
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
2196
|
-
|
3105
|
+
|
3106
|
+
const uint i12 = im%ne12;
|
3107
|
+
const uint i13 = im/ne12;
|
3108
|
+
|
3109
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3110
|
+
|
2197
3111
|
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
2198
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
3112
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2199
3113
|
|
2200
3114
|
float sumf[2]={0.f};
|
2201
3115
|
|
@@ -2211,15 +3125,15 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2211
3125
|
|
2212
3126
|
const int tid = tiisg/4;
|
2213
3127
|
const int ix = tiisg%4;
|
2214
|
-
const int
|
3128
|
+
const int iq = tid/4;
|
2215
3129
|
const int ir = tid%4;
|
2216
3130
|
const int n = 8;
|
2217
3131
|
|
2218
3132
|
const int l0 = n*ir;
|
2219
|
-
const int q_offset = 32*
|
2220
|
-
const int y_offset = 64*
|
3133
|
+
const int q_offset = 32*iq + l0;
|
3134
|
+
const int y_offset = 64*iq + l0;
|
2221
3135
|
|
2222
|
-
const uint8_t hm1 = 1u << (2*
|
3136
|
+
const uint8_t hm1 = 1u << (2*iq);
|
2223
3137
|
const uint8_t hm2 = hm1 << 1;
|
2224
3138
|
const uint8_t hm3 = hm1 << 4;
|
2225
3139
|
const uint8_t hm4 = hm2 << 4;
|
@@ -2234,7 +3148,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2234
3148
|
device const uint8_t * q1 = x[i].qs + q_offset;
|
2235
3149
|
device const uint8_t * qh = x[i].qh + l0;
|
2236
3150
|
device const half * dh = &x[i].d;
|
2237
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales +
|
3151
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
|
2238
3152
|
|
2239
3153
|
device const float * y2 = y1 + 128;
|
2240
3154
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
@@ -2290,7 +3204,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2290
3204
|
|
2291
3205
|
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
2292
3206
|
const int ix = tiisg%8;
|
2293
|
-
const int
|
3207
|
+
const int iq = il/8; // 0, 0, 1, 1
|
2294
3208
|
const int in = il%8; // 0, 4, 0, 4
|
2295
3209
|
|
2296
3210
|
device const float * y = yy + ix*QK_K + il;
|
@@ -2315,7 +3229,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2315
3229
|
|
2316
3230
|
float2 acc = {0.f, 0.f};
|
2317
3231
|
for (int l = 0; l < 4; ++l) {
|
2318
|
-
const uint8_t hl = h[l] >>
|
3232
|
+
const uint8_t hl = h[l] >> iq;
|
2319
3233
|
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
2320
3234
|
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
2321
3235
|
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
@@ -2337,27 +3251,48 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2337
3251
|
for (int row = 0; row < 2; ++row) {
|
2338
3252
|
const float tot = simd_sum(sumf[row]);
|
2339
3253
|
if (tiisg == 0) {
|
2340
|
-
dst[r1*ne0 +
|
3254
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
2341
3255
|
}
|
2342
3256
|
}
|
3257
|
+
}
|
3258
|
+
|
3259
|
+
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
3260
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
3261
|
+
device const void * src0,
|
3262
|
+
device const float * src1,
|
3263
|
+
device float * dst,
|
3264
|
+
constant int64_t & ne00,
|
3265
|
+
constant int64_t & ne01[[buffer(4)]],
|
3266
|
+
constant int64_t & ne02[[buffer(5)]],
|
3267
|
+
constant int64_t & ne10[[buffer(9)]],
|
3268
|
+
constant int64_t & ne12[[buffer(11)]],
|
3269
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3270
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3271
|
+
constant uint & r2 [[buffer(17)]],
|
3272
|
+
constant uint & r3 [[buffer(18)]],
|
3273
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3274
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3275
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2343
3276
|
|
3277
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2344
3278
|
}
|
2345
3279
|
|
2346
|
-
|
3280
|
+
void kernel_mul_mv_q6_K_f32_impl(
|
2347
3281
|
device const void * src0,
|
2348
3282
|
device const float * src1,
|
2349
3283
|
device float * dst,
|
2350
3284
|
constant int64_t & ne00,
|
2351
|
-
constant int64_t & ne01
|
2352
|
-
constant int64_t & ne02
|
2353
|
-
constant int64_t & ne10
|
2354
|
-
constant int64_t & ne12
|
2355
|
-
constant int64_t & ne0
|
2356
|
-
constant int64_t & ne1
|
2357
|
-
constant uint &
|
3285
|
+
constant int64_t & ne01,
|
3286
|
+
constant int64_t & ne02,
|
3287
|
+
constant int64_t & ne10,
|
3288
|
+
constant int64_t & ne12,
|
3289
|
+
constant int64_t & ne0,
|
3290
|
+
constant int64_t & ne1,
|
3291
|
+
constant uint & r2,
|
3292
|
+
constant uint & r3,
|
2358
3293
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2359
|
-
uint
|
2360
|
-
uint
|
3294
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3295
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2361
3296
|
|
2362
3297
|
const uint8_t kmask1 = 0x03;
|
2363
3298
|
const uint8_t kmask2 = 0x0C;
|
@@ -2368,12 +3303,17 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2368
3303
|
|
2369
3304
|
const int64_t r0 = tgpig.x;
|
2370
3305
|
const int64_t r1 = tgpig.y;
|
2371
|
-
const int
|
3306
|
+
const int im = tgpig.z;
|
2372
3307
|
|
2373
3308
|
const int row = 2 * r0 + sgitg;
|
2374
|
-
|
3309
|
+
|
3310
|
+
const uint i12 = im%ne12;
|
3311
|
+
const uint i13 = im/ne12;
|
3312
|
+
|
3313
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3314
|
+
|
2375
3315
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
2376
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
3316
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2377
3317
|
|
2378
3318
|
float sumf = 0;
|
2379
3319
|
|
@@ -2439,10 +3379,31 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2439
3379
|
|
2440
3380
|
const float tot = simd_sum(sumf);
|
2441
3381
|
if (tiisg == 0) {
|
2442
|
-
dst[r1*ne0 +
|
3382
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
2443
3383
|
}
|
2444
3384
|
}
|
2445
3385
|
|
3386
|
+
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
3387
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
3388
|
+
device const void * src0,
|
3389
|
+
device const float * src1,
|
3390
|
+
device float * dst,
|
3391
|
+
constant int64_t & ne00,
|
3392
|
+
constant int64_t & ne01[[buffer(4)]],
|
3393
|
+
constant int64_t & ne02[[buffer(5)]],
|
3394
|
+
constant int64_t & ne10[[buffer(9)]],
|
3395
|
+
constant int64_t & ne12[[buffer(11)]],
|
3396
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3397
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3398
|
+
constant uint & r2 [[buffer(17)]],
|
3399
|
+
constant uint & r3 [[buffer(18)]],
|
3400
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3401
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3402
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3403
|
+
|
3404
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3405
|
+
}
|
3406
|
+
|
2446
3407
|
//============================= templates and their specializations =============================
|
2447
3408
|
|
2448
3409
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
@@ -2560,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
2560
3521
|
|
2561
3522
|
template <typename type4x4>
|
2562
3523
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
2563
|
-
const
|
2564
|
-
const
|
3524
|
+
const float d = xb->d;
|
3525
|
+
const float min = xb->dmin;
|
2565
3526
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
2566
|
-
|
3527
|
+
float dl, ml;
|
2567
3528
|
uint8_t sc = xb->scales[il];
|
2568
3529
|
|
2569
3530
|
#if QK_K == 256
|
@@ -2633,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
2633
3594
|
q = q + (il/4) * 32 + 16 * (il&1);
|
2634
3595
|
il = il & 3;
|
2635
3596
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2636
|
-
const
|
2637
|
-
const
|
2638
|
-
const
|
2639
|
-
const
|
3597
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3598
|
+
const float min = xb->dmin;
|
3599
|
+
const float dl = d * sc[0];
|
3600
|
+
const float ml = min * sc[1];
|
2640
3601
|
#else
|
2641
3602
|
q = q + 16 * (il&1);
|
2642
3603
|
device const uint8_t * s = xb->scales;
|
@@ -2663,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
2663
3624
|
uint8_t ul = 1 << (il/2);
|
2664
3625
|
il = il & 3;
|
2665
3626
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2666
|
-
const
|
2667
|
-
const
|
2668
|
-
const
|
2669
|
-
const
|
3627
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3628
|
+
const float min = xb->dmin;
|
3629
|
+
const float dl = d * sc[0];
|
3630
|
+
const float ml = min * sc[1];
|
2670
3631
|
|
2671
|
-
const ushort mask
|
2672
|
-
const
|
3632
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
3633
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
2673
3634
|
for (int i = 0; i < 16; ++i) {
|
2674
3635
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
2675
3636
|
}
|
@@ -2717,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
2717
3678
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
2718
3679
|
kernel void kernel_get_rows(
|
2719
3680
|
device const void * src0,
|
2720
|
-
device const
|
3681
|
+
device const char * src1,
|
2721
3682
|
device float * dst,
|
2722
3683
|
constant int64_t & ne00,
|
2723
3684
|
constant uint64_t & nb01,
|
3685
|
+
constant uint64_t & nb02,
|
3686
|
+
constant int64_t & ne10,
|
3687
|
+
constant uint64_t & nb10,
|
3688
|
+
constant uint64_t & nb11,
|
2724
3689
|
constant uint64_t & nb1,
|
2725
|
-
|
3690
|
+
constant uint64_t & nb2,
|
3691
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2726
3692
|
uint tiitg[[thread_index_in_threadgroup]],
|
2727
|
-
|
2728
|
-
const
|
2729
|
-
const
|
3693
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3694
|
+
//const int64_t i = tgpig;
|
3695
|
+
//const int64_t r = ((device int32_t *) src1)[i];
|
3696
|
+
|
3697
|
+
const int64_t i10 = tgpig.x;
|
3698
|
+
const int64_t i11 = tgpig.y;
|
2730
3699
|
|
2731
|
-
|
3700
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3701
|
+
|
3702
|
+
const int64_t i02 = i11;
|
3703
|
+
|
3704
|
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
2732
3705
|
float4x4 temp;
|
2733
3706
|
dequantize_func(
|
2734
|
-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
2735
|
-
*(((device float4x4 *) ((device char *) dst +
|
3707
|
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
3708
|
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
3709
|
+
}
|
3710
|
+
}
|
3711
|
+
|
3712
|
+
kernel void kernel_get_rows_f32(
|
3713
|
+
device const void * src0,
|
3714
|
+
device const char * src1,
|
3715
|
+
device float * dst,
|
3716
|
+
constant int64_t & ne00,
|
3717
|
+
constant uint64_t & nb01,
|
3718
|
+
constant uint64_t & nb02,
|
3719
|
+
constant int64_t & ne10,
|
3720
|
+
constant uint64_t & nb10,
|
3721
|
+
constant uint64_t & nb11,
|
3722
|
+
constant uint64_t & nb1,
|
3723
|
+
constant uint64_t & nb2,
|
3724
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3725
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3726
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3727
|
+
const int64_t i10 = tgpig.x;
|
3728
|
+
const int64_t i11 = tgpig.y;
|
3729
|
+
|
3730
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3731
|
+
|
3732
|
+
const int64_t i02 = i11;
|
3733
|
+
|
3734
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3735
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3736
|
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3737
|
+
}
|
3738
|
+
}
|
3739
|
+
|
3740
|
+
kernel void kernel_get_rows_f16(
|
3741
|
+
device const void * src0,
|
3742
|
+
device const char * src1,
|
3743
|
+
device float * dst,
|
3744
|
+
constant int64_t & ne00,
|
3745
|
+
constant uint64_t & nb01,
|
3746
|
+
constant uint64_t & nb02,
|
3747
|
+
constant int64_t & ne10,
|
3748
|
+
constant uint64_t & nb10,
|
3749
|
+
constant uint64_t & nb11,
|
3750
|
+
constant uint64_t & nb1,
|
3751
|
+
constant uint64_t & nb2,
|
3752
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3753
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3754
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3755
|
+
const int64_t i10 = tgpig.x;
|
3756
|
+
const int64_t i11 = tgpig.y;
|
3757
|
+
|
3758
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3759
|
+
|
3760
|
+
const int64_t i02 = i11;
|
3761
|
+
|
3762
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3763
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3764
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
2736
3765
|
}
|
2737
3766
|
}
|
2738
3767
|
|
@@ -2749,24 +3778,25 @@ kernel void kernel_get_rows(
|
|
2749
3778
|
|
2750
3779
|
// each block_q contains 16*nl weights
|
2751
3780
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
2752
|
-
|
2753
|
-
|
2754
|
-
|
2755
|
-
|
2756
|
-
|
2757
|
-
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2762
|
-
|
2763
|
-
|
2764
|
-
|
2765
|
-
|
2766
|
-
|
2767
|
-
|
2768
|
-
|
2769
|
-
|
3781
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
3782
|
+
device const uchar * src1,
|
3783
|
+
device float * dst,
|
3784
|
+
constant int64_t & ne00,
|
3785
|
+
constant int64_t & ne02,
|
3786
|
+
constant int64_t & nb01,
|
3787
|
+
constant int64_t & nb02,
|
3788
|
+
constant int64_t & ne12,
|
3789
|
+
constant int64_t & nb10,
|
3790
|
+
constant int64_t & nb11,
|
3791
|
+
constant int64_t & nb12,
|
3792
|
+
constant int64_t & ne0,
|
3793
|
+
constant int64_t & ne1,
|
3794
|
+
constant uint & r2,
|
3795
|
+
constant uint & r3,
|
3796
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3797
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3798
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3799
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2770
3800
|
|
2771
3801
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
2772
3802
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
@@ -2792,7 +3822,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2792
3822
|
|
2793
3823
|
short il = (tiitg % THREAD_PER_ROW);
|
2794
3824
|
|
2795
|
-
uint
|
3825
|
+
const uint i12 = im%ne12;
|
3826
|
+
const uint i13 = im/ne12;
|
3827
|
+
|
3828
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
2796
3829
|
ushort offset1 = il/nl;
|
2797
3830
|
|
2798
3831
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
@@ -2876,17 +3909,137 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2876
3909
|
}
|
2877
3910
|
}
|
2878
3911
|
|
3912
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3913
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
3914
|
+
device const uchar * src1,
|
3915
|
+
device float * dst,
|
3916
|
+
constant int64_t & ne00,
|
3917
|
+
constant int64_t & ne02,
|
3918
|
+
constant int64_t & nb01,
|
3919
|
+
constant int64_t & nb02,
|
3920
|
+
constant int64_t & ne12,
|
3921
|
+
constant int64_t & nb10,
|
3922
|
+
constant int64_t & nb11,
|
3923
|
+
constant int64_t & nb12,
|
3924
|
+
constant int64_t & ne0,
|
3925
|
+
constant int64_t & ne1,
|
3926
|
+
constant uint & r2,
|
3927
|
+
constant uint & r3,
|
3928
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3929
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3930
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3931
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3932
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3933
|
+
src0,
|
3934
|
+
src1,
|
3935
|
+
dst,
|
3936
|
+
ne00,
|
3937
|
+
ne02,
|
3938
|
+
nb01,
|
3939
|
+
nb02,
|
3940
|
+
ne12,
|
3941
|
+
nb10,
|
3942
|
+
nb11,
|
3943
|
+
nb12,
|
3944
|
+
ne0,
|
3945
|
+
ne1,
|
3946
|
+
r2,
|
3947
|
+
r3,
|
3948
|
+
shared_memory,
|
3949
|
+
tgpig,
|
3950
|
+
tiitg,
|
3951
|
+
sgitg);
|
3952
|
+
}
|
3953
|
+
|
3954
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3955
|
+
kernel void kernel_mul_mm_id(
|
3956
|
+
device const uchar * ids,
|
3957
|
+
device const uchar * src1,
|
3958
|
+
device uchar * dst,
|
3959
|
+
constant int64_t & nbi1,
|
3960
|
+
constant int64_t & ne00,
|
3961
|
+
constant int64_t & ne02,
|
3962
|
+
constant int64_t & nb01,
|
3963
|
+
constant int64_t & nb02,
|
3964
|
+
constant int64_t & ne12,
|
3965
|
+
constant int64_t & ne13,
|
3966
|
+
constant int64_t & nb10,
|
3967
|
+
constant int64_t & nb11,
|
3968
|
+
constant int64_t & nb12,
|
3969
|
+
constant int64_t & ne0,
|
3970
|
+
constant int64_t & ne1,
|
3971
|
+
constant int64_t & nb1,
|
3972
|
+
constant uint & r2,
|
3973
|
+
constant uint & r3,
|
3974
|
+
constant int & idx,
|
3975
|
+
device const uchar * src00,
|
3976
|
+
device const uchar * src01,
|
3977
|
+
device const uchar * src02,
|
3978
|
+
device const uchar * src03,
|
3979
|
+
device const uchar * src04,
|
3980
|
+
device const uchar * src05,
|
3981
|
+
device const uchar * src06,
|
3982
|
+
device const uchar * src07,
|
3983
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3984
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3985
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3986
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3987
|
+
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3988
|
+
|
3989
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
3990
|
+
|
3991
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
3992
|
+
|
3993
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
3994
|
+
|
3995
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3996
|
+
src0[id],
|
3997
|
+
src1 + bid*nb11,
|
3998
|
+
(device float *) (dst + bid*nb1),
|
3999
|
+
ne00,
|
4000
|
+
ne02,
|
4001
|
+
nb01,
|
4002
|
+
nb02,
|
4003
|
+
ne12,
|
4004
|
+
nb10,
|
4005
|
+
nb11,
|
4006
|
+
nb12,
|
4007
|
+
ne0,
|
4008
|
+
ne1,
|
4009
|
+
r2,
|
4010
|
+
r3,
|
4011
|
+
shared_memory,
|
4012
|
+
tgpig,
|
4013
|
+
tiitg,
|
4014
|
+
sgitg);
|
4015
|
+
}
|
4016
|
+
|
2879
4017
|
#if QK_K == 256
|
2880
4018
|
#define QK_NL 16
|
2881
4019
|
#else
|
2882
4020
|
#define QK_NL 4
|
2883
4021
|
#endif
|
2884
4022
|
|
2885
|
-
|
2886
|
-
|
4023
|
+
//
|
4024
|
+
// get rows
|
4025
|
+
//
|
2887
4026
|
|
2888
|
-
|
2889
|
-
|
4027
|
+
typedef void (get_rows_t)(
|
4028
|
+
device const void * src0,
|
4029
|
+
device const char * src1,
|
4030
|
+
device float * dst,
|
4031
|
+
constant int64_t & ne00,
|
4032
|
+
constant uint64_t & nb01,
|
4033
|
+
constant uint64_t & nb02,
|
4034
|
+
constant int64_t & ne10,
|
4035
|
+
constant uint64_t & nb10,
|
4036
|
+
constant uint64_t & nb11,
|
4037
|
+
constant uint64_t & nb1,
|
4038
|
+
constant uint64_t & nb2,
|
4039
|
+
uint3, uint, uint3);
|
4040
|
+
|
4041
|
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
4042
|
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2890
4043
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2891
4044
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2892
4045
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
@@ -2898,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
2898
4051
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
2899
4052
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
2900
4053
|
|
4054
|
+
//
|
4055
|
+
// matrix-matrix multiplication
|
4056
|
+
//
|
4057
|
+
|
2901
4058
|
typedef void (mat_mm_t)(
|
2902
4059
|
device const uchar * src0,
|
2903
4060
|
device const uchar * src1,
|
@@ -2912,8 +4069,10 @@ typedef void (mat_mm_t)(
|
|
2912
4069
|
constant int64_t & nb12,
|
2913
4070
|
constant int64_t & ne0,
|
2914
4071
|
constant int64_t & ne1,
|
2915
|
-
constant uint &
|
2916
|
-
|
4072
|
+
constant uint & r2,
|
4073
|
+
constant uint & r3,
|
4074
|
+
threadgroup uchar *,
|
4075
|
+
uint3, uint, uint);
|
2917
4076
|
|
2918
4077
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
2919
4078
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
@@ -2927,3 +4086,823 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
2927
4086
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
2928
4087
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
2929
4088
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
4089
|
+
|
4090
|
+
//
|
4091
|
+
// indirect matrix-matrix multiplication
|
4092
|
+
//
|
4093
|
+
|
4094
|
+
typedef void (mat_mm_id_t)(
|
4095
|
+
device const uchar * ids,
|
4096
|
+
device const uchar * src1,
|
4097
|
+
device uchar * dst,
|
4098
|
+
constant int64_t & nbi1,
|
4099
|
+
constant int64_t & ne00,
|
4100
|
+
constant int64_t & ne02,
|
4101
|
+
constant int64_t & nb01,
|
4102
|
+
constant int64_t & nb02,
|
4103
|
+
constant int64_t & ne12,
|
4104
|
+
constant int64_t & ne13,
|
4105
|
+
constant int64_t & nb10,
|
4106
|
+
constant int64_t & nb11,
|
4107
|
+
constant int64_t & nb12,
|
4108
|
+
constant int64_t & ne0,
|
4109
|
+
constant int64_t & ne1,
|
4110
|
+
constant int64_t & nb1,
|
4111
|
+
constant uint & r2,
|
4112
|
+
constant uint & r3,
|
4113
|
+
constant int & idx,
|
4114
|
+
device const uchar * src00,
|
4115
|
+
device const uchar * src01,
|
4116
|
+
device const uchar * src02,
|
4117
|
+
device const uchar * src03,
|
4118
|
+
device const uchar * src04,
|
4119
|
+
device const uchar * src05,
|
4120
|
+
device const uchar * src06,
|
4121
|
+
device const uchar * src07,
|
4122
|
+
threadgroup uchar *,
|
4123
|
+
uint3, uint, uint);
|
4124
|
+
|
4125
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
4126
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
4127
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
4128
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
4129
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
4130
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
4131
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
4132
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
4133
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
4134
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
4135
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
4136
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
4137
|
+
|
4138
|
+
//
|
4139
|
+
// matrix-vector multiplication
|
4140
|
+
//
|
4141
|
+
|
4142
|
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
4143
|
+
kernel void kernel_mul_mv_id_f32_f32(
|
4144
|
+
device const char * ids,
|
4145
|
+
device const char * src1,
|
4146
|
+
device uchar * dst,
|
4147
|
+
constant int64_t & nbi1,
|
4148
|
+
constant int64_t & ne00,
|
4149
|
+
constant int64_t & ne01,
|
4150
|
+
constant int64_t & ne02,
|
4151
|
+
constant uint64_t & nb00,
|
4152
|
+
constant uint64_t & nb01,
|
4153
|
+
constant uint64_t & nb02,
|
4154
|
+
constant int64_t & ne10,
|
4155
|
+
constant int64_t & ne11,
|
4156
|
+
constant int64_t & ne12,
|
4157
|
+
constant int64_t & ne13,
|
4158
|
+
constant uint64_t & nb10,
|
4159
|
+
constant uint64_t & nb11,
|
4160
|
+
constant uint64_t & nb12,
|
4161
|
+
constant int64_t & ne0,
|
4162
|
+
constant int64_t & ne1,
|
4163
|
+
constant int64_t & nb1,
|
4164
|
+
constant uint & r2,
|
4165
|
+
constant uint & r3,
|
4166
|
+
constant int & idx,
|
4167
|
+
device const char * src00,
|
4168
|
+
device const char * src01,
|
4169
|
+
device const char * src02,
|
4170
|
+
device const char * src03,
|
4171
|
+
device const char * src04,
|
4172
|
+
device const char * src05,
|
4173
|
+
device const char * src06,
|
4174
|
+
device const char * src07,
|
4175
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4176
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4177
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4178
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4179
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4180
|
+
|
4181
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4182
|
+
|
4183
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4184
|
+
|
4185
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4186
|
+
|
4187
|
+
kernel_mul_mv_f32_f32_impl(
|
4188
|
+
src0[id],
|
4189
|
+
src1 + bid*nb11,
|
4190
|
+
(device float *) (dst + bid*nb1),
|
4191
|
+
ne00,
|
4192
|
+
ne01,
|
4193
|
+
ne02,
|
4194
|
+
nb00,
|
4195
|
+
nb01,
|
4196
|
+
nb02,
|
4197
|
+
ne10,
|
4198
|
+
ne11,
|
4199
|
+
ne12,
|
4200
|
+
nb10,
|
4201
|
+
nb11,
|
4202
|
+
nb12,
|
4203
|
+
ne0,
|
4204
|
+
ne1,
|
4205
|
+
r2,
|
4206
|
+
r3,
|
4207
|
+
tgpig,
|
4208
|
+
tiisg);
|
4209
|
+
}
|
4210
|
+
|
4211
|
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
4212
|
+
kernel void kernel_mul_mv_id_f16_f32(
|
4213
|
+
device const char * ids,
|
4214
|
+
device const char * src1,
|
4215
|
+
device uchar * dst,
|
4216
|
+
constant int64_t & nbi1,
|
4217
|
+
constant int64_t & ne00,
|
4218
|
+
constant int64_t & ne01,
|
4219
|
+
constant int64_t & ne02,
|
4220
|
+
constant uint64_t & nb00,
|
4221
|
+
constant uint64_t & nb01,
|
4222
|
+
constant uint64_t & nb02,
|
4223
|
+
constant int64_t & ne10,
|
4224
|
+
constant int64_t & ne11,
|
4225
|
+
constant int64_t & ne12,
|
4226
|
+
constant int64_t & ne13,
|
4227
|
+
constant uint64_t & nb10,
|
4228
|
+
constant uint64_t & nb11,
|
4229
|
+
constant uint64_t & nb12,
|
4230
|
+
constant int64_t & ne0,
|
4231
|
+
constant int64_t & ne1,
|
4232
|
+
constant int64_t & nb1,
|
4233
|
+
constant uint & r2,
|
4234
|
+
constant uint & r3,
|
4235
|
+
constant int & idx,
|
4236
|
+
device const char * src00,
|
4237
|
+
device const char * src01,
|
4238
|
+
device const char * src02,
|
4239
|
+
device const char * src03,
|
4240
|
+
device const char * src04,
|
4241
|
+
device const char * src05,
|
4242
|
+
device const char * src06,
|
4243
|
+
device const char * src07,
|
4244
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4245
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4246
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4247
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4248
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4249
|
+
|
4250
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4251
|
+
|
4252
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4253
|
+
|
4254
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4255
|
+
|
4256
|
+
kernel_mul_mv_f16_f32_impl(
|
4257
|
+
src0[id],
|
4258
|
+
src1 + bid*nb11,
|
4259
|
+
(device float *) (dst + bid*nb1),
|
4260
|
+
ne00,
|
4261
|
+
ne01,
|
4262
|
+
ne02,
|
4263
|
+
nb00,
|
4264
|
+
nb01,
|
4265
|
+
nb02,
|
4266
|
+
ne10,
|
4267
|
+
ne11,
|
4268
|
+
ne12,
|
4269
|
+
nb10,
|
4270
|
+
nb11,
|
4271
|
+
nb12,
|
4272
|
+
ne0,
|
4273
|
+
ne1,
|
4274
|
+
r2,
|
4275
|
+
r3,
|
4276
|
+
tgpig,
|
4277
|
+
tiisg);
|
4278
|
+
}
|
4279
|
+
|
4280
|
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
4281
|
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
4282
|
+
device const char * ids,
|
4283
|
+
device const char * src1,
|
4284
|
+
device uchar * dst,
|
4285
|
+
constant int64_t & nbi1,
|
4286
|
+
constant int64_t & ne00,
|
4287
|
+
constant int64_t & ne01,
|
4288
|
+
constant int64_t & ne02,
|
4289
|
+
constant uint64_t & nb00,
|
4290
|
+
constant uint64_t & nb01,
|
4291
|
+
constant uint64_t & nb02,
|
4292
|
+
constant int64_t & ne10,
|
4293
|
+
constant int64_t & ne11,
|
4294
|
+
constant int64_t & ne12,
|
4295
|
+
constant int64_t & ne13,
|
4296
|
+
constant uint64_t & nb10,
|
4297
|
+
constant uint64_t & nb11,
|
4298
|
+
constant uint64_t & nb12,
|
4299
|
+
constant int64_t & ne0,
|
4300
|
+
constant int64_t & ne1,
|
4301
|
+
constant int64_t & nb1,
|
4302
|
+
constant uint & r2,
|
4303
|
+
constant uint & r3,
|
4304
|
+
constant int & idx,
|
4305
|
+
device const char * src00,
|
4306
|
+
device const char * src01,
|
4307
|
+
device const char * src02,
|
4308
|
+
device const char * src03,
|
4309
|
+
device const char * src04,
|
4310
|
+
device const char * src05,
|
4311
|
+
device const char * src06,
|
4312
|
+
device const char * src07,
|
4313
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4314
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4315
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4316
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4317
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4318
|
+
|
4319
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4320
|
+
|
4321
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4322
|
+
|
4323
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4324
|
+
|
4325
|
+
kernel_mul_mv_q8_0_f32_impl(
|
4326
|
+
src0[id],
|
4327
|
+
(device const float *) (src1 + bid*nb11),
|
4328
|
+
(device float *) ( dst + bid*nb1),
|
4329
|
+
ne00,
|
4330
|
+
ne01,
|
4331
|
+
ne02,
|
4332
|
+
ne10,
|
4333
|
+
ne12,
|
4334
|
+
ne0,
|
4335
|
+
ne1,
|
4336
|
+
r2,
|
4337
|
+
r3,
|
4338
|
+
tgpig,
|
4339
|
+
tiisg,
|
4340
|
+
sgitg);
|
4341
|
+
}
|
4342
|
+
|
4343
|
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
4344
|
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
4345
|
+
device const char * ids,
|
4346
|
+
device const char * src1,
|
4347
|
+
device uchar * dst,
|
4348
|
+
constant int64_t & nbi1,
|
4349
|
+
constant int64_t & ne00,
|
4350
|
+
constant int64_t & ne01,
|
4351
|
+
constant int64_t & ne02,
|
4352
|
+
constant uint64_t & nb00,
|
4353
|
+
constant uint64_t & nb01,
|
4354
|
+
constant uint64_t & nb02,
|
4355
|
+
constant int64_t & ne10,
|
4356
|
+
constant int64_t & ne11,
|
4357
|
+
constant int64_t & ne12,
|
4358
|
+
constant int64_t & ne13,
|
4359
|
+
constant uint64_t & nb10,
|
4360
|
+
constant uint64_t & nb11,
|
4361
|
+
constant uint64_t & nb12,
|
4362
|
+
constant int64_t & ne0,
|
4363
|
+
constant int64_t & ne1,
|
4364
|
+
constant int64_t & nb1,
|
4365
|
+
constant uint & r2,
|
4366
|
+
constant uint & r3,
|
4367
|
+
constant int & idx,
|
4368
|
+
device const char * src00,
|
4369
|
+
device const char * src01,
|
4370
|
+
device const char * src02,
|
4371
|
+
device const char * src03,
|
4372
|
+
device const char * src04,
|
4373
|
+
device const char * src05,
|
4374
|
+
device const char * src06,
|
4375
|
+
device const char * src07,
|
4376
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4377
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4378
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4379
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4380
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4381
|
+
|
4382
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4383
|
+
|
4384
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4385
|
+
|
4386
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4387
|
+
|
4388
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4389
|
+
src0[id],
|
4390
|
+
(device const float *) (src1 + bid*nb11),
|
4391
|
+
(device float *) ( dst + bid*nb1),
|
4392
|
+
ne00,
|
4393
|
+
ne01,
|
4394
|
+
ne02,
|
4395
|
+
ne10,
|
4396
|
+
ne12,
|
4397
|
+
ne0,
|
4398
|
+
ne1,
|
4399
|
+
r2,
|
4400
|
+
r3,
|
4401
|
+
tgpig,
|
4402
|
+
tiisg,
|
4403
|
+
sgitg);
|
4404
|
+
}
|
4405
|
+
|
4406
|
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
4407
|
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
4408
|
+
device const char * ids,
|
4409
|
+
device const char * src1,
|
4410
|
+
device uchar * dst,
|
4411
|
+
constant int64_t & nbi1,
|
4412
|
+
constant int64_t & ne00,
|
4413
|
+
constant int64_t & ne01,
|
4414
|
+
constant int64_t & ne02,
|
4415
|
+
constant uint64_t & nb00,
|
4416
|
+
constant uint64_t & nb01,
|
4417
|
+
constant uint64_t & nb02,
|
4418
|
+
constant int64_t & ne10,
|
4419
|
+
constant int64_t & ne11,
|
4420
|
+
constant int64_t & ne12,
|
4421
|
+
constant int64_t & ne13,
|
4422
|
+
constant uint64_t & nb10,
|
4423
|
+
constant uint64_t & nb11,
|
4424
|
+
constant uint64_t & nb12,
|
4425
|
+
constant int64_t & ne0,
|
4426
|
+
constant int64_t & ne1,
|
4427
|
+
constant int64_t & nb1,
|
4428
|
+
constant uint & r2,
|
4429
|
+
constant uint & r3,
|
4430
|
+
constant int & idx,
|
4431
|
+
device const char * src00,
|
4432
|
+
device const char * src01,
|
4433
|
+
device const char * src02,
|
4434
|
+
device const char * src03,
|
4435
|
+
device const char * src04,
|
4436
|
+
device const char * src05,
|
4437
|
+
device const char * src06,
|
4438
|
+
device const char * src07,
|
4439
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4440
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4441
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4442
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4443
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4444
|
+
|
4445
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4446
|
+
|
4447
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4448
|
+
|
4449
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4450
|
+
|
4451
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4452
|
+
src0[id],
|
4453
|
+
(device const float *) (src1 + bid*nb11),
|
4454
|
+
(device float *) ( dst + bid*nb1),
|
4455
|
+
ne00,
|
4456
|
+
ne01,
|
4457
|
+
ne02,
|
4458
|
+
ne10,
|
4459
|
+
ne12,
|
4460
|
+
ne0,
|
4461
|
+
ne1,
|
4462
|
+
r2,
|
4463
|
+
r3,
|
4464
|
+
tgpig,
|
4465
|
+
tiisg,
|
4466
|
+
sgitg);
|
4467
|
+
}
|
4468
|
+
|
4469
|
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
4470
|
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
4471
|
+
device const char * ids,
|
4472
|
+
device const char * src1,
|
4473
|
+
device uchar * dst,
|
4474
|
+
constant int64_t & nbi1,
|
4475
|
+
constant int64_t & ne00,
|
4476
|
+
constant int64_t & ne01,
|
4477
|
+
constant int64_t & ne02,
|
4478
|
+
constant uint64_t & nb00,
|
4479
|
+
constant uint64_t & nb01,
|
4480
|
+
constant uint64_t & nb02,
|
4481
|
+
constant int64_t & ne10,
|
4482
|
+
constant int64_t & ne11,
|
4483
|
+
constant int64_t & ne12,
|
4484
|
+
constant int64_t & ne13,
|
4485
|
+
constant uint64_t & nb10,
|
4486
|
+
constant uint64_t & nb11,
|
4487
|
+
constant uint64_t & nb12,
|
4488
|
+
constant int64_t & ne0,
|
4489
|
+
constant int64_t & ne1,
|
4490
|
+
constant int64_t & nb1,
|
4491
|
+
constant uint & r2,
|
4492
|
+
constant uint & r3,
|
4493
|
+
constant int & idx,
|
4494
|
+
device const char * src00,
|
4495
|
+
device const char * src01,
|
4496
|
+
device const char * src02,
|
4497
|
+
device const char * src03,
|
4498
|
+
device const char * src04,
|
4499
|
+
device const char * src05,
|
4500
|
+
device const char * src06,
|
4501
|
+
device const char * src07,
|
4502
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4503
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4504
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4505
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4506
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4507
|
+
|
4508
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4509
|
+
|
4510
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4511
|
+
|
4512
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4513
|
+
|
4514
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4515
|
+
src0[id],
|
4516
|
+
(device const float *) (src1 + bid*nb11),
|
4517
|
+
(device float *) ( dst + bid*nb1),
|
4518
|
+
ne00,
|
4519
|
+
ne01,
|
4520
|
+
ne02,
|
4521
|
+
ne10,
|
4522
|
+
ne12,
|
4523
|
+
ne0,
|
4524
|
+
ne1,
|
4525
|
+
r2,
|
4526
|
+
r3,
|
4527
|
+
tgpig,
|
4528
|
+
tiisg,
|
4529
|
+
sgitg);
|
4530
|
+
}
|
4531
|
+
|
4532
|
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
4533
|
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
4534
|
+
device const char * ids,
|
4535
|
+
device const char * src1,
|
4536
|
+
device uchar * dst,
|
4537
|
+
constant int64_t & nbi1,
|
4538
|
+
constant int64_t & ne00,
|
4539
|
+
constant int64_t & ne01,
|
4540
|
+
constant int64_t & ne02,
|
4541
|
+
constant uint64_t & nb00,
|
4542
|
+
constant uint64_t & nb01,
|
4543
|
+
constant uint64_t & nb02,
|
4544
|
+
constant int64_t & ne10,
|
4545
|
+
constant int64_t & ne11,
|
4546
|
+
constant int64_t & ne12,
|
4547
|
+
constant int64_t & ne13,
|
4548
|
+
constant uint64_t & nb10,
|
4549
|
+
constant uint64_t & nb11,
|
4550
|
+
constant uint64_t & nb12,
|
4551
|
+
constant int64_t & ne0,
|
4552
|
+
constant int64_t & ne1,
|
4553
|
+
constant int64_t & nb1,
|
4554
|
+
constant uint & r2,
|
4555
|
+
constant uint & r3,
|
4556
|
+
constant int & idx,
|
4557
|
+
device const char * src00,
|
4558
|
+
device const char * src01,
|
4559
|
+
device const char * src02,
|
4560
|
+
device const char * src03,
|
4561
|
+
device const char * src04,
|
4562
|
+
device const char * src05,
|
4563
|
+
device const char * src06,
|
4564
|
+
device const char * src07,
|
4565
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4566
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4567
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4568
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4569
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4570
|
+
|
4571
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4572
|
+
|
4573
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4574
|
+
|
4575
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4576
|
+
|
4577
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4578
|
+
src0[id],
|
4579
|
+
(device const float *) (src1 + bid*nb11),
|
4580
|
+
(device float *) ( dst + bid*nb1),
|
4581
|
+
ne00,
|
4582
|
+
ne01,
|
4583
|
+
ne02,
|
4584
|
+
ne10,
|
4585
|
+
ne12,
|
4586
|
+
ne0,
|
4587
|
+
ne1,
|
4588
|
+
r2,
|
4589
|
+
r3,
|
4590
|
+
tgpig,
|
4591
|
+
tiisg,
|
4592
|
+
sgitg);
|
4593
|
+
}
|
4594
|
+
|
4595
|
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
4596
|
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
4597
|
+
device const char * ids,
|
4598
|
+
device const char * src1,
|
4599
|
+
device uchar * dst,
|
4600
|
+
constant int64_t & nbi1,
|
4601
|
+
constant int64_t & ne00,
|
4602
|
+
constant int64_t & ne01,
|
4603
|
+
constant int64_t & ne02,
|
4604
|
+
constant uint64_t & nb00,
|
4605
|
+
constant uint64_t & nb01,
|
4606
|
+
constant uint64_t & nb02,
|
4607
|
+
constant int64_t & ne10,
|
4608
|
+
constant int64_t & ne11,
|
4609
|
+
constant int64_t & ne12,
|
4610
|
+
constant int64_t & ne13,
|
4611
|
+
constant uint64_t & nb10,
|
4612
|
+
constant uint64_t & nb11,
|
4613
|
+
constant uint64_t & nb12,
|
4614
|
+
constant int64_t & ne0,
|
4615
|
+
constant int64_t & ne1,
|
4616
|
+
constant int64_t & nb1,
|
4617
|
+
constant uint & r2,
|
4618
|
+
constant uint & r3,
|
4619
|
+
constant int & idx,
|
4620
|
+
device const char * src00,
|
4621
|
+
device const char * src01,
|
4622
|
+
device const char * src02,
|
4623
|
+
device const char * src03,
|
4624
|
+
device const char * src04,
|
4625
|
+
device const char * src05,
|
4626
|
+
device const char * src06,
|
4627
|
+
device const char * src07,
|
4628
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4629
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4630
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4631
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4632
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4633
|
+
|
4634
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4635
|
+
|
4636
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4637
|
+
|
4638
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4639
|
+
|
4640
|
+
kernel_mul_mv_q2_K_f32_impl(
|
4641
|
+
src0[id],
|
4642
|
+
(device const float *) (src1 + bid*nb11),
|
4643
|
+
(device float *) ( dst + bid*nb1),
|
4644
|
+
ne00,
|
4645
|
+
ne01,
|
4646
|
+
ne02,
|
4647
|
+
ne10,
|
4648
|
+
ne12,
|
4649
|
+
ne0,
|
4650
|
+
ne1,
|
4651
|
+
r2,
|
4652
|
+
r3,
|
4653
|
+
tgpig,
|
4654
|
+
tiisg,
|
4655
|
+
sgitg);
|
4656
|
+
}
|
4657
|
+
|
4658
|
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
4659
|
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
4660
|
+
device const char * ids,
|
4661
|
+
device const char * src1,
|
4662
|
+
device uchar * dst,
|
4663
|
+
constant int64_t & nbi1,
|
4664
|
+
constant int64_t & ne00,
|
4665
|
+
constant int64_t & ne01,
|
4666
|
+
constant int64_t & ne02,
|
4667
|
+
constant uint64_t & nb00,
|
4668
|
+
constant uint64_t & nb01,
|
4669
|
+
constant uint64_t & nb02,
|
4670
|
+
constant int64_t & ne10,
|
4671
|
+
constant int64_t & ne11,
|
4672
|
+
constant int64_t & ne12,
|
4673
|
+
constant int64_t & ne13,
|
4674
|
+
constant uint64_t & nb10,
|
4675
|
+
constant uint64_t & nb11,
|
4676
|
+
constant uint64_t & nb12,
|
4677
|
+
constant int64_t & ne0,
|
4678
|
+
constant int64_t & ne1,
|
4679
|
+
constant int64_t & nb1,
|
4680
|
+
constant uint & r2,
|
4681
|
+
constant uint & r3,
|
4682
|
+
constant int & idx,
|
4683
|
+
device const char * src00,
|
4684
|
+
device const char * src01,
|
4685
|
+
device const char * src02,
|
4686
|
+
device const char * src03,
|
4687
|
+
device const char * src04,
|
4688
|
+
device const char * src05,
|
4689
|
+
device const char * src06,
|
4690
|
+
device const char * src07,
|
4691
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4692
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4693
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4694
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4695
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4696
|
+
|
4697
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4698
|
+
|
4699
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4700
|
+
|
4701
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4702
|
+
|
4703
|
+
kernel_mul_mv_q3_K_f32_impl(
|
4704
|
+
src0[id],
|
4705
|
+
(device const float *) (src1 + bid*nb11),
|
4706
|
+
(device float *) ( dst + bid*nb1),
|
4707
|
+
ne00,
|
4708
|
+
ne01,
|
4709
|
+
ne02,
|
4710
|
+
ne10,
|
4711
|
+
ne12,
|
4712
|
+
ne0,
|
4713
|
+
ne1,
|
4714
|
+
r2,
|
4715
|
+
r3,
|
4716
|
+
tgpig,
|
4717
|
+
tiisg,
|
4718
|
+
sgitg);
|
4719
|
+
}
|
4720
|
+
|
4721
|
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
4722
|
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
4723
|
+
device const char * ids,
|
4724
|
+
device const char * src1,
|
4725
|
+
device uchar * dst,
|
4726
|
+
constant int64_t & nbi1,
|
4727
|
+
constant int64_t & ne00,
|
4728
|
+
constant int64_t & ne01,
|
4729
|
+
constant int64_t & ne02,
|
4730
|
+
constant uint64_t & nb00,
|
4731
|
+
constant uint64_t & nb01,
|
4732
|
+
constant uint64_t & nb02,
|
4733
|
+
constant int64_t & ne10,
|
4734
|
+
constant int64_t & ne11,
|
4735
|
+
constant int64_t & ne12,
|
4736
|
+
constant int64_t & ne13,
|
4737
|
+
constant uint64_t & nb10,
|
4738
|
+
constant uint64_t & nb11,
|
4739
|
+
constant uint64_t & nb12,
|
4740
|
+
constant int64_t & ne0,
|
4741
|
+
constant int64_t & ne1,
|
4742
|
+
constant int64_t & nb1,
|
4743
|
+
constant uint & r2,
|
4744
|
+
constant uint & r3,
|
4745
|
+
constant int & idx,
|
4746
|
+
device const char * src00,
|
4747
|
+
device const char * src01,
|
4748
|
+
device const char * src02,
|
4749
|
+
device const char * src03,
|
4750
|
+
device const char * src04,
|
4751
|
+
device const char * src05,
|
4752
|
+
device const char * src06,
|
4753
|
+
device const char * src07,
|
4754
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4755
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4756
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4757
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4758
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4759
|
+
|
4760
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4761
|
+
|
4762
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4763
|
+
|
4764
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4765
|
+
|
4766
|
+
kernel_mul_mv_q4_K_f32_impl(
|
4767
|
+
src0[id],
|
4768
|
+
(device const float *) (src1 + bid*nb11),
|
4769
|
+
(device float *) ( dst + bid*nb1),
|
4770
|
+
ne00,
|
4771
|
+
ne01,
|
4772
|
+
ne02,
|
4773
|
+
ne10,
|
4774
|
+
ne12,
|
4775
|
+
ne0,
|
4776
|
+
ne1,
|
4777
|
+
r2,
|
4778
|
+
r3,
|
4779
|
+
tgpig,
|
4780
|
+
tiisg,
|
4781
|
+
sgitg);
|
4782
|
+
}
|
4783
|
+
|
4784
|
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
4785
|
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
4786
|
+
device const char * ids,
|
4787
|
+
device const char * src1,
|
4788
|
+
device uchar * dst,
|
4789
|
+
constant int64_t & nbi1,
|
4790
|
+
constant int64_t & ne00,
|
4791
|
+
constant int64_t & ne01,
|
4792
|
+
constant int64_t & ne02,
|
4793
|
+
constant uint64_t & nb00,
|
4794
|
+
constant uint64_t & nb01,
|
4795
|
+
constant uint64_t & nb02,
|
4796
|
+
constant int64_t & ne10,
|
4797
|
+
constant int64_t & ne11,
|
4798
|
+
constant int64_t & ne12,
|
4799
|
+
constant int64_t & ne13,
|
4800
|
+
constant uint64_t & nb10,
|
4801
|
+
constant uint64_t & nb11,
|
4802
|
+
constant uint64_t & nb12,
|
4803
|
+
constant int64_t & ne0,
|
4804
|
+
constant int64_t & ne1,
|
4805
|
+
constant int64_t & nb1,
|
4806
|
+
constant uint & r2,
|
4807
|
+
constant uint & r3,
|
4808
|
+
constant int & idx,
|
4809
|
+
device const char * src00,
|
4810
|
+
device const char * src01,
|
4811
|
+
device const char * src02,
|
4812
|
+
device const char * src03,
|
4813
|
+
device const char * src04,
|
4814
|
+
device const char * src05,
|
4815
|
+
device const char * src06,
|
4816
|
+
device const char * src07,
|
4817
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4818
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4819
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4820
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4821
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4822
|
+
|
4823
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4824
|
+
|
4825
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4826
|
+
|
4827
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4828
|
+
|
4829
|
+
kernel_mul_mv_q5_K_f32_impl(
|
4830
|
+
src0[id],
|
4831
|
+
(device const float *) (src1 + bid*nb11),
|
4832
|
+
(device float *) ( dst + bid*nb1),
|
4833
|
+
ne00,
|
4834
|
+
ne01,
|
4835
|
+
ne02,
|
4836
|
+
ne10,
|
4837
|
+
ne12,
|
4838
|
+
ne0,
|
4839
|
+
ne1,
|
4840
|
+
r2,
|
4841
|
+
r3,
|
4842
|
+
tgpig,
|
4843
|
+
tiisg,
|
4844
|
+
sgitg);
|
4845
|
+
}
|
4846
|
+
|
4847
|
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
4848
|
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
4849
|
+
device const char * ids,
|
4850
|
+
device const char * src1,
|
4851
|
+
device uchar * dst,
|
4852
|
+
constant int64_t & nbi1,
|
4853
|
+
constant int64_t & ne00,
|
4854
|
+
constant int64_t & ne01,
|
4855
|
+
constant int64_t & ne02,
|
4856
|
+
constant uint64_t & nb00,
|
4857
|
+
constant uint64_t & nb01,
|
4858
|
+
constant uint64_t & nb02,
|
4859
|
+
constant int64_t & ne10,
|
4860
|
+
constant int64_t & ne11,
|
4861
|
+
constant int64_t & ne12,
|
4862
|
+
constant int64_t & ne13,
|
4863
|
+
constant uint64_t & nb10,
|
4864
|
+
constant uint64_t & nb11,
|
4865
|
+
constant uint64_t & nb12,
|
4866
|
+
constant int64_t & ne0,
|
4867
|
+
constant int64_t & ne1,
|
4868
|
+
constant int64_t & nb1,
|
4869
|
+
constant uint & r2,
|
4870
|
+
constant uint & r3,
|
4871
|
+
constant int & idx,
|
4872
|
+
device const char * src00,
|
4873
|
+
device const char * src01,
|
4874
|
+
device const char * src02,
|
4875
|
+
device const char * src03,
|
4876
|
+
device const char * src04,
|
4877
|
+
device const char * src05,
|
4878
|
+
device const char * src06,
|
4879
|
+
device const char * src07,
|
4880
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4882
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4883
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4884
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4885
|
+
|
4886
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4887
|
+
|
4888
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4889
|
+
|
4890
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4891
|
+
|
4892
|
+
kernel_mul_mv_q6_K_f32_impl(
|
4893
|
+
src0[id],
|
4894
|
+
(device const float *) (src1 + bid*nb11),
|
4895
|
+
(device float *) ( dst + bid*nb1),
|
4896
|
+
ne00,
|
4897
|
+
ne01,
|
4898
|
+
ne02,
|
4899
|
+
ne10,
|
4900
|
+
ne12,
|
4901
|
+
ne0,
|
4902
|
+
ne1,
|
4903
|
+
r2,
|
4904
|
+
r3,
|
4905
|
+
tgpig,
|
4906
|
+
tiisg,
|
4907
|
+
sgitg);
|
4908
|
+
}
|