llama_cpp 0.9.5 → 0.10.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/ext/llama_cpp/llama_cpp.cpp +123 -15
- data/ext/llama_cpp/src/ggml-alloc.c +42 -7
- data/ext/llama_cpp/src/ggml-alloc.h +8 -1
- data/ext/llama_cpp/src/ggml-backend-impl.h +46 -21
- data/ext/llama_cpp/src/ggml-backend.c +563 -156
- data/ext/llama_cpp/src/ggml-backend.h +62 -17
- data/ext/llama_cpp/src/ggml-cuda.cu +1796 -413
- data/ext/llama_cpp/src/ggml-cuda.h +9 -1
- data/ext/llama_cpp/src/ggml-impl.h +1 -1
- data/ext/llama_cpp/src/ggml-metal.h +6 -0
- data/ext/llama_cpp/src/ggml-metal.m +998 -169
- data/ext/llama_cpp/src/ggml-metal.metal +2253 -274
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +634 -248
- data/ext/llama_cpp/src/ggml.h +81 -15
- data/ext/llama_cpp/src/llama.cpp +932 -352
- data/ext/llama_cpp/src/llama.h +28 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +22 -2
- metadata +2 -2
@@ -3,6 +3,8 @@
|
|
3
3
|
using namespace metal;
|
4
4
|
|
5
5
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
6
|
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
7
|
+
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
6
8
|
|
7
9
|
#define QK4_0 32
|
8
10
|
#define QR4_0 2
|
@@ -41,8 +43,13 @@ typedef struct {
|
|
41
43
|
|
42
44
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
43
45
|
|
44
|
-
|
45
|
-
|
46
|
+
enum ggml_sort_order {
|
47
|
+
GGML_SORT_ASC,
|
48
|
+
GGML_SORT_DESC,
|
49
|
+
};
|
50
|
+
|
51
|
+
// general-purpose kernel for addition, multiplication and division of two tensors
|
52
|
+
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
46
53
|
// cons: not very efficient
|
47
54
|
kernel void kernel_add(
|
48
55
|
device const char * src0,
|
@@ -72,6 +79,7 @@ kernel void kernel_add(
|
|
72
79
|
constant int64_t & nb1,
|
73
80
|
constant int64_t & nb2,
|
74
81
|
constant int64_t & nb3,
|
82
|
+
constant int64_t & offs,
|
75
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
76
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
77
85
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -83,16 +91,111 @@ kernel void kernel_add(
|
|
83
91
|
const int64_t i12 = i02 % ne12;
|
84
92
|
const int64_t i11 = i01 % ne11;
|
85
93
|
|
86
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 +
|
87
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11
|
88
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 +
|
94
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
95
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
96
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
97
|
+
|
98
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
99
|
+
const int i10 = i0 % ne10;
|
100
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
kernel void kernel_mul(
|
105
|
+
device const char * src0,
|
106
|
+
device const char * src1,
|
107
|
+
device char * dst,
|
108
|
+
constant int64_t & ne00,
|
109
|
+
constant int64_t & ne01,
|
110
|
+
constant int64_t & ne02,
|
111
|
+
constant int64_t & ne03,
|
112
|
+
constant int64_t & nb00,
|
113
|
+
constant int64_t & nb01,
|
114
|
+
constant int64_t & nb02,
|
115
|
+
constant int64_t & nb03,
|
116
|
+
constant int64_t & ne10,
|
117
|
+
constant int64_t & ne11,
|
118
|
+
constant int64_t & ne12,
|
119
|
+
constant int64_t & ne13,
|
120
|
+
constant int64_t & nb10,
|
121
|
+
constant int64_t & nb11,
|
122
|
+
constant int64_t & nb12,
|
123
|
+
constant int64_t & nb13,
|
124
|
+
constant int64_t & ne0,
|
125
|
+
constant int64_t & ne1,
|
126
|
+
constant int64_t & ne2,
|
127
|
+
constant int64_t & ne3,
|
128
|
+
constant int64_t & nb0,
|
129
|
+
constant int64_t & nb1,
|
130
|
+
constant int64_t & nb2,
|
131
|
+
constant int64_t & nb3,
|
132
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
133
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
134
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
135
|
+
const int64_t i03 = tgpig.z;
|
136
|
+
const int64_t i02 = tgpig.y;
|
137
|
+
const int64_t i01 = tgpig.x;
|
138
|
+
|
139
|
+
const int64_t i13 = i03 % ne13;
|
140
|
+
const int64_t i12 = i02 % ne12;
|
141
|
+
const int64_t i11 = i01 % ne11;
|
142
|
+
|
143
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
144
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
145
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
89
146
|
|
90
147
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
91
|
-
|
148
|
+
const int i10 = i0 % ne10;
|
149
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
|
150
|
+
}
|
151
|
+
}
|
152
|
+
|
153
|
+
kernel void kernel_div(
|
154
|
+
device const char * src0,
|
155
|
+
device const char * src1,
|
156
|
+
device char * dst,
|
157
|
+
constant int64_t & ne00,
|
158
|
+
constant int64_t & ne01,
|
159
|
+
constant int64_t & ne02,
|
160
|
+
constant int64_t & ne03,
|
161
|
+
constant int64_t & nb00,
|
162
|
+
constant int64_t & nb01,
|
163
|
+
constant int64_t & nb02,
|
164
|
+
constant int64_t & nb03,
|
165
|
+
constant int64_t & ne10,
|
166
|
+
constant int64_t & ne11,
|
167
|
+
constant int64_t & ne12,
|
168
|
+
constant int64_t & ne13,
|
169
|
+
constant int64_t & nb10,
|
170
|
+
constant int64_t & nb11,
|
171
|
+
constant int64_t & nb12,
|
172
|
+
constant int64_t & nb13,
|
173
|
+
constant int64_t & ne0,
|
174
|
+
constant int64_t & ne1,
|
175
|
+
constant int64_t & ne2,
|
176
|
+
constant int64_t & ne3,
|
177
|
+
constant int64_t & nb0,
|
178
|
+
constant int64_t & nb1,
|
179
|
+
constant int64_t & nb2,
|
180
|
+
constant int64_t & nb3,
|
181
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
182
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
183
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
184
|
+
const int64_t i03 = tgpig.z;
|
185
|
+
const int64_t i02 = tgpig.y;
|
186
|
+
const int64_t i01 = tgpig.x;
|
187
|
+
|
188
|
+
const int64_t i13 = i03 % ne13;
|
189
|
+
const int64_t i12 = i02 % ne12;
|
190
|
+
const int64_t i11 = i01 % ne11;
|
92
191
|
|
93
|
-
|
94
|
-
|
95
|
-
|
192
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
193
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
194
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
195
|
+
|
196
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
197
|
+
const int i10 = i0 % ne10;
|
198
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
|
96
199
|
}
|
97
200
|
}
|
98
201
|
|
@@ -102,28 +205,27 @@ kernel void kernel_add_row(
|
|
102
205
|
device const float4 * src0,
|
103
206
|
device const float4 * src1,
|
104
207
|
device float4 * dst,
|
105
|
-
constant int64_t & nb [[buffer(
|
208
|
+
constant int64_t & nb [[buffer(28)]],
|
106
209
|
uint tpig[[thread_position_in_grid]]) {
|
107
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
108
211
|
}
|
109
212
|
|
110
|
-
kernel void
|
213
|
+
kernel void kernel_mul_row(
|
111
214
|
device const float4 * src0,
|
112
215
|
device const float4 * src1,
|
113
216
|
device float4 * dst,
|
217
|
+
constant int64_t & nb [[buffer(28)]],
|
114
218
|
uint tpig[[thread_position_in_grid]]) {
|
115
|
-
dst[tpig] = src0[tpig] * src1[tpig];
|
219
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
116
220
|
}
|
117
221
|
|
118
|
-
|
119
|
-
// broadcast src1 into src0
|
120
|
-
kernel void kernel_mul_row(
|
222
|
+
kernel void kernel_div_row(
|
121
223
|
device const float4 * src0,
|
122
224
|
device const float4 * src1,
|
123
225
|
device float4 * dst,
|
124
|
-
constant int64_t & nb,
|
226
|
+
constant int64_t & nb [[buffer(28)]],
|
125
227
|
uint tpig[[thread_position_in_grid]]) {
|
126
|
-
dst[tpig] = src0[tpig]
|
228
|
+
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
127
229
|
}
|
128
230
|
|
129
231
|
kernel void kernel_scale(
|
@@ -142,14 +244,6 @@ kernel void kernel_scale_4(
|
|
142
244
|
dst[tpig] = src0[tpig] * scale;
|
143
245
|
}
|
144
246
|
|
145
|
-
kernel void kernel_silu(
|
146
|
-
device const float4 * src0,
|
147
|
-
device float4 * dst,
|
148
|
-
uint tpig[[thread_position_in_grid]]) {
|
149
|
-
device const float4 & x = src0[tpig];
|
150
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
151
|
-
}
|
152
|
-
|
153
247
|
kernel void kernel_relu(
|
154
248
|
device const float * src0,
|
155
249
|
device float * dst,
|
@@ -157,15 +251,17 @@ kernel void kernel_relu(
|
|
157
251
|
dst[tpig] = max(0.0f, src0[tpig]);
|
158
252
|
}
|
159
253
|
|
160
|
-
kernel void
|
254
|
+
kernel void kernel_tanh(
|
161
255
|
device const float * src0,
|
162
256
|
device float * dst,
|
163
257
|
uint tpig[[thread_position_in_grid]]) {
|
164
|
-
|
258
|
+
device const float & x = src0[tpig];
|
259
|
+
dst[tpig] = precise::tanh(x);
|
165
260
|
}
|
166
261
|
|
167
|
-
constant float GELU_COEF_A
|
168
|
-
constant float
|
262
|
+
constant float GELU_COEF_A = 0.044715f;
|
263
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
264
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
169
265
|
|
170
266
|
kernel void kernel_gelu(
|
171
267
|
device const float4 * src0,
|
@@ -180,6 +276,78 @@ kernel void kernel_gelu(
|
|
180
276
|
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
181
277
|
}
|
182
278
|
|
279
|
+
kernel void kernel_gelu_quick(
|
280
|
+
device const float4 * src0,
|
281
|
+
device float4 * dst,
|
282
|
+
uint tpig[[thread_position_in_grid]]) {
|
283
|
+
device const float4 & x = src0[tpig];
|
284
|
+
|
285
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
286
|
+
}
|
287
|
+
|
288
|
+
kernel void kernel_silu(
|
289
|
+
device const float4 * src0,
|
290
|
+
device float4 * dst,
|
291
|
+
uint tpig[[thread_position_in_grid]]) {
|
292
|
+
device const float4 & x = src0[tpig];
|
293
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
294
|
+
}
|
295
|
+
|
296
|
+
kernel void kernel_sqr(
|
297
|
+
device const float * src0,
|
298
|
+
device float * dst,
|
299
|
+
uint tpig[[thread_position_in_grid]]) {
|
300
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
301
|
+
}
|
302
|
+
|
303
|
+
kernel void kernel_sum_rows(
|
304
|
+
device const float * src0,
|
305
|
+
device float * dst,
|
306
|
+
constant int64_t & ne00,
|
307
|
+
constant int64_t & ne01,
|
308
|
+
constant int64_t & ne02,
|
309
|
+
constant int64_t & ne03,
|
310
|
+
constant int64_t & nb00,
|
311
|
+
constant int64_t & nb01,
|
312
|
+
constant int64_t & nb02,
|
313
|
+
constant int64_t & nb03,
|
314
|
+
constant int64_t & ne10,
|
315
|
+
constant int64_t & ne11,
|
316
|
+
constant int64_t & ne12,
|
317
|
+
constant int64_t & ne13,
|
318
|
+
constant int64_t & nb10,
|
319
|
+
constant int64_t & nb11,
|
320
|
+
constant int64_t & nb12,
|
321
|
+
constant int64_t & nb13,
|
322
|
+
constant int64_t & ne0,
|
323
|
+
constant int64_t & ne1,
|
324
|
+
constant int64_t & ne2,
|
325
|
+
constant int64_t & ne3,
|
326
|
+
constant int64_t & nb0,
|
327
|
+
constant int64_t & nb1,
|
328
|
+
constant int64_t & nb2,
|
329
|
+
constant int64_t & nb3,
|
330
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
331
|
+
int64_t i3 = tpig.z;
|
332
|
+
int64_t i2 = tpig.y;
|
333
|
+
int64_t i1 = tpig.x;
|
334
|
+
|
335
|
+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
336
|
+
return;
|
337
|
+
}
|
338
|
+
|
339
|
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
340
|
+
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
341
|
+
|
342
|
+
float row_sum = 0;
|
343
|
+
|
344
|
+
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
345
|
+
row_sum += src_row[i0];
|
346
|
+
}
|
347
|
+
|
348
|
+
dst_row[0] = row_sum;
|
349
|
+
}
|
350
|
+
|
183
351
|
kernel void kernel_soft_max(
|
184
352
|
device const float * src0,
|
185
353
|
device const float * src1,
|
@@ -198,9 +366,9 @@ kernel void kernel_soft_max(
|
|
198
366
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
199
367
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
200
368
|
|
201
|
-
device const float * psrc0 =
|
202
|
-
device const float * pmask = src1 ? src1
|
203
|
-
device float * pdst =
|
369
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
370
|
+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
371
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
204
372
|
|
205
373
|
// parallel max
|
206
374
|
float lmax = -INFINITY;
|
@@ -236,7 +404,12 @@ kernel void kernel_soft_max(
|
|
236
404
|
pdst[i00] = exp_psrc0;
|
237
405
|
}
|
238
406
|
|
407
|
+
// This barrier fixes a failing test
|
408
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
409
|
+
threadgroup_barrier(mem_flags::mem_none);
|
410
|
+
|
239
411
|
float sum = simd_sum(lsum);
|
412
|
+
|
240
413
|
if (ntg > N_SIMDWIDTH) {
|
241
414
|
if (sgitg == 0) {
|
242
415
|
buf[tiisg] = 0.0f;
|
@@ -279,9 +452,9 @@ kernel void kernel_soft_max_4(
|
|
279
452
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
280
453
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
281
454
|
|
282
|
-
device const float4 * psrc4 =
|
283
|
-
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
284
|
-
device float4 * pdst4 =
|
455
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
456
|
+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
457
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
285
458
|
|
286
459
|
// parallel max
|
287
460
|
float4 lmax4 = -INFINITY;
|
@@ -319,7 +492,13 @@ kernel void kernel_soft_max_4(
|
|
319
492
|
}
|
320
493
|
|
321
494
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
495
|
+
|
496
|
+
// This barrier fixes a failing test
|
497
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
498
|
+
threadgroup_barrier(mem_flags::mem_none);
|
499
|
+
|
322
500
|
float sum = simd_sum(lsum);
|
501
|
+
|
323
502
|
if (ntg > N_SIMDWIDTH) {
|
324
503
|
if (sgitg == 0) {
|
325
504
|
buf[tiisg] = 0.0f;
|
@@ -490,6 +669,94 @@ kernel void kernel_rms_norm(
|
|
490
669
|
}
|
491
670
|
}
|
492
671
|
|
672
|
+
kernel void kernel_group_norm(
|
673
|
+
device const float * src0,
|
674
|
+
device float * dst,
|
675
|
+
constant int64_t & ne00,
|
676
|
+
constant int64_t & ne01,
|
677
|
+
constant int64_t & ne02,
|
678
|
+
constant uint64_t & nb00,
|
679
|
+
constant uint64_t & nb01,
|
680
|
+
constant uint64_t & nb02,
|
681
|
+
constant int32_t & n_groups,
|
682
|
+
constant float & eps,
|
683
|
+
threadgroup float * buf [[threadgroup(0)]],
|
684
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
685
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
686
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
687
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
688
|
+
uint ntg[[threads_per_threadgroup]]) {
|
689
|
+
const int64_t ne = ne00*ne01*ne02;
|
690
|
+
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
691
|
+
|
692
|
+
int start = tgpig * gs;
|
693
|
+
int end = start + gs;
|
694
|
+
|
695
|
+
start += tpitg;
|
696
|
+
|
697
|
+
if (end >= ne) {
|
698
|
+
end = ne;
|
699
|
+
}
|
700
|
+
|
701
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
702
|
+
|
703
|
+
for (int j = start; j < end; j += ntg) {
|
704
|
+
tmp += src0[j];
|
705
|
+
}
|
706
|
+
|
707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
708
|
+
tmp = simd_sum(tmp);
|
709
|
+
if (ntg > N_SIMDWIDTH) {
|
710
|
+
if (sgitg == 0) {
|
711
|
+
buf[tiisg] = 0.0f;
|
712
|
+
}
|
713
|
+
|
714
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
715
|
+
|
716
|
+
if (tiisg == 0) {
|
717
|
+
buf[sgitg] = tmp;
|
718
|
+
}
|
719
|
+
|
720
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
721
|
+
|
722
|
+
tmp = buf[tiisg];
|
723
|
+
tmp = simd_sum(tmp);
|
724
|
+
}
|
725
|
+
|
726
|
+
const float mean = tmp / gs;
|
727
|
+
tmp = 0.0f;
|
728
|
+
|
729
|
+
for (int j = start; j < end; j += ntg) {
|
730
|
+
float xi = src0[j] - mean;
|
731
|
+
dst[j] = xi;
|
732
|
+
tmp += xi * xi;
|
733
|
+
}
|
734
|
+
|
735
|
+
tmp = simd_sum(tmp);
|
736
|
+
if (ntg > N_SIMDWIDTH) {
|
737
|
+
if (sgitg == 0) {
|
738
|
+
buf[tiisg] = 0.0f;
|
739
|
+
}
|
740
|
+
|
741
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
742
|
+
|
743
|
+
if (tiisg == 0) {
|
744
|
+
buf[sgitg] = tmp;
|
745
|
+
}
|
746
|
+
|
747
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
748
|
+
|
749
|
+
tmp = buf[tiisg];
|
750
|
+
tmp = simd_sum(tmp);
|
751
|
+
}
|
752
|
+
|
753
|
+
const float variance = tmp / gs;
|
754
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
755
|
+
for (int j = start; j < end; j += ntg) {
|
756
|
+
dst[j] *= scale;
|
757
|
+
}
|
758
|
+
}
|
759
|
+
|
493
760
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
494
761
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
495
762
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
@@ -582,9 +849,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
582
849
|
// giard against the number of rows not being divisible by
|
583
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
584
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
585
|
-
void
|
586
|
-
|
587
|
-
|
852
|
+
void mul_vec_q_n_f32_impl(
|
853
|
+
device const void * src0,
|
854
|
+
device const float * src1,
|
855
|
+
device float * dst,
|
856
|
+
int64_t ne00,
|
857
|
+
int64_t ne01,
|
858
|
+
int64_t ne02,
|
859
|
+
int64_t ne10,
|
860
|
+
int64_t ne12,
|
861
|
+
int64_t ne0,
|
862
|
+
int64_t ne1,
|
863
|
+
uint r2,
|
864
|
+
uint r3,
|
865
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
588
866
|
const int nb = ne00/QK4_0;
|
589
867
|
|
590
868
|
const int r0 = tgpig.x;
|
@@ -593,7 +871,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
593
871
|
|
594
872
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
595
873
|
|
596
|
-
const uint
|
874
|
+
const uint i12 = im%ne12;
|
875
|
+
const uint i13 = im/ne12;
|
876
|
+
|
877
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
597
878
|
|
598
879
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
599
880
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
@@ -643,13 +924,14 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
643
924
|
constant int64_t & ne02[[buffer(5)]],
|
644
925
|
constant int64_t & ne10[[buffer(9)]],
|
645
926
|
constant int64_t & ne12[[buffer(11)]],
|
646
|
-
constant int64_t & ne0[[buffer(15)]],
|
647
|
-
constant int64_t & ne1[[buffer(16)]],
|
648
|
-
constant uint &
|
927
|
+
constant int64_t & ne0 [[buffer(15)]],
|
928
|
+
constant int64_t & ne1 [[buffer(16)]],
|
929
|
+
constant uint & r2 [[buffer(17)]],
|
930
|
+
constant uint & r3 [[buffer(18)]],
|
649
931
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
650
932
|
uint tiisg[[thread_index_in_simdgroup]],
|
651
933
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
652
|
-
|
934
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
653
935
|
}
|
654
936
|
|
655
937
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -661,13 +943,14 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
661
943
|
constant int64_t & ne02[[buffer(5)]],
|
662
944
|
constant int64_t & ne10[[buffer(9)]],
|
663
945
|
constant int64_t & ne12[[buffer(11)]],
|
664
|
-
constant int64_t & ne0[[buffer(15)]],
|
665
|
-
constant int64_t & ne1[[buffer(16)]],
|
666
|
-
constant uint &
|
946
|
+
constant int64_t & ne0 [[buffer(15)]],
|
947
|
+
constant int64_t & ne1 [[buffer(16)]],
|
948
|
+
constant uint & r2 [[buffer(17)]],
|
949
|
+
constant uint & r3 [[buffer(18)]],
|
667
950
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
668
951
|
uint tiisg[[thread_index_in_simdgroup]],
|
669
952
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
670
|
-
|
953
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
671
954
|
}
|
672
955
|
|
673
956
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -679,13 +962,14 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
679
962
|
constant int64_t & ne02[[buffer(5)]],
|
680
963
|
constant int64_t & ne10[[buffer(9)]],
|
681
964
|
constant int64_t & ne12[[buffer(11)]],
|
682
|
-
constant int64_t & ne0[[buffer(15)]],
|
683
|
-
constant int64_t & ne1[[buffer(16)]],
|
684
|
-
constant uint &
|
965
|
+
constant int64_t & ne0 [[buffer(15)]],
|
966
|
+
constant int64_t & ne1 [[buffer(16)]],
|
967
|
+
constant uint & r2 [[buffer(17)]],
|
968
|
+
constant uint & r3 [[buffer(18)]],
|
685
969
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
686
970
|
uint tiisg[[thread_index_in_simdgroup]],
|
687
971
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
688
|
-
|
972
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
689
973
|
}
|
690
974
|
|
691
975
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -697,33 +981,35 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
697
981
|
constant int64_t & ne02[[buffer(5)]],
|
698
982
|
constant int64_t & ne10[[buffer(9)]],
|
699
983
|
constant int64_t & ne12[[buffer(11)]],
|
700
|
-
constant int64_t & ne0[[buffer(15)]],
|
701
|
-
constant int64_t & ne1[[buffer(16)]],
|
702
|
-
constant uint &
|
984
|
+
constant int64_t & ne0 [[buffer(15)]],
|
985
|
+
constant int64_t & ne1 [[buffer(16)]],
|
986
|
+
constant uint & r2 [[buffer(17)]],
|
987
|
+
constant uint & r3 [[buffer(18)]],
|
703
988
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
704
989
|
uint tiisg[[thread_index_in_simdgroup]],
|
705
990
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
706
|
-
|
991
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
707
992
|
}
|
708
993
|
|
709
994
|
|
710
995
|
#define NB_Q8_0 8
|
711
996
|
|
712
|
-
|
997
|
+
void kernel_mul_mv_q8_0_f32_impl(
|
713
998
|
device const void * src0,
|
714
999
|
device const float * src1,
|
715
1000
|
device float * dst,
|
716
1001
|
constant int64_t & ne00,
|
717
|
-
constant int64_t & ne01
|
718
|
-
constant int64_t & ne02
|
719
|
-
constant int64_t & ne10
|
720
|
-
constant int64_t & ne12
|
721
|
-
constant int64_t & ne0
|
722
|
-
constant int64_t & ne1
|
723
|
-
constant uint &
|
1002
|
+
constant int64_t & ne01,
|
1003
|
+
constant int64_t & ne02,
|
1004
|
+
constant int64_t & ne10,
|
1005
|
+
constant int64_t & ne12,
|
1006
|
+
constant int64_t & ne0,
|
1007
|
+
constant int64_t & ne1,
|
1008
|
+
constant uint & r2,
|
1009
|
+
constant uint & r3,
|
724
1010
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
725
|
-
uint
|
726
|
-
uint
|
1011
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1012
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
727
1013
|
const int nr = N_DST;
|
728
1014
|
const int nsg = N_SIMDGROUP;
|
729
1015
|
const int nw = N_SIMDWIDTH;
|
@@ -732,8 +1018,14 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
732
1018
|
const int r0 = tgpig.x;
|
733
1019
|
const int r1 = tgpig.y;
|
734
1020
|
const int im = tgpig.z;
|
1021
|
+
|
735
1022
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
736
|
-
|
1023
|
+
|
1024
|
+
const uint i12 = im%ne12;
|
1025
|
+
const uint i13 = im/ne12;
|
1026
|
+
|
1027
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
1028
|
+
|
737
1029
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
738
1030
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
739
1031
|
|
@@ -771,11 +1063,31 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
771
1063
|
}
|
772
1064
|
}
|
773
1065
|
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
device const
|
778
|
-
device
|
1066
|
+
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
1067
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
1068
|
+
device const void * src0,
|
1069
|
+
device const float * src1,
|
1070
|
+
device float * dst,
|
1071
|
+
constant int64_t & ne00,
|
1072
|
+
constant int64_t & ne01,
|
1073
|
+
constant int64_t & ne02,
|
1074
|
+
constant int64_t & ne10,
|
1075
|
+
constant int64_t & ne12,
|
1076
|
+
constant int64_t & ne0,
|
1077
|
+
constant int64_t & ne1,
|
1078
|
+
constant uint & r2 [[buffer(17)]],
|
1079
|
+
constant uint & r3 [[buffer(18)]],
|
1080
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
1082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1083
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
1084
|
+
}
|
1085
|
+
|
1086
|
+
#define N_F32_F32 4
|
1087
|
+
|
1088
|
+
void kernel_mul_mv_f32_f32_impl(
|
1089
|
+
device const char * src0,
|
1090
|
+
device const char * src1,
|
779
1091
|
device float * dst,
|
780
1092
|
constant int64_t & ne00,
|
781
1093
|
constant int64_t & ne01,
|
@@ -791,6 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
791
1103
|
constant uint64_t & nb12,
|
792
1104
|
constant int64_t & ne0,
|
793
1105
|
constant int64_t & ne1,
|
1106
|
+
constant uint & r2,
|
1107
|
+
constant uint & r3,
|
794
1108
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
795
1109
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
796
1110
|
|
@@ -798,7 +1112,12 @@ kernel void kernel_mul_mv_f32_f32(
|
|
798
1112
|
const int64_t rb = tgpig.y*N_F32_F32;
|
799
1113
|
const int64_t im = tgpig.z;
|
800
1114
|
|
801
|
-
|
1115
|
+
const uint i12 = im%ne12;
|
1116
|
+
const uint i13 = im/ne12;
|
1117
|
+
|
1118
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1119
|
+
|
1120
|
+
device const float * x = (device const float *) (src0 + offset0);
|
802
1121
|
|
803
1122
|
if (ne00 < 128) {
|
804
1123
|
for (int row = 0; row < N_F32_F32; ++row) {
|
@@ -844,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
|
|
844
1163
|
}
|
845
1164
|
}
|
846
1165
|
|
1166
|
+
[[host_name("kernel_mul_mv_f32_f32")]]
|
1167
|
+
kernel void kernel_mul_mv_f32_f32(
|
1168
|
+
device const char * src0,
|
1169
|
+
device const char * src1,
|
1170
|
+
device float * dst,
|
1171
|
+
constant int64_t & ne00,
|
1172
|
+
constant int64_t & ne01,
|
1173
|
+
constant int64_t & ne02,
|
1174
|
+
constant uint64_t & nb00,
|
1175
|
+
constant uint64_t & nb01,
|
1176
|
+
constant uint64_t & nb02,
|
1177
|
+
constant int64_t & ne10,
|
1178
|
+
constant int64_t & ne11,
|
1179
|
+
constant int64_t & ne12,
|
1180
|
+
constant uint64_t & nb10,
|
1181
|
+
constant uint64_t & nb11,
|
1182
|
+
constant uint64_t & nb12,
|
1183
|
+
constant int64_t & ne0,
|
1184
|
+
constant int64_t & ne1,
|
1185
|
+
constant uint & r2 [[buffer(17)]],
|
1186
|
+
constant uint & r3 [[buffer(18)]],
|
1187
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1188
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1189
|
+
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1190
|
+
}
|
1191
|
+
|
847
1192
|
#define N_F16_F16 4
|
848
1193
|
|
849
1194
|
kernel void kernel_mul_mv_f16_f16(
|
@@ -864,6 +1209,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
864
1209
|
constant uint64_t & nb12,
|
865
1210
|
constant int64_t & ne0,
|
866
1211
|
constant int64_t & ne1,
|
1212
|
+
constant uint & r2 [[buffer(17)]],
|
1213
|
+
constant uint & r3 [[buffer(18)]],
|
867
1214
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
868
1215
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
869
1216
|
|
@@ -871,7 +1218,12 @@ kernel void kernel_mul_mv_f16_f16(
|
|
871
1218
|
const int64_t rb = tgpig.y*N_F16_F16;
|
872
1219
|
const int64_t im = tgpig.z;
|
873
1220
|
|
874
|
-
|
1221
|
+
const uint i12 = im%ne12;
|
1222
|
+
const uint i13 = im/ne12;
|
1223
|
+
|
1224
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1225
|
+
|
1226
|
+
device const half * x = (device const half *) (src0 + offset0);
|
875
1227
|
|
876
1228
|
if (ne00 < 128) {
|
877
1229
|
for (int row = 0; row < N_F16_F16; ++row) {
|
@@ -917,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
|
|
917
1269
|
}
|
918
1270
|
}
|
919
1271
|
|
920
|
-
|
1272
|
+
void kernel_mul_mv_f16_f32_1row_impl(
|
921
1273
|
device const char * src0,
|
922
1274
|
device const char * src1,
|
923
1275
|
device float * dst,
|
@@ -935,6 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
935
1287
|
constant uint64_t & nb12,
|
936
1288
|
constant int64_t & ne0,
|
937
1289
|
constant int64_t & ne1,
|
1290
|
+
constant uint & r2,
|
1291
|
+
constant uint & r3,
|
938
1292
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
939
1293
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
940
1294
|
|
@@ -942,7 +1296,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
942
1296
|
const int64_t r1 = tgpig.y;
|
943
1297
|
const int64_t im = tgpig.z;
|
944
1298
|
|
945
|
-
|
1299
|
+
const uint i12 = im%ne12;
|
1300
|
+
const uint i13 = im/ne12;
|
1301
|
+
|
1302
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1303
|
+
|
1304
|
+
device const half * x = (device const half *) (src0 + offset0);
|
946
1305
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
947
1306
|
|
948
1307
|
float sumf = 0;
|
@@ -966,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
966
1325
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
967
1326
|
}
|
968
1327
|
}
|
1328
|
+
}
|
969
1329
|
|
1330
|
+
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
1331
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
1332
|
+
device const char * src0,
|
1333
|
+
device const char * src1,
|
1334
|
+
device float * dst,
|
1335
|
+
constant int64_t & ne00,
|
1336
|
+
constant int64_t & ne01,
|
1337
|
+
constant int64_t & ne02,
|
1338
|
+
constant uint64_t & nb00,
|
1339
|
+
constant uint64_t & nb01,
|
1340
|
+
constant uint64_t & nb02,
|
1341
|
+
constant int64_t & ne10,
|
1342
|
+
constant int64_t & ne11,
|
1343
|
+
constant int64_t & ne12,
|
1344
|
+
constant uint64_t & nb10,
|
1345
|
+
constant uint64_t & nb11,
|
1346
|
+
constant uint64_t & nb12,
|
1347
|
+
constant int64_t & ne0,
|
1348
|
+
constant int64_t & ne1,
|
1349
|
+
constant uint & r2 [[buffer(17)]],
|
1350
|
+
constant uint & r3 [[buffer(18)]],
|
1351
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1353
|
+
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
970
1354
|
}
|
971
1355
|
|
972
1356
|
#define N_F16_F32 4
|
973
1357
|
|
974
|
-
|
1358
|
+
void kernel_mul_mv_f16_f32_impl(
|
975
1359
|
device const char * src0,
|
976
1360
|
device const char * src1,
|
977
1361
|
device float * dst,
|
@@ -989,6 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
989
1373
|
constant uint64_t & nb12,
|
990
1374
|
constant int64_t & ne0,
|
991
1375
|
constant int64_t & ne1,
|
1376
|
+
constant uint & r2,
|
1377
|
+
constant uint & r3,
|
992
1378
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
993
1379
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
994
1380
|
|
@@ -996,7 +1382,12 @@ kernel void kernel_mul_mv_f16_f32(
|
|
996
1382
|
const int64_t rb = tgpig.y*N_F16_F32;
|
997
1383
|
const int64_t im = tgpig.z;
|
998
1384
|
|
999
|
-
|
1385
|
+
const uint i12 = im%ne12;
|
1386
|
+
const uint i13 = im/ne12;
|
1387
|
+
|
1388
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1389
|
+
|
1390
|
+
device const half * x = (device const half *) (src0 + offset0);
|
1000
1391
|
|
1001
1392
|
if (ne00 < 128) {
|
1002
1393
|
for (int row = 0; row < N_F16_F32; ++row) {
|
@@ -1042,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1042
1433
|
}
|
1043
1434
|
}
|
1044
1435
|
|
1436
|
+
[[host_name("kernel_mul_mv_f16_f32")]]
|
1437
|
+
kernel void kernel_mul_mv_f16_f32(
|
1438
|
+
device const char * src0,
|
1439
|
+
device const char * src1,
|
1440
|
+
device float * dst,
|
1441
|
+
constant int64_t & ne00,
|
1442
|
+
constant int64_t & ne01,
|
1443
|
+
constant int64_t & ne02,
|
1444
|
+
constant uint64_t & nb00,
|
1445
|
+
constant uint64_t & nb01,
|
1446
|
+
constant uint64_t & nb02,
|
1447
|
+
constant int64_t & ne10,
|
1448
|
+
constant int64_t & ne11,
|
1449
|
+
constant int64_t & ne12,
|
1450
|
+
constant uint64_t & nb10,
|
1451
|
+
constant uint64_t & nb11,
|
1452
|
+
constant uint64_t & nb12,
|
1453
|
+
constant int64_t & ne0,
|
1454
|
+
constant int64_t & ne1,
|
1455
|
+
constant uint & r2 [[buffer(17)]],
|
1456
|
+
constant uint & r3 [[buffer(18)]],
|
1457
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1458
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
1459
|
+
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
1460
|
+
}
|
1461
|
+
|
1045
1462
|
// Assumes row size (ne00) is a multiple of 4
|
1046
1463
|
kernel void kernel_mul_mv_f16_f32_l4(
|
1047
1464
|
device const char * src0,
|
@@ -1061,6 +1478,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1061
1478
|
constant uint64_t & nb12,
|
1062
1479
|
constant int64_t & ne0,
|
1063
1480
|
constant int64_t & ne1,
|
1481
|
+
constant uint & r2 [[buffer(17)]],
|
1482
|
+
constant uint & r3 [[buffer(18)]],
|
1064
1483
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1065
1484
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1066
1485
|
|
@@ -1068,7 +1487,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1068
1487
|
const int64_t r0 = tgpig.x;
|
1069
1488
|
const int64_t im = tgpig.z;
|
1070
1489
|
|
1071
|
-
|
1490
|
+
const uint i12 = im%ne12;
|
1491
|
+
const uint i13 = im/ne12;
|
1492
|
+
|
1493
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1494
|
+
|
1495
|
+
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
1072
1496
|
|
1073
1497
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
1074
1498
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
@@ -1120,17 +1544,21 @@ kernel void kernel_alibi_f32(
|
|
1120
1544
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1121
1545
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1122
1546
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1547
|
+
const int64_t k = i3*ne3 + i2;
|
1123
1548
|
|
1124
|
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1125
1549
|
float m_k;
|
1126
|
-
if (
|
1127
|
-
m_k = pow(m0,
|
1550
|
+
if (k < n_heads_log2_floor) {
|
1551
|
+
m_k = pow(m0, k + 1);
|
1128
1552
|
} else {
|
1129
|
-
m_k = pow(m1, 2 * (
|
1553
|
+
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
1130
1554
|
}
|
1555
|
+
|
1556
|
+
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
1557
|
+
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
1131
1558
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1132
|
-
|
1133
|
-
|
1559
|
+
const float src_v = *(device float *)(src_row + i00*nb00);
|
1560
|
+
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
1561
|
+
*dst_v = i00 * m_k + src_v;
|
1134
1562
|
}
|
1135
1563
|
}
|
1136
1564
|
|
@@ -1335,9 +1763,160 @@ kernel void kernel_im2col_f16(
|
|
1335
1763
|
}
|
1336
1764
|
}
|
1337
1765
|
|
1766
|
+
kernel void kernel_upscale_f32(
|
1767
|
+
device const char * src0,
|
1768
|
+
device char * dst,
|
1769
|
+
constant int64_t & ne00,
|
1770
|
+
constant int64_t & ne01,
|
1771
|
+
constant int64_t & ne02,
|
1772
|
+
constant int64_t & ne03,
|
1773
|
+
constant uint64_t & nb00,
|
1774
|
+
constant uint64_t & nb01,
|
1775
|
+
constant uint64_t & nb02,
|
1776
|
+
constant uint64_t & nb03,
|
1777
|
+
constant int64_t & ne0,
|
1778
|
+
constant int64_t & ne1,
|
1779
|
+
constant int64_t & ne2,
|
1780
|
+
constant int64_t & ne3,
|
1781
|
+
constant uint64_t & nb0,
|
1782
|
+
constant uint64_t & nb1,
|
1783
|
+
constant uint64_t & nb2,
|
1784
|
+
constant uint64_t & nb3,
|
1785
|
+
constant int32_t & sf,
|
1786
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1787
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1788
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1789
|
+
|
1790
|
+
const int64_t i3 = tgpig.z;
|
1791
|
+
const int64_t i2 = tgpig.y;
|
1792
|
+
const int64_t i1 = tgpig.x;
|
1793
|
+
|
1794
|
+
const int64_t i03 = i3;
|
1795
|
+
const int64_t i02 = i2;
|
1796
|
+
const int64_t i01 = i1/sf;
|
1797
|
+
|
1798
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1799
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1800
|
+
|
1801
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1802
|
+
dst_ptr[i0] = src0_ptr[i0/sf];
|
1803
|
+
}
|
1804
|
+
}
|
1805
|
+
|
1806
|
+
kernel void kernel_pad_f32(
|
1807
|
+
device const char * src0,
|
1808
|
+
device char * dst,
|
1809
|
+
constant int64_t & ne00,
|
1810
|
+
constant int64_t & ne01,
|
1811
|
+
constant int64_t & ne02,
|
1812
|
+
constant int64_t & ne03,
|
1813
|
+
constant uint64_t & nb00,
|
1814
|
+
constant uint64_t & nb01,
|
1815
|
+
constant uint64_t & nb02,
|
1816
|
+
constant uint64_t & nb03,
|
1817
|
+
constant int64_t & ne0,
|
1818
|
+
constant int64_t & ne1,
|
1819
|
+
constant int64_t & ne2,
|
1820
|
+
constant int64_t & ne3,
|
1821
|
+
constant uint64_t & nb0,
|
1822
|
+
constant uint64_t & nb1,
|
1823
|
+
constant uint64_t & nb2,
|
1824
|
+
constant uint64_t & nb3,
|
1825
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1826
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1827
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1828
|
+
|
1829
|
+
const int64_t i3 = tgpig.z;
|
1830
|
+
const int64_t i2 = tgpig.y;
|
1831
|
+
const int64_t i1 = tgpig.x;
|
1832
|
+
|
1833
|
+
const int64_t i03 = i3;
|
1834
|
+
const int64_t i02 = i2;
|
1835
|
+
const int64_t i01 = i1;
|
1836
|
+
|
1837
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1838
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1839
|
+
|
1840
|
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
1841
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1842
|
+
if (i0 < ne00) {
|
1843
|
+
dst_ptr[i0] = src0_ptr[i0];
|
1844
|
+
} else {
|
1845
|
+
dst_ptr[i0] = 0.0f;
|
1846
|
+
}
|
1847
|
+
}
|
1848
|
+
|
1849
|
+
return;
|
1850
|
+
}
|
1851
|
+
|
1852
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1853
|
+
dst_ptr[i0] = 0.0f;
|
1854
|
+
}
|
1855
|
+
}
|
1856
|
+
|
1857
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
1858
|
+
typedef void (argsort_t)(
|
1859
|
+
device const float * x,
|
1860
|
+
device int32_t * dst,
|
1861
|
+
constant int64_t & ncols,
|
1862
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1863
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1864
|
+
|
1865
|
+
template<ggml_sort_order order>
|
1866
|
+
kernel void kernel_argsort_f32_i32(
|
1867
|
+
device const float * x,
|
1868
|
+
device int32_t * dst,
|
1869
|
+
constant int64_t & ncols,
|
1870
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1871
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1872
|
+
// bitonic sort
|
1873
|
+
int col = tpitg[0];
|
1874
|
+
int row = tgpig[1];
|
1875
|
+
|
1876
|
+
if (col >= ncols) return;
|
1877
|
+
|
1878
|
+
device const float * x_row = x + row * ncols;
|
1879
|
+
device int32_t * dst_row = dst + row * ncols;
|
1880
|
+
|
1881
|
+
// initialize indices
|
1882
|
+
if (col < ncols) {
|
1883
|
+
dst_row[col] = col;
|
1884
|
+
}
|
1885
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1886
|
+
|
1887
|
+
for (int k = 2; k <= ncols; k *= 2) {
|
1888
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
1889
|
+
int ixj = col ^ j;
|
1890
|
+
if (ixj > col) {
|
1891
|
+
if ((col & k) == 0) {
|
1892
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
1893
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
1894
|
+
}
|
1895
|
+
} else {
|
1896
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
1897
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
1898
|
+
}
|
1899
|
+
}
|
1900
|
+
}
|
1901
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1902
|
+
}
|
1903
|
+
}
|
1904
|
+
}
|
1905
|
+
|
1906
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
1907
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
1908
|
+
|
1909
|
+
kernel void kernel_leaky_relu_f32(
|
1910
|
+
device const float * src0,
|
1911
|
+
device float * dst,
|
1912
|
+
constant float & slope,
|
1913
|
+
uint tpig[[thread_position_in_grid]]) {
|
1914
|
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
1915
|
+
}
|
1916
|
+
|
1338
1917
|
kernel void kernel_cpy_f16_f16(
|
1339
|
-
device
|
1340
|
-
device
|
1918
|
+
device const half * src0,
|
1919
|
+
device half * dst,
|
1341
1920
|
constant int64_t & ne00,
|
1342
1921
|
constant int64_t & ne01,
|
1343
1922
|
constant int64_t & ne02,
|
@@ -1376,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
|
|
1376
1955
|
}
|
1377
1956
|
}
|
1378
1957
|
|
1958
|
+
kernel void kernel_cpy_f16_f32(
|
1959
|
+
device const half * src0,
|
1960
|
+
device float * dst,
|
1961
|
+
constant int64_t & ne00,
|
1962
|
+
constant int64_t & ne01,
|
1963
|
+
constant int64_t & ne02,
|
1964
|
+
constant int64_t & ne03,
|
1965
|
+
constant uint64_t & nb00,
|
1966
|
+
constant uint64_t & nb01,
|
1967
|
+
constant uint64_t & nb02,
|
1968
|
+
constant uint64_t & nb03,
|
1969
|
+
constant int64_t & ne0,
|
1970
|
+
constant int64_t & ne1,
|
1971
|
+
constant int64_t & ne2,
|
1972
|
+
constant int64_t & ne3,
|
1973
|
+
constant uint64_t & nb0,
|
1974
|
+
constant uint64_t & nb1,
|
1975
|
+
constant uint64_t & nb2,
|
1976
|
+
constant uint64_t & nb3,
|
1977
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1978
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1979
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1980
|
+
const int64_t i03 = tgpig[2];
|
1981
|
+
const int64_t i02 = tgpig[1];
|
1982
|
+
const int64_t i01 = tgpig[0];
|
1983
|
+
|
1984
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1985
|
+
|
1986
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1987
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1988
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1989
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1990
|
+
|
1991
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1992
|
+
|
1993
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1994
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1995
|
+
dst_data[i00] = src[0];
|
1996
|
+
}
|
1997
|
+
}
|
1998
|
+
|
1379
1999
|
kernel void kernel_cpy_f32_f16(
|
1380
2000
|
device const float * src0,
|
1381
2001
|
device half * dst,
|
@@ -1460,47 +2080,238 @@ kernel void kernel_cpy_f32_f32(
|
|
1460
2080
|
}
|
1461
2081
|
}
|
1462
2082
|
|
1463
|
-
kernel void
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
2083
|
+
kernel void kernel_cpy_f32_q8_0(
|
2084
|
+
device const float * src0,
|
2085
|
+
device void * dst,
|
2086
|
+
constant int64_t & ne00,
|
2087
|
+
constant int64_t & ne01,
|
2088
|
+
constant int64_t & ne02,
|
2089
|
+
constant int64_t & ne03,
|
2090
|
+
constant uint64_t & nb00,
|
2091
|
+
constant uint64_t & nb01,
|
2092
|
+
constant uint64_t & nb02,
|
2093
|
+
constant uint64_t & nb03,
|
2094
|
+
constant int64_t & ne0,
|
2095
|
+
constant int64_t & ne1,
|
2096
|
+
constant int64_t & ne2,
|
2097
|
+
constant int64_t & ne3,
|
2098
|
+
constant uint64_t & nb0,
|
2099
|
+
constant uint64_t & nb1,
|
2100
|
+
constant uint64_t & nb2,
|
2101
|
+
constant uint64_t & nb3,
|
2102
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2103
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2104
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2105
|
+
const int64_t i03 = tgpig[2];
|
2106
|
+
const int64_t i02 = tgpig[1];
|
2107
|
+
const int64_t i01 = tgpig[0];
|
2108
|
+
|
2109
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2110
|
+
|
2111
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2112
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2113
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2114
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
2115
|
+
|
2116
|
+
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2117
|
+
|
2118
|
+
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
2119
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2120
|
+
|
2121
|
+
float amax = 0.0f; // absolute max
|
2122
|
+
|
2123
|
+
for (int j = 0; j < QK8_0; j++) {
|
2124
|
+
const float v = src[j];
|
2125
|
+
amax = MAX(amax, fabs(v));
|
2126
|
+
}
|
2127
|
+
|
2128
|
+
const float d = amax / ((1 << 7) - 1);
|
2129
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2130
|
+
|
2131
|
+
dst_data[i00/QK8_0].d = d;
|
2132
|
+
|
2133
|
+
for (int j = 0; j < QK8_0; ++j) {
|
2134
|
+
const float x0 = src[j]*id;
|
2135
|
+
|
2136
|
+
dst_data[i00/QK8_0].qs[j] = round(x0);
|
2137
|
+
}
|
2138
|
+
}
|
2139
|
+
}
|
2140
|
+
|
2141
|
+
kernel void kernel_cpy_f32_q4_0(
|
2142
|
+
device const float * src0,
|
2143
|
+
device void * dst,
|
2144
|
+
constant int64_t & ne00,
|
2145
|
+
constant int64_t & ne01,
|
2146
|
+
constant int64_t & ne02,
|
2147
|
+
constant int64_t & ne03,
|
2148
|
+
constant uint64_t & nb00,
|
2149
|
+
constant uint64_t & nb01,
|
2150
|
+
constant uint64_t & nb02,
|
2151
|
+
constant uint64_t & nb03,
|
2152
|
+
constant int64_t & ne0,
|
2153
|
+
constant int64_t & ne1,
|
2154
|
+
constant int64_t & ne2,
|
2155
|
+
constant int64_t & ne3,
|
2156
|
+
constant uint64_t & nb0,
|
2157
|
+
constant uint64_t & nb1,
|
2158
|
+
constant uint64_t & nb2,
|
2159
|
+
constant uint64_t & nb3,
|
2160
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2161
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2162
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2163
|
+
const int64_t i03 = tgpig[2];
|
2164
|
+
const int64_t i02 = tgpig[1];
|
2165
|
+
const int64_t i01 = tgpig[0];
|
2166
|
+
|
2167
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2168
|
+
|
2169
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2170
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2171
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2172
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
|
2173
|
+
|
2174
|
+
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2175
|
+
|
2176
|
+
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
|
2177
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2178
|
+
|
2179
|
+
float amax = 0.0f; // absolute max
|
2180
|
+
float max = 0.0f;
|
2181
|
+
|
2182
|
+
for (int j = 0; j < QK4_0; j++) {
|
2183
|
+
const float v = src[j];
|
2184
|
+
if (amax < fabs(v)) {
|
2185
|
+
amax = fabs(v);
|
2186
|
+
max = v;
|
2187
|
+
}
|
2188
|
+
}
|
2189
|
+
|
2190
|
+
const float d = max / -8;
|
2191
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2192
|
+
|
2193
|
+
dst_data[i00/QK4_0].d = d;
|
2194
|
+
|
2195
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
2196
|
+
const float x0 = src[0 + j]*id;
|
2197
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
2198
|
+
|
2199
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
2200
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
2201
|
+
|
2202
|
+
dst_data[i00/QK4_0].qs[j] = xi0;
|
2203
|
+
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
2204
|
+
}
|
2205
|
+
}
|
2206
|
+
}
|
2207
|
+
|
2208
|
+
kernel void kernel_cpy_f32_q4_1(
|
2209
|
+
device const float * src0,
|
2210
|
+
device void * dst,
|
2211
|
+
constant int64_t & ne00,
|
2212
|
+
constant int64_t & ne01,
|
2213
|
+
constant int64_t & ne02,
|
2214
|
+
constant int64_t & ne03,
|
2215
|
+
constant uint64_t & nb00,
|
2216
|
+
constant uint64_t & nb01,
|
2217
|
+
constant uint64_t & nb02,
|
2218
|
+
constant uint64_t & nb03,
|
2219
|
+
constant int64_t & ne0,
|
2220
|
+
constant int64_t & ne1,
|
2221
|
+
constant int64_t & ne2,
|
2222
|
+
constant int64_t & ne3,
|
2223
|
+
constant uint64_t & nb0,
|
2224
|
+
constant uint64_t & nb1,
|
2225
|
+
constant uint64_t & nb2,
|
2226
|
+
constant uint64_t & nb3,
|
2227
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2228
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2229
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2230
|
+
const int64_t i03 = tgpig[2];
|
2231
|
+
const int64_t i02 = tgpig[1];
|
2232
|
+
const int64_t i01 = tgpig[0];
|
2233
|
+
|
2234
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2235
|
+
|
2236
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2237
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2238
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2239
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
|
2240
|
+
|
2241
|
+
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2242
|
+
|
2243
|
+
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
|
2244
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2245
|
+
|
2246
|
+
float min = FLT_MAX;
|
2247
|
+
float max = -FLT_MAX;
|
2248
|
+
|
2249
|
+
for (int j = 0; j < QK4_1; j++) {
|
2250
|
+
const float v = src[j];
|
2251
|
+
if (min > v) min = v;
|
2252
|
+
if (max < v) max = v;
|
2253
|
+
}
|
2254
|
+
|
2255
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
2256
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2257
|
+
|
2258
|
+
dst_data[i00/QK4_1].d = d;
|
2259
|
+
dst_data[i00/QK4_1].m = min;
|
2260
|
+
|
2261
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
2262
|
+
const float x0 = (src[0 + j] - min)*id;
|
2263
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
2264
|
+
|
2265
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
2266
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
2267
|
+
|
2268
|
+
dst_data[i00/QK4_1].qs[j] = xi0;
|
2269
|
+
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
2270
|
+
}
|
2271
|
+
}
|
2272
|
+
}
|
2273
|
+
|
2274
|
+
kernel void kernel_concat(
|
2275
|
+
device const char * src0,
|
2276
|
+
device const char * src1,
|
2277
|
+
device char * dst,
|
2278
|
+
constant int64_t & ne00,
|
2279
|
+
constant int64_t & ne01,
|
2280
|
+
constant int64_t & ne02,
|
2281
|
+
constant int64_t & ne03,
|
2282
|
+
constant uint64_t & nb00,
|
2283
|
+
constant uint64_t & nb01,
|
2284
|
+
constant uint64_t & nb02,
|
2285
|
+
constant uint64_t & nb03,
|
2286
|
+
constant int64_t & ne10,
|
2287
|
+
constant int64_t & ne11,
|
2288
|
+
constant int64_t & ne12,
|
2289
|
+
constant int64_t & ne13,
|
2290
|
+
constant uint64_t & nb10,
|
2291
|
+
constant uint64_t & nb11,
|
2292
|
+
constant uint64_t & nb12,
|
2293
|
+
constant uint64_t & nb13,
|
2294
|
+
constant int64_t & ne0,
|
2295
|
+
constant int64_t & ne1,
|
2296
|
+
constant int64_t & ne2,
|
2297
|
+
constant int64_t & ne3,
|
2298
|
+
constant uint64_t & nb0,
|
2299
|
+
constant uint64_t & nb1,
|
2300
|
+
constant uint64_t & nb2,
|
2301
|
+
constant uint64_t & nb3,
|
2302
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2303
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2304
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2305
|
+
|
2306
|
+
const int64_t i03 = tgpig.z;
|
2307
|
+
const int64_t i02 = tgpig.y;
|
2308
|
+
const int64_t i01 = tgpig.x;
|
2309
|
+
|
2310
|
+
const int64_t i13 = i03 % ne13;
|
1500
2311
|
const int64_t i12 = i02 % ne12;
|
1501
2312
|
const int64_t i11 = i01 % ne11;
|
1502
2313
|
|
1503
|
-
device const char * src0_ptr = src0 + i03
|
2314
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
1504
2315
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
1505
2316
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
1506
2317
|
|
@@ -1608,32 +2419,39 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
1608
2419
|
|
1609
2420
|
//====================================== dot products =========================
|
1610
2421
|
|
1611
|
-
|
2422
|
+
void kernel_mul_mv_q2_K_f32_impl(
|
1612
2423
|
device const void * src0,
|
1613
2424
|
device const float * src1,
|
1614
2425
|
device float * dst,
|
1615
2426
|
constant int64_t & ne00,
|
1616
|
-
constant int64_t & ne01
|
1617
|
-
constant int64_t & ne02
|
1618
|
-
constant int64_t & ne10
|
1619
|
-
constant int64_t & ne12
|
1620
|
-
constant int64_t & ne0
|
1621
|
-
constant int64_t & ne1
|
1622
|
-
constant uint &
|
2427
|
+
constant int64_t & ne01,
|
2428
|
+
constant int64_t & ne02,
|
2429
|
+
constant int64_t & ne10,
|
2430
|
+
constant int64_t & ne12,
|
2431
|
+
constant int64_t & ne0,
|
2432
|
+
constant int64_t & ne1,
|
2433
|
+
constant uint & r2,
|
2434
|
+
constant uint & r3,
|
1623
2435
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1624
|
-
uint
|
1625
|
-
uint
|
2436
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2437
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1626
2438
|
|
1627
2439
|
const int nb = ne00/QK_K;
|
1628
2440
|
const int r0 = tgpig.x;
|
1629
2441
|
const int r1 = tgpig.y;
|
1630
|
-
const int
|
2442
|
+
const int im = tgpig.z;
|
1631
2443
|
|
1632
2444
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1633
2445
|
const int ib_row = first_row * nb;
|
1634
|
-
|
2446
|
+
|
2447
|
+
const uint i12 = im%ne12;
|
2448
|
+
const uint i13 = im/ne12;
|
2449
|
+
|
2450
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2451
|
+
|
1635
2452
|
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
1636
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2453
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2454
|
+
|
1637
2455
|
float yl[32];
|
1638
2456
|
float sumf[N_DST]={0.f}, all_sum;
|
1639
2457
|
|
@@ -1642,11 +2460,11 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1642
2460
|
#if QK_K == 256
|
1643
2461
|
const int ix = tiisg/8; // 0...3
|
1644
2462
|
const int it = tiisg%8; // 0...7
|
1645
|
-
const int
|
2463
|
+
const int iq = it/4; // 0 or 1
|
1646
2464
|
const int ir = it%4; // 0...3
|
1647
2465
|
const int is = (8*ir)/16;// 0 or 1
|
1648
2466
|
|
1649
|
-
device const float * y4 = y + ix * QK_K + 128 *
|
2467
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
1650
2468
|
|
1651
2469
|
for (int ib = ix; ib < nb; ib += 4) {
|
1652
2470
|
|
@@ -1658,8 +2476,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1658
2476
|
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
1659
2477
|
}
|
1660
2478
|
|
1661
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*
|
1662
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 *
|
2479
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
2480
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
1663
2481
|
device const half * dh = &x[ib].d;
|
1664
2482
|
|
1665
2483
|
for (int row = 0; row < N_DST; row++) {
|
@@ -1746,13 +2564,13 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1746
2564
|
for (int row = 0; row < N_DST; ++row) {
|
1747
2565
|
all_sum = simd_sum(sumf[row]);
|
1748
2566
|
if (tiisg == 0) {
|
1749
|
-
dst[r1*ne0 +
|
2567
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
1750
2568
|
}
|
1751
2569
|
}
|
1752
2570
|
}
|
1753
2571
|
|
1754
|
-
|
1755
|
-
kernel void
|
2572
|
+
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
2573
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
1756
2574
|
device const void * src0,
|
1757
2575
|
device const float * src1,
|
1758
2576
|
device float * dst,
|
@@ -1761,23 +2579,50 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1761
2579
|
constant int64_t & ne02[[buffer(5)]],
|
1762
2580
|
constant int64_t & ne10[[buffer(9)]],
|
1763
2581
|
constant int64_t & ne12[[buffer(11)]],
|
1764
|
-
constant int64_t & ne0[[buffer(15)]],
|
1765
|
-
constant int64_t & ne1[[buffer(16)]],
|
1766
|
-
constant uint &
|
2582
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2583
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2584
|
+
constant uint & r2 [[buffer(17)]],
|
2585
|
+
constant uint & r3 [[buffer(18)]],
|
1767
2586
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1768
|
-
uint
|
1769
|
-
uint
|
2587
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2588
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2589
|
+
|
2590
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2591
|
+
}
|
2592
|
+
|
2593
|
+
#if QK_K == 256
|
2594
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
2595
|
+
device const void * src0,
|
2596
|
+
device const float * src1,
|
2597
|
+
device float * dst,
|
2598
|
+
constant int64_t & ne00,
|
2599
|
+
constant int64_t & ne01,
|
2600
|
+
constant int64_t & ne02,
|
2601
|
+
constant int64_t & ne10,
|
2602
|
+
constant int64_t & ne12,
|
2603
|
+
constant int64_t & ne0,
|
2604
|
+
constant int64_t & ne1,
|
2605
|
+
constant uint & r2,
|
2606
|
+
constant uint & r3,
|
2607
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2608
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2609
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1770
2610
|
|
1771
2611
|
const int nb = ne00/QK_K;
|
1772
2612
|
|
1773
2613
|
const int64_t r0 = tgpig.x;
|
1774
2614
|
const int64_t r1 = tgpig.y;
|
1775
|
-
const int64_t
|
2615
|
+
const int64_t im = tgpig.z;
|
1776
2616
|
|
1777
2617
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1778
|
-
|
2618
|
+
|
2619
|
+
const uint i12 = im%ne12;
|
2620
|
+
const uint i13 = im/ne12;
|
2621
|
+
|
2622
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2623
|
+
|
1779
2624
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1780
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2625
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
1781
2626
|
|
1782
2627
|
float yl[32];
|
1783
2628
|
|
@@ -1899,40 +2744,47 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1899
2744
|
}
|
1900
2745
|
if (tiisg == 0) {
|
1901
2746
|
for (int row = 0; row < 2; ++row) {
|
1902
|
-
dst[r1*ne0 +
|
2747
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
|
1903
2748
|
}
|
1904
2749
|
}
|
1905
2750
|
}
|
1906
2751
|
#else
|
1907
|
-
|
2752
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
1908
2753
|
device const void * src0,
|
1909
2754
|
device const float * src1,
|
1910
2755
|
device float * dst,
|
1911
2756
|
constant int64_t & ne00,
|
1912
|
-
constant int64_t & ne01
|
1913
|
-
constant int64_t & ne02
|
1914
|
-
constant int64_t & ne10
|
1915
|
-
constant int64_t & ne12
|
1916
|
-
constant int64_t & ne0
|
1917
|
-
constant int64_t & ne1
|
1918
|
-
constant uint &
|
2757
|
+
constant int64_t & ne01,
|
2758
|
+
constant int64_t & ne02,
|
2759
|
+
constant int64_t & ne10,
|
2760
|
+
constant int64_t & ne12,
|
2761
|
+
constant int64_t & ne0,
|
2762
|
+
constant int64_t & ne1,
|
2763
|
+
constant uint & r2,
|
2764
|
+
constant uint & r3,
|
1919
2765
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1920
|
-
uint
|
1921
|
-
uint
|
2766
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2767
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1922
2768
|
|
1923
2769
|
const int nb = ne00/QK_K;
|
1924
2770
|
|
1925
2771
|
const int64_t r0 = tgpig.x;
|
1926
2772
|
const int64_t r1 = tgpig.y;
|
1927
|
-
const int64_t
|
2773
|
+
const int64_t im = tgpig.z;
|
1928
2774
|
|
1929
2775
|
const int row = 2 * r0 + sgitg;
|
1930
|
-
|
2776
|
+
|
2777
|
+
const uint i12 = im%ne12;
|
2778
|
+
const uint i13 = im/ne12;
|
2779
|
+
|
2780
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2781
|
+
|
1931
2782
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
1932
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2783
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2784
|
+
|
1933
2785
|
const int ix = tiisg/4;
|
1934
2786
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1935
|
-
const int
|
2787
|
+
const int iq = il/8; // 0, 0, 1, 1
|
1936
2788
|
const int in = il%8; // 0, 4, 0, 4
|
1937
2789
|
|
1938
2790
|
float2 sum = {0.f, 0.f};
|
@@ -1952,7 +2804,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1952
2804
|
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
1953
2805
|
|
1954
2806
|
for (int l = 0; l < 4; l += 2) {
|
1955
|
-
const uint16_t hm = h[l/2] >>
|
2807
|
+
const uint16_t hm = h[l/2] >> iq;
|
1956
2808
|
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
1957
2809
|
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
1958
2810
|
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
@@ -1968,28 +2820,50 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1968
2820
|
|
1969
2821
|
const float tot = simd_sum(sumf);
|
1970
2822
|
if (tiisg == 0) {
|
1971
|
-
dst[r1*ne0 +
|
2823
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
1972
2824
|
}
|
1973
2825
|
|
1974
2826
|
}
|
1975
2827
|
#endif
|
1976
2828
|
|
2829
|
+
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
2830
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
2831
|
+
device const void * src0,
|
2832
|
+
device const float * src1,
|
2833
|
+
device float * dst,
|
2834
|
+
constant int64_t & ne00,
|
2835
|
+
constant int64_t & ne01[[buffer(4)]],
|
2836
|
+
constant int64_t & ne02[[buffer(5)]],
|
2837
|
+
constant int64_t & ne10[[buffer(9)]],
|
2838
|
+
constant int64_t & ne12[[buffer(11)]],
|
2839
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2840
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2841
|
+
constant uint & r2 [[buffer(17)]],
|
2842
|
+
constant uint & r3 [[buffer(18)]],
|
2843
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2844
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2845
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2846
|
+
|
2847
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2848
|
+
}
|
2849
|
+
|
1977
2850
|
#if QK_K == 256
|
1978
|
-
|
2851
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
1979
2852
|
device const void * src0,
|
1980
2853
|
device const float * src1,
|
1981
2854
|
device float * dst,
|
1982
2855
|
constant int64_t & ne00,
|
1983
|
-
constant int64_t & ne01
|
1984
|
-
constant int64_t & ne02
|
1985
|
-
constant int64_t & ne10
|
1986
|
-
constant int64_t & ne12
|
1987
|
-
constant int64_t & ne0
|
1988
|
-
constant int64_t & ne1
|
1989
|
-
constant uint &
|
2856
|
+
constant int64_t & ne01,
|
2857
|
+
constant int64_t & ne02,
|
2858
|
+
constant int64_t & ne10,
|
2859
|
+
constant int64_t & ne12,
|
2860
|
+
constant int64_t & ne0,
|
2861
|
+
constant int64_t & ne1,
|
2862
|
+
constant uint & r2,
|
2863
|
+
constant uint & r3,
|
1990
2864
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1991
|
-
uint
|
1992
|
-
uint
|
2865
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2866
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1993
2867
|
|
1994
2868
|
const uint16_t kmask1 = 0x3f3f;
|
1995
2869
|
const uint16_t kmask2 = 0x0f0f;
|
@@ -1997,26 +2871,32 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
1997
2871
|
|
1998
2872
|
const int ix = tiisg/8; // 0...3
|
1999
2873
|
const int it = tiisg%8; // 0...7
|
2000
|
-
const int
|
2874
|
+
const int iq = it/4; // 0 or 1
|
2001
2875
|
const int ir = it%4; // 0...3
|
2002
2876
|
|
2003
2877
|
const int nb = ne00/QK_K;
|
2004
2878
|
const int r0 = tgpig.x;
|
2005
2879
|
const int r1 = tgpig.y;
|
2006
|
-
const int
|
2880
|
+
const int im = tgpig.z;
|
2007
2881
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
2008
2882
|
const int first_row = r0 * N_DST;
|
2009
2883
|
const int ib_row = first_row * nb;
|
2010
|
-
|
2884
|
+
|
2885
|
+
const uint i12 = im%ne12;
|
2886
|
+
const uint i13 = im/ne12;
|
2887
|
+
|
2888
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2889
|
+
|
2011
2890
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
2012
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2891
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2892
|
+
|
2013
2893
|
float yl[16];
|
2014
2894
|
float yh[16];
|
2015
2895
|
float sumf[N_DST]={0.f}, all_sum;
|
2016
2896
|
|
2017
2897
|
const int step = sizeof(block_q4_K) * nb / 2;
|
2018
2898
|
|
2019
|
-
device const float * y4 = y + ix * QK_K + 64 *
|
2899
|
+
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
2020
2900
|
|
2021
2901
|
uint16_t sc16[4];
|
2022
2902
|
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
@@ -2031,8 +2911,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2031
2911
|
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
2032
2912
|
}
|
2033
2913
|
|
2034
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales +
|
2035
|
-
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 *
|
2914
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
|
2915
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
2036
2916
|
device const half * dh = &x[ib].d;
|
2037
2917
|
|
2038
2918
|
for (int row = 0; row < N_DST; row++) {
|
@@ -2076,23 +2956,24 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2076
2956
|
for (int row = 0; row < N_DST; ++row) {
|
2077
2957
|
all_sum = simd_sum(sumf[row]);
|
2078
2958
|
if (tiisg == 0) {
|
2079
|
-
dst[r1*ne0 +
|
2959
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
2080
2960
|
}
|
2081
2961
|
}
|
2082
2962
|
}
|
2083
2963
|
#else
|
2084
|
-
|
2964
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
2085
2965
|
device const void * src0,
|
2086
2966
|
device const float * src1,
|
2087
2967
|
device float * dst,
|
2088
2968
|
constant int64_t & ne00,
|
2089
|
-
constant int64_t & ne01
|
2090
|
-
constant int64_t & ne02
|
2091
|
-
constant int64_t & ne10
|
2092
|
-
constant int64_t & ne12
|
2093
|
-
constant int64_t & ne0
|
2094
|
-
constant int64_t & ne1
|
2095
|
-
constant uint &
|
2969
|
+
constant int64_t & ne01,
|
2970
|
+
constant int64_t & ne02,
|
2971
|
+
constant int64_t & ne10,
|
2972
|
+
constant int64_t & ne12,
|
2973
|
+
constant int64_t & ne0,
|
2974
|
+
constant int64_t & ne1,
|
2975
|
+
constant uint & r2,
|
2976
|
+
constant uint & r3,
|
2096
2977
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2097
2978
|
uint tiisg[[thread_index_in_simdgroup]],
|
2098
2979
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2103,12 +2984,18 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2103
2984
|
const int nb = ne00/QK_K;
|
2104
2985
|
const int r0 = tgpig.x;
|
2105
2986
|
const int r1 = tgpig.y;
|
2106
|
-
const int
|
2987
|
+
const int im = tgpig.z;
|
2107
2988
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
2108
2989
|
const int ib_row = first_row * nb;
|
2109
|
-
|
2990
|
+
|
2991
|
+
const uint i12 = im%ne12;
|
2992
|
+
const uint i13 = im/ne12;
|
2993
|
+
|
2994
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2995
|
+
|
2110
2996
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
2111
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2997
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2998
|
+
|
2112
2999
|
float yl[8];
|
2113
3000
|
float yh[8];
|
2114
3001
|
float sumf[N_DST]={0.f}, all_sum;
|
@@ -2164,13 +3051,14 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2164
3051
|
for (int row = 0; row < N_DST; ++row) {
|
2165
3052
|
all_sum = simd_sum(sumf[row]);
|
2166
3053
|
if (tiisg == 0) {
|
2167
|
-
dst[r1*ne0+
|
3054
|
+
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
2168
3055
|
}
|
2169
3056
|
}
|
2170
3057
|
}
|
2171
3058
|
#endif
|
2172
3059
|
|
2173
|
-
|
3060
|
+
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
3061
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
2174
3062
|
device const void * src0,
|
2175
3063
|
device const float * src1,
|
2176
3064
|
device float * dst,
|
@@ -2179,23 +3067,49 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2179
3067
|
constant int64_t & ne02[[buffer(5)]],
|
2180
3068
|
constant int64_t & ne10[[buffer(9)]],
|
2181
3069
|
constant int64_t & ne12[[buffer(11)]],
|
2182
|
-
constant int64_t & ne0[[buffer(15)]],
|
2183
|
-
constant int64_t & ne1[[buffer(16)]],
|
2184
|
-
constant uint &
|
3070
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3071
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3072
|
+
constant uint & r2 [[buffer(17)]],
|
3073
|
+
constant uint & r3 [[buffer(18)]],
|
2185
3074
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2186
3075
|
uint tiisg[[thread_index_in_simdgroup]],
|
2187
3076
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2188
3077
|
|
3078
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3079
|
+
}
|
3080
|
+
|
3081
|
+
void kernel_mul_mv_q5_K_f32_impl(
|
3082
|
+
device const void * src0,
|
3083
|
+
device const float * src1,
|
3084
|
+
device float * dst,
|
3085
|
+
constant int64_t & ne00,
|
3086
|
+
constant int64_t & ne01,
|
3087
|
+
constant int64_t & ne02,
|
3088
|
+
constant int64_t & ne10,
|
3089
|
+
constant int64_t & ne12,
|
3090
|
+
constant int64_t & ne0,
|
3091
|
+
constant int64_t & ne1,
|
3092
|
+
constant uint & r2,
|
3093
|
+
constant uint & r3,
|
3094
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3095
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3096
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3097
|
+
|
2189
3098
|
const int nb = ne00/QK_K;
|
2190
3099
|
|
2191
3100
|
const int64_t r0 = tgpig.x;
|
2192
3101
|
const int64_t r1 = tgpig.y;
|
2193
|
-
const int
|
3102
|
+
const int im = tgpig.z;
|
2194
3103
|
|
2195
3104
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
2196
|
-
|
3105
|
+
|
3106
|
+
const uint i12 = im%ne12;
|
3107
|
+
const uint i13 = im/ne12;
|
3108
|
+
|
3109
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3110
|
+
|
2197
3111
|
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
2198
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
3112
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2199
3113
|
|
2200
3114
|
float sumf[2]={0.f};
|
2201
3115
|
|
@@ -2211,15 +3125,15 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2211
3125
|
|
2212
3126
|
const int tid = tiisg/4;
|
2213
3127
|
const int ix = tiisg%4;
|
2214
|
-
const int
|
3128
|
+
const int iq = tid/4;
|
2215
3129
|
const int ir = tid%4;
|
2216
3130
|
const int n = 8;
|
2217
3131
|
|
2218
3132
|
const int l0 = n*ir;
|
2219
|
-
const int q_offset = 32*
|
2220
|
-
const int y_offset = 64*
|
3133
|
+
const int q_offset = 32*iq + l0;
|
3134
|
+
const int y_offset = 64*iq + l0;
|
2221
3135
|
|
2222
|
-
const uint8_t hm1 = 1u << (2*
|
3136
|
+
const uint8_t hm1 = 1u << (2*iq);
|
2223
3137
|
const uint8_t hm2 = hm1 << 1;
|
2224
3138
|
const uint8_t hm3 = hm1 << 4;
|
2225
3139
|
const uint8_t hm4 = hm2 << 4;
|
@@ -2234,7 +3148,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2234
3148
|
device const uint8_t * q1 = x[i].qs + q_offset;
|
2235
3149
|
device const uint8_t * qh = x[i].qh + l0;
|
2236
3150
|
device const half * dh = &x[i].d;
|
2237
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales +
|
3151
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
|
2238
3152
|
|
2239
3153
|
device const float * y2 = y1 + 128;
|
2240
3154
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
@@ -2290,7 +3204,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2290
3204
|
|
2291
3205
|
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
2292
3206
|
const int ix = tiisg%8;
|
2293
|
-
const int
|
3207
|
+
const int iq = il/8; // 0, 0, 1, 1
|
2294
3208
|
const int in = il%8; // 0, 4, 0, 4
|
2295
3209
|
|
2296
3210
|
device const float * y = yy + ix*QK_K + il;
|
@@ -2315,7 +3229,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2315
3229
|
|
2316
3230
|
float2 acc = {0.f, 0.f};
|
2317
3231
|
for (int l = 0; l < 4; ++l) {
|
2318
|
-
const uint8_t hl = h[l] >>
|
3232
|
+
const uint8_t hl = h[l] >> iq;
|
2319
3233
|
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
2320
3234
|
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
2321
3235
|
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
@@ -2337,27 +3251,48 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2337
3251
|
for (int row = 0; row < 2; ++row) {
|
2338
3252
|
const float tot = simd_sum(sumf[row]);
|
2339
3253
|
if (tiisg == 0) {
|
2340
|
-
dst[r1*ne0 +
|
3254
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
2341
3255
|
}
|
2342
3256
|
}
|
3257
|
+
}
|
3258
|
+
|
3259
|
+
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
3260
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
3261
|
+
device const void * src0,
|
3262
|
+
device const float * src1,
|
3263
|
+
device float * dst,
|
3264
|
+
constant int64_t & ne00,
|
3265
|
+
constant int64_t & ne01[[buffer(4)]],
|
3266
|
+
constant int64_t & ne02[[buffer(5)]],
|
3267
|
+
constant int64_t & ne10[[buffer(9)]],
|
3268
|
+
constant int64_t & ne12[[buffer(11)]],
|
3269
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3270
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3271
|
+
constant uint & r2 [[buffer(17)]],
|
3272
|
+
constant uint & r3 [[buffer(18)]],
|
3273
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3274
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3275
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2343
3276
|
|
3277
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
2344
3278
|
}
|
2345
3279
|
|
2346
|
-
|
3280
|
+
void kernel_mul_mv_q6_K_f32_impl(
|
2347
3281
|
device const void * src0,
|
2348
3282
|
device const float * src1,
|
2349
3283
|
device float * dst,
|
2350
3284
|
constant int64_t & ne00,
|
2351
|
-
constant int64_t & ne01
|
2352
|
-
constant int64_t & ne02
|
2353
|
-
constant int64_t & ne10
|
2354
|
-
constant int64_t & ne12
|
2355
|
-
constant int64_t & ne0
|
2356
|
-
constant int64_t & ne1
|
2357
|
-
constant uint &
|
3285
|
+
constant int64_t & ne01,
|
3286
|
+
constant int64_t & ne02,
|
3287
|
+
constant int64_t & ne10,
|
3288
|
+
constant int64_t & ne12,
|
3289
|
+
constant int64_t & ne0,
|
3290
|
+
constant int64_t & ne1,
|
3291
|
+
constant uint & r2,
|
3292
|
+
constant uint & r3,
|
2358
3293
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2359
|
-
uint
|
2360
|
-
uint
|
3294
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3295
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2361
3296
|
|
2362
3297
|
const uint8_t kmask1 = 0x03;
|
2363
3298
|
const uint8_t kmask2 = 0x0C;
|
@@ -2368,12 +3303,17 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2368
3303
|
|
2369
3304
|
const int64_t r0 = tgpig.x;
|
2370
3305
|
const int64_t r1 = tgpig.y;
|
2371
|
-
const int
|
3306
|
+
const int im = tgpig.z;
|
2372
3307
|
|
2373
3308
|
const int row = 2 * r0 + sgitg;
|
2374
|
-
|
3309
|
+
|
3310
|
+
const uint i12 = im%ne12;
|
3311
|
+
const uint i13 = im/ne12;
|
3312
|
+
|
3313
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3314
|
+
|
2375
3315
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
2376
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
3316
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2377
3317
|
|
2378
3318
|
float sumf = 0;
|
2379
3319
|
|
@@ -2439,10 +3379,31 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2439
3379
|
|
2440
3380
|
const float tot = simd_sum(sumf);
|
2441
3381
|
if (tiisg == 0) {
|
2442
|
-
dst[r1*ne0 +
|
3382
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
2443
3383
|
}
|
2444
3384
|
}
|
2445
3385
|
|
3386
|
+
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
3387
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
3388
|
+
device const void * src0,
|
3389
|
+
device const float * src1,
|
3390
|
+
device float * dst,
|
3391
|
+
constant int64_t & ne00,
|
3392
|
+
constant int64_t & ne01[[buffer(4)]],
|
3393
|
+
constant int64_t & ne02[[buffer(5)]],
|
3394
|
+
constant int64_t & ne10[[buffer(9)]],
|
3395
|
+
constant int64_t & ne12[[buffer(11)]],
|
3396
|
+
constant int64_t & ne0 [[buffer(15)]],
|
3397
|
+
constant int64_t & ne1 [[buffer(16)]],
|
3398
|
+
constant uint & r2 [[buffer(17)]],
|
3399
|
+
constant uint & r3 [[buffer(18)]],
|
3400
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3401
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3402
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3403
|
+
|
3404
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3405
|
+
}
|
3406
|
+
|
2446
3407
|
//============================= templates and their specializations =============================
|
2447
3408
|
|
2448
3409
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
@@ -2560,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
2560
3521
|
|
2561
3522
|
template <typename type4x4>
|
2562
3523
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
2563
|
-
const
|
2564
|
-
const
|
3524
|
+
const float d = xb->d;
|
3525
|
+
const float min = xb->dmin;
|
2565
3526
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
2566
|
-
|
3527
|
+
float dl, ml;
|
2567
3528
|
uint8_t sc = xb->scales[il];
|
2568
3529
|
|
2569
3530
|
#if QK_K == 256
|
@@ -2633,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
2633
3594
|
q = q + (il/4) * 32 + 16 * (il&1);
|
2634
3595
|
il = il & 3;
|
2635
3596
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2636
|
-
const
|
2637
|
-
const
|
2638
|
-
const
|
2639
|
-
const
|
3597
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3598
|
+
const float min = xb->dmin;
|
3599
|
+
const float dl = d * sc[0];
|
3600
|
+
const float ml = min * sc[1];
|
2640
3601
|
#else
|
2641
3602
|
q = q + 16 * (il&1);
|
2642
3603
|
device const uint8_t * s = xb->scales;
|
@@ -2663,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
2663
3624
|
uint8_t ul = 1 << (il/2);
|
2664
3625
|
il = il & 3;
|
2665
3626
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2666
|
-
const
|
2667
|
-
const
|
2668
|
-
const
|
2669
|
-
const
|
3627
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
3628
|
+
const float min = xb->dmin;
|
3629
|
+
const float dl = d * sc[0];
|
3630
|
+
const float ml = min * sc[1];
|
2670
3631
|
|
2671
|
-
const ushort mask
|
2672
|
-
const
|
3632
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
3633
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
2673
3634
|
for (int i = 0; i < 16; ++i) {
|
2674
3635
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
2675
3636
|
}
|
@@ -2717,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
2717
3678
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
2718
3679
|
kernel void kernel_get_rows(
|
2719
3680
|
device const void * src0,
|
2720
|
-
device const
|
3681
|
+
device const char * src1,
|
2721
3682
|
device float * dst,
|
2722
3683
|
constant int64_t & ne00,
|
2723
3684
|
constant uint64_t & nb01,
|
3685
|
+
constant uint64_t & nb02,
|
3686
|
+
constant int64_t & ne10,
|
3687
|
+
constant uint64_t & nb10,
|
3688
|
+
constant uint64_t & nb11,
|
2724
3689
|
constant uint64_t & nb1,
|
2725
|
-
|
3690
|
+
constant uint64_t & nb2,
|
3691
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2726
3692
|
uint tiitg[[thread_index_in_threadgroup]],
|
2727
|
-
|
2728
|
-
const
|
2729
|
-
const
|
3693
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3694
|
+
//const int64_t i = tgpig;
|
3695
|
+
//const int64_t r = ((device int32_t *) src1)[i];
|
3696
|
+
|
3697
|
+
const int64_t i10 = tgpig.x;
|
3698
|
+
const int64_t i11 = tgpig.y;
|
2730
3699
|
|
2731
|
-
|
3700
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3701
|
+
|
3702
|
+
const int64_t i02 = i11;
|
3703
|
+
|
3704
|
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
2732
3705
|
float4x4 temp;
|
2733
3706
|
dequantize_func(
|
2734
|
-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
2735
|
-
*(((device float4x4 *) ((device char *) dst +
|
3707
|
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
3708
|
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
3709
|
+
}
|
3710
|
+
}
|
3711
|
+
|
3712
|
+
kernel void kernel_get_rows_f32(
|
3713
|
+
device const void * src0,
|
3714
|
+
device const char * src1,
|
3715
|
+
device float * dst,
|
3716
|
+
constant int64_t & ne00,
|
3717
|
+
constant uint64_t & nb01,
|
3718
|
+
constant uint64_t & nb02,
|
3719
|
+
constant int64_t & ne10,
|
3720
|
+
constant uint64_t & nb10,
|
3721
|
+
constant uint64_t & nb11,
|
3722
|
+
constant uint64_t & nb1,
|
3723
|
+
constant uint64_t & nb2,
|
3724
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3725
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3726
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3727
|
+
const int64_t i10 = tgpig.x;
|
3728
|
+
const int64_t i11 = tgpig.y;
|
3729
|
+
|
3730
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3731
|
+
|
3732
|
+
const int64_t i02 = i11;
|
3733
|
+
|
3734
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3735
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3736
|
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3737
|
+
}
|
3738
|
+
}
|
3739
|
+
|
3740
|
+
kernel void kernel_get_rows_f16(
|
3741
|
+
device const void * src0,
|
3742
|
+
device const char * src1,
|
3743
|
+
device float * dst,
|
3744
|
+
constant int64_t & ne00,
|
3745
|
+
constant uint64_t & nb01,
|
3746
|
+
constant uint64_t & nb02,
|
3747
|
+
constant int64_t & ne10,
|
3748
|
+
constant uint64_t & nb10,
|
3749
|
+
constant uint64_t & nb11,
|
3750
|
+
constant uint64_t & nb1,
|
3751
|
+
constant uint64_t & nb2,
|
3752
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3753
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3754
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3755
|
+
const int64_t i10 = tgpig.x;
|
3756
|
+
const int64_t i11 = tgpig.y;
|
3757
|
+
|
3758
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3759
|
+
|
3760
|
+
const int64_t i02 = i11;
|
3761
|
+
|
3762
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3763
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3764
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
2736
3765
|
}
|
2737
3766
|
}
|
2738
3767
|
|
@@ -2749,24 +3778,25 @@ kernel void kernel_get_rows(
|
|
2749
3778
|
|
2750
3779
|
// each block_q contains 16*nl weights
|
2751
3780
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
2752
|
-
|
2753
|
-
|
2754
|
-
|
2755
|
-
|
2756
|
-
|
2757
|
-
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2762
|
-
|
2763
|
-
|
2764
|
-
|
2765
|
-
|
2766
|
-
|
2767
|
-
|
2768
|
-
|
2769
|
-
|
3781
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
3782
|
+
device const uchar * src1,
|
3783
|
+
device float * dst,
|
3784
|
+
constant int64_t & ne00,
|
3785
|
+
constant int64_t & ne02,
|
3786
|
+
constant int64_t & nb01,
|
3787
|
+
constant int64_t & nb02,
|
3788
|
+
constant int64_t & ne12,
|
3789
|
+
constant int64_t & nb10,
|
3790
|
+
constant int64_t & nb11,
|
3791
|
+
constant int64_t & nb12,
|
3792
|
+
constant int64_t & ne0,
|
3793
|
+
constant int64_t & ne1,
|
3794
|
+
constant uint & r2,
|
3795
|
+
constant uint & r3,
|
3796
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3797
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3798
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3799
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2770
3800
|
|
2771
3801
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
2772
3802
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
@@ -2792,7 +3822,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2792
3822
|
|
2793
3823
|
short il = (tiitg % THREAD_PER_ROW);
|
2794
3824
|
|
2795
|
-
uint
|
3825
|
+
const uint i12 = im%ne12;
|
3826
|
+
const uint i13 = im/ne12;
|
3827
|
+
|
3828
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
2796
3829
|
ushort offset1 = il/nl;
|
2797
3830
|
|
2798
3831
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
@@ -2876,17 +3909,137 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2876
3909
|
}
|
2877
3910
|
}
|
2878
3911
|
|
3912
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3913
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
3914
|
+
device const uchar * src1,
|
3915
|
+
device float * dst,
|
3916
|
+
constant int64_t & ne00,
|
3917
|
+
constant int64_t & ne02,
|
3918
|
+
constant int64_t & nb01,
|
3919
|
+
constant int64_t & nb02,
|
3920
|
+
constant int64_t & ne12,
|
3921
|
+
constant int64_t & nb10,
|
3922
|
+
constant int64_t & nb11,
|
3923
|
+
constant int64_t & nb12,
|
3924
|
+
constant int64_t & ne0,
|
3925
|
+
constant int64_t & ne1,
|
3926
|
+
constant uint & r2,
|
3927
|
+
constant uint & r3,
|
3928
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3929
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3930
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3931
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3932
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3933
|
+
src0,
|
3934
|
+
src1,
|
3935
|
+
dst,
|
3936
|
+
ne00,
|
3937
|
+
ne02,
|
3938
|
+
nb01,
|
3939
|
+
nb02,
|
3940
|
+
ne12,
|
3941
|
+
nb10,
|
3942
|
+
nb11,
|
3943
|
+
nb12,
|
3944
|
+
ne0,
|
3945
|
+
ne1,
|
3946
|
+
r2,
|
3947
|
+
r3,
|
3948
|
+
shared_memory,
|
3949
|
+
tgpig,
|
3950
|
+
tiitg,
|
3951
|
+
sgitg);
|
3952
|
+
}
|
3953
|
+
|
3954
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3955
|
+
kernel void kernel_mul_mm_id(
|
3956
|
+
device const uchar * ids,
|
3957
|
+
device const uchar * src1,
|
3958
|
+
device uchar * dst,
|
3959
|
+
constant int64_t & nbi1,
|
3960
|
+
constant int64_t & ne00,
|
3961
|
+
constant int64_t & ne02,
|
3962
|
+
constant int64_t & nb01,
|
3963
|
+
constant int64_t & nb02,
|
3964
|
+
constant int64_t & ne12,
|
3965
|
+
constant int64_t & ne13,
|
3966
|
+
constant int64_t & nb10,
|
3967
|
+
constant int64_t & nb11,
|
3968
|
+
constant int64_t & nb12,
|
3969
|
+
constant int64_t & ne0,
|
3970
|
+
constant int64_t & ne1,
|
3971
|
+
constant int64_t & nb1,
|
3972
|
+
constant uint & r2,
|
3973
|
+
constant uint & r3,
|
3974
|
+
constant int & idx,
|
3975
|
+
device const uchar * src00,
|
3976
|
+
device const uchar * src01,
|
3977
|
+
device const uchar * src02,
|
3978
|
+
device const uchar * src03,
|
3979
|
+
device const uchar * src04,
|
3980
|
+
device const uchar * src05,
|
3981
|
+
device const uchar * src06,
|
3982
|
+
device const uchar * src07,
|
3983
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3984
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3985
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3986
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3987
|
+
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3988
|
+
|
3989
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
3990
|
+
|
3991
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
3992
|
+
|
3993
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
3994
|
+
|
3995
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3996
|
+
src0[id],
|
3997
|
+
src1 + bid*nb11,
|
3998
|
+
(device float *) (dst + bid*nb1),
|
3999
|
+
ne00,
|
4000
|
+
ne02,
|
4001
|
+
nb01,
|
4002
|
+
nb02,
|
4003
|
+
ne12,
|
4004
|
+
nb10,
|
4005
|
+
nb11,
|
4006
|
+
nb12,
|
4007
|
+
ne0,
|
4008
|
+
ne1,
|
4009
|
+
r2,
|
4010
|
+
r3,
|
4011
|
+
shared_memory,
|
4012
|
+
tgpig,
|
4013
|
+
tiitg,
|
4014
|
+
sgitg);
|
4015
|
+
}
|
4016
|
+
|
2879
4017
|
#if QK_K == 256
|
2880
4018
|
#define QK_NL 16
|
2881
4019
|
#else
|
2882
4020
|
#define QK_NL 4
|
2883
4021
|
#endif
|
2884
4022
|
|
2885
|
-
|
2886
|
-
|
4023
|
+
//
|
4024
|
+
// get rows
|
4025
|
+
//
|
2887
4026
|
|
2888
|
-
|
2889
|
-
|
4027
|
+
typedef void (get_rows_t)(
|
4028
|
+
device const void * src0,
|
4029
|
+
device const char * src1,
|
4030
|
+
device float * dst,
|
4031
|
+
constant int64_t & ne00,
|
4032
|
+
constant uint64_t & nb01,
|
4033
|
+
constant uint64_t & nb02,
|
4034
|
+
constant int64_t & ne10,
|
4035
|
+
constant uint64_t & nb10,
|
4036
|
+
constant uint64_t & nb11,
|
4037
|
+
constant uint64_t & nb1,
|
4038
|
+
constant uint64_t & nb2,
|
4039
|
+
uint3, uint, uint3);
|
4040
|
+
|
4041
|
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
4042
|
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2890
4043
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2891
4044
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2892
4045
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
@@ -2898,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
2898
4051
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
2899
4052
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
2900
4053
|
|
4054
|
+
//
|
4055
|
+
// matrix-matrix multiplication
|
4056
|
+
//
|
4057
|
+
|
2901
4058
|
typedef void (mat_mm_t)(
|
2902
4059
|
device const uchar * src0,
|
2903
4060
|
device const uchar * src1,
|
@@ -2912,8 +4069,10 @@ typedef void (mat_mm_t)(
|
|
2912
4069
|
constant int64_t & nb12,
|
2913
4070
|
constant int64_t & ne0,
|
2914
4071
|
constant int64_t & ne1,
|
2915
|
-
constant uint &
|
2916
|
-
|
4072
|
+
constant uint & r2,
|
4073
|
+
constant uint & r3,
|
4074
|
+
threadgroup uchar *,
|
4075
|
+
uint3, uint, uint);
|
2917
4076
|
|
2918
4077
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
2919
4078
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
@@ -2927,3 +4086,823 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
2927
4086
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
2928
4087
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
2929
4088
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
4089
|
+
|
4090
|
+
//
|
4091
|
+
// indirect matrix-matrix multiplication
|
4092
|
+
//
|
4093
|
+
|
4094
|
+
typedef void (mat_mm_id_t)(
|
4095
|
+
device const uchar * ids,
|
4096
|
+
device const uchar * src1,
|
4097
|
+
device uchar * dst,
|
4098
|
+
constant int64_t & nbi1,
|
4099
|
+
constant int64_t & ne00,
|
4100
|
+
constant int64_t & ne02,
|
4101
|
+
constant int64_t & nb01,
|
4102
|
+
constant int64_t & nb02,
|
4103
|
+
constant int64_t & ne12,
|
4104
|
+
constant int64_t & ne13,
|
4105
|
+
constant int64_t & nb10,
|
4106
|
+
constant int64_t & nb11,
|
4107
|
+
constant int64_t & nb12,
|
4108
|
+
constant int64_t & ne0,
|
4109
|
+
constant int64_t & ne1,
|
4110
|
+
constant int64_t & nb1,
|
4111
|
+
constant uint & r2,
|
4112
|
+
constant uint & r3,
|
4113
|
+
constant int & idx,
|
4114
|
+
device const uchar * src00,
|
4115
|
+
device const uchar * src01,
|
4116
|
+
device const uchar * src02,
|
4117
|
+
device const uchar * src03,
|
4118
|
+
device const uchar * src04,
|
4119
|
+
device const uchar * src05,
|
4120
|
+
device const uchar * src06,
|
4121
|
+
device const uchar * src07,
|
4122
|
+
threadgroup uchar *,
|
4123
|
+
uint3, uint, uint);
|
4124
|
+
|
4125
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
4126
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
4127
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
4128
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
4129
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
4130
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
4131
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
4132
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
4133
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
4134
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
4135
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
4136
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
4137
|
+
|
4138
|
+
//
|
4139
|
+
// matrix-vector multiplication
|
4140
|
+
//
|
4141
|
+
|
4142
|
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
4143
|
+
kernel void kernel_mul_mv_id_f32_f32(
|
4144
|
+
device const char * ids,
|
4145
|
+
device const char * src1,
|
4146
|
+
device uchar * dst,
|
4147
|
+
constant int64_t & nbi1,
|
4148
|
+
constant int64_t & ne00,
|
4149
|
+
constant int64_t & ne01,
|
4150
|
+
constant int64_t & ne02,
|
4151
|
+
constant uint64_t & nb00,
|
4152
|
+
constant uint64_t & nb01,
|
4153
|
+
constant uint64_t & nb02,
|
4154
|
+
constant int64_t & ne10,
|
4155
|
+
constant int64_t & ne11,
|
4156
|
+
constant int64_t & ne12,
|
4157
|
+
constant int64_t & ne13,
|
4158
|
+
constant uint64_t & nb10,
|
4159
|
+
constant uint64_t & nb11,
|
4160
|
+
constant uint64_t & nb12,
|
4161
|
+
constant int64_t & ne0,
|
4162
|
+
constant int64_t & ne1,
|
4163
|
+
constant int64_t & nb1,
|
4164
|
+
constant uint & r2,
|
4165
|
+
constant uint & r3,
|
4166
|
+
constant int & idx,
|
4167
|
+
device const char * src00,
|
4168
|
+
device const char * src01,
|
4169
|
+
device const char * src02,
|
4170
|
+
device const char * src03,
|
4171
|
+
device const char * src04,
|
4172
|
+
device const char * src05,
|
4173
|
+
device const char * src06,
|
4174
|
+
device const char * src07,
|
4175
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4176
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4177
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4178
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4179
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4180
|
+
|
4181
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4182
|
+
|
4183
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4184
|
+
|
4185
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4186
|
+
|
4187
|
+
kernel_mul_mv_f32_f32_impl(
|
4188
|
+
src0[id],
|
4189
|
+
src1 + bid*nb11,
|
4190
|
+
(device float *) (dst + bid*nb1),
|
4191
|
+
ne00,
|
4192
|
+
ne01,
|
4193
|
+
ne02,
|
4194
|
+
nb00,
|
4195
|
+
nb01,
|
4196
|
+
nb02,
|
4197
|
+
ne10,
|
4198
|
+
ne11,
|
4199
|
+
ne12,
|
4200
|
+
nb10,
|
4201
|
+
nb11,
|
4202
|
+
nb12,
|
4203
|
+
ne0,
|
4204
|
+
ne1,
|
4205
|
+
r2,
|
4206
|
+
r3,
|
4207
|
+
tgpig,
|
4208
|
+
tiisg);
|
4209
|
+
}
|
4210
|
+
|
4211
|
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
4212
|
+
kernel void kernel_mul_mv_id_f16_f32(
|
4213
|
+
device const char * ids,
|
4214
|
+
device const char * src1,
|
4215
|
+
device uchar * dst,
|
4216
|
+
constant int64_t & nbi1,
|
4217
|
+
constant int64_t & ne00,
|
4218
|
+
constant int64_t & ne01,
|
4219
|
+
constant int64_t & ne02,
|
4220
|
+
constant uint64_t & nb00,
|
4221
|
+
constant uint64_t & nb01,
|
4222
|
+
constant uint64_t & nb02,
|
4223
|
+
constant int64_t & ne10,
|
4224
|
+
constant int64_t & ne11,
|
4225
|
+
constant int64_t & ne12,
|
4226
|
+
constant int64_t & ne13,
|
4227
|
+
constant uint64_t & nb10,
|
4228
|
+
constant uint64_t & nb11,
|
4229
|
+
constant uint64_t & nb12,
|
4230
|
+
constant int64_t & ne0,
|
4231
|
+
constant int64_t & ne1,
|
4232
|
+
constant int64_t & nb1,
|
4233
|
+
constant uint & r2,
|
4234
|
+
constant uint & r3,
|
4235
|
+
constant int & idx,
|
4236
|
+
device const char * src00,
|
4237
|
+
device const char * src01,
|
4238
|
+
device const char * src02,
|
4239
|
+
device const char * src03,
|
4240
|
+
device const char * src04,
|
4241
|
+
device const char * src05,
|
4242
|
+
device const char * src06,
|
4243
|
+
device const char * src07,
|
4244
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4245
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4246
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4247
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4248
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4249
|
+
|
4250
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4251
|
+
|
4252
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4253
|
+
|
4254
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4255
|
+
|
4256
|
+
kernel_mul_mv_f16_f32_impl(
|
4257
|
+
src0[id],
|
4258
|
+
src1 + bid*nb11,
|
4259
|
+
(device float *) (dst + bid*nb1),
|
4260
|
+
ne00,
|
4261
|
+
ne01,
|
4262
|
+
ne02,
|
4263
|
+
nb00,
|
4264
|
+
nb01,
|
4265
|
+
nb02,
|
4266
|
+
ne10,
|
4267
|
+
ne11,
|
4268
|
+
ne12,
|
4269
|
+
nb10,
|
4270
|
+
nb11,
|
4271
|
+
nb12,
|
4272
|
+
ne0,
|
4273
|
+
ne1,
|
4274
|
+
r2,
|
4275
|
+
r3,
|
4276
|
+
tgpig,
|
4277
|
+
tiisg);
|
4278
|
+
}
|
4279
|
+
|
4280
|
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
4281
|
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
4282
|
+
device const char * ids,
|
4283
|
+
device const char * src1,
|
4284
|
+
device uchar * dst,
|
4285
|
+
constant int64_t & nbi1,
|
4286
|
+
constant int64_t & ne00,
|
4287
|
+
constant int64_t & ne01,
|
4288
|
+
constant int64_t & ne02,
|
4289
|
+
constant uint64_t & nb00,
|
4290
|
+
constant uint64_t & nb01,
|
4291
|
+
constant uint64_t & nb02,
|
4292
|
+
constant int64_t & ne10,
|
4293
|
+
constant int64_t & ne11,
|
4294
|
+
constant int64_t & ne12,
|
4295
|
+
constant int64_t & ne13,
|
4296
|
+
constant uint64_t & nb10,
|
4297
|
+
constant uint64_t & nb11,
|
4298
|
+
constant uint64_t & nb12,
|
4299
|
+
constant int64_t & ne0,
|
4300
|
+
constant int64_t & ne1,
|
4301
|
+
constant int64_t & nb1,
|
4302
|
+
constant uint & r2,
|
4303
|
+
constant uint & r3,
|
4304
|
+
constant int & idx,
|
4305
|
+
device const char * src00,
|
4306
|
+
device const char * src01,
|
4307
|
+
device const char * src02,
|
4308
|
+
device const char * src03,
|
4309
|
+
device const char * src04,
|
4310
|
+
device const char * src05,
|
4311
|
+
device const char * src06,
|
4312
|
+
device const char * src07,
|
4313
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4314
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4315
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4316
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4317
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4318
|
+
|
4319
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4320
|
+
|
4321
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4322
|
+
|
4323
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4324
|
+
|
4325
|
+
kernel_mul_mv_q8_0_f32_impl(
|
4326
|
+
src0[id],
|
4327
|
+
(device const float *) (src1 + bid*nb11),
|
4328
|
+
(device float *) ( dst + bid*nb1),
|
4329
|
+
ne00,
|
4330
|
+
ne01,
|
4331
|
+
ne02,
|
4332
|
+
ne10,
|
4333
|
+
ne12,
|
4334
|
+
ne0,
|
4335
|
+
ne1,
|
4336
|
+
r2,
|
4337
|
+
r3,
|
4338
|
+
tgpig,
|
4339
|
+
tiisg,
|
4340
|
+
sgitg);
|
4341
|
+
}
|
4342
|
+
|
4343
|
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
4344
|
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
4345
|
+
device const char * ids,
|
4346
|
+
device const char * src1,
|
4347
|
+
device uchar * dst,
|
4348
|
+
constant int64_t & nbi1,
|
4349
|
+
constant int64_t & ne00,
|
4350
|
+
constant int64_t & ne01,
|
4351
|
+
constant int64_t & ne02,
|
4352
|
+
constant uint64_t & nb00,
|
4353
|
+
constant uint64_t & nb01,
|
4354
|
+
constant uint64_t & nb02,
|
4355
|
+
constant int64_t & ne10,
|
4356
|
+
constant int64_t & ne11,
|
4357
|
+
constant int64_t & ne12,
|
4358
|
+
constant int64_t & ne13,
|
4359
|
+
constant uint64_t & nb10,
|
4360
|
+
constant uint64_t & nb11,
|
4361
|
+
constant uint64_t & nb12,
|
4362
|
+
constant int64_t & ne0,
|
4363
|
+
constant int64_t & ne1,
|
4364
|
+
constant int64_t & nb1,
|
4365
|
+
constant uint & r2,
|
4366
|
+
constant uint & r3,
|
4367
|
+
constant int & idx,
|
4368
|
+
device const char * src00,
|
4369
|
+
device const char * src01,
|
4370
|
+
device const char * src02,
|
4371
|
+
device const char * src03,
|
4372
|
+
device const char * src04,
|
4373
|
+
device const char * src05,
|
4374
|
+
device const char * src06,
|
4375
|
+
device const char * src07,
|
4376
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4377
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4378
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4379
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4380
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4381
|
+
|
4382
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4383
|
+
|
4384
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4385
|
+
|
4386
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4387
|
+
|
4388
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4389
|
+
src0[id],
|
4390
|
+
(device const float *) (src1 + bid*nb11),
|
4391
|
+
(device float *) ( dst + bid*nb1),
|
4392
|
+
ne00,
|
4393
|
+
ne01,
|
4394
|
+
ne02,
|
4395
|
+
ne10,
|
4396
|
+
ne12,
|
4397
|
+
ne0,
|
4398
|
+
ne1,
|
4399
|
+
r2,
|
4400
|
+
r3,
|
4401
|
+
tgpig,
|
4402
|
+
tiisg,
|
4403
|
+
sgitg);
|
4404
|
+
}
|
4405
|
+
|
4406
|
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
4407
|
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
4408
|
+
device const char * ids,
|
4409
|
+
device const char * src1,
|
4410
|
+
device uchar * dst,
|
4411
|
+
constant int64_t & nbi1,
|
4412
|
+
constant int64_t & ne00,
|
4413
|
+
constant int64_t & ne01,
|
4414
|
+
constant int64_t & ne02,
|
4415
|
+
constant uint64_t & nb00,
|
4416
|
+
constant uint64_t & nb01,
|
4417
|
+
constant uint64_t & nb02,
|
4418
|
+
constant int64_t & ne10,
|
4419
|
+
constant int64_t & ne11,
|
4420
|
+
constant int64_t & ne12,
|
4421
|
+
constant int64_t & ne13,
|
4422
|
+
constant uint64_t & nb10,
|
4423
|
+
constant uint64_t & nb11,
|
4424
|
+
constant uint64_t & nb12,
|
4425
|
+
constant int64_t & ne0,
|
4426
|
+
constant int64_t & ne1,
|
4427
|
+
constant int64_t & nb1,
|
4428
|
+
constant uint & r2,
|
4429
|
+
constant uint & r3,
|
4430
|
+
constant int & idx,
|
4431
|
+
device const char * src00,
|
4432
|
+
device const char * src01,
|
4433
|
+
device const char * src02,
|
4434
|
+
device const char * src03,
|
4435
|
+
device const char * src04,
|
4436
|
+
device const char * src05,
|
4437
|
+
device const char * src06,
|
4438
|
+
device const char * src07,
|
4439
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4440
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4441
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4442
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4443
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4444
|
+
|
4445
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4446
|
+
|
4447
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4448
|
+
|
4449
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4450
|
+
|
4451
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4452
|
+
src0[id],
|
4453
|
+
(device const float *) (src1 + bid*nb11),
|
4454
|
+
(device float *) ( dst + bid*nb1),
|
4455
|
+
ne00,
|
4456
|
+
ne01,
|
4457
|
+
ne02,
|
4458
|
+
ne10,
|
4459
|
+
ne12,
|
4460
|
+
ne0,
|
4461
|
+
ne1,
|
4462
|
+
r2,
|
4463
|
+
r3,
|
4464
|
+
tgpig,
|
4465
|
+
tiisg,
|
4466
|
+
sgitg);
|
4467
|
+
}
|
4468
|
+
|
4469
|
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
4470
|
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
4471
|
+
device const char * ids,
|
4472
|
+
device const char * src1,
|
4473
|
+
device uchar * dst,
|
4474
|
+
constant int64_t & nbi1,
|
4475
|
+
constant int64_t & ne00,
|
4476
|
+
constant int64_t & ne01,
|
4477
|
+
constant int64_t & ne02,
|
4478
|
+
constant uint64_t & nb00,
|
4479
|
+
constant uint64_t & nb01,
|
4480
|
+
constant uint64_t & nb02,
|
4481
|
+
constant int64_t & ne10,
|
4482
|
+
constant int64_t & ne11,
|
4483
|
+
constant int64_t & ne12,
|
4484
|
+
constant int64_t & ne13,
|
4485
|
+
constant uint64_t & nb10,
|
4486
|
+
constant uint64_t & nb11,
|
4487
|
+
constant uint64_t & nb12,
|
4488
|
+
constant int64_t & ne0,
|
4489
|
+
constant int64_t & ne1,
|
4490
|
+
constant int64_t & nb1,
|
4491
|
+
constant uint & r2,
|
4492
|
+
constant uint & r3,
|
4493
|
+
constant int & idx,
|
4494
|
+
device const char * src00,
|
4495
|
+
device const char * src01,
|
4496
|
+
device const char * src02,
|
4497
|
+
device const char * src03,
|
4498
|
+
device const char * src04,
|
4499
|
+
device const char * src05,
|
4500
|
+
device const char * src06,
|
4501
|
+
device const char * src07,
|
4502
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4503
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4504
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4505
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4506
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4507
|
+
|
4508
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4509
|
+
|
4510
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4511
|
+
|
4512
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4513
|
+
|
4514
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4515
|
+
src0[id],
|
4516
|
+
(device const float *) (src1 + bid*nb11),
|
4517
|
+
(device float *) ( dst + bid*nb1),
|
4518
|
+
ne00,
|
4519
|
+
ne01,
|
4520
|
+
ne02,
|
4521
|
+
ne10,
|
4522
|
+
ne12,
|
4523
|
+
ne0,
|
4524
|
+
ne1,
|
4525
|
+
r2,
|
4526
|
+
r3,
|
4527
|
+
tgpig,
|
4528
|
+
tiisg,
|
4529
|
+
sgitg);
|
4530
|
+
}
|
4531
|
+
|
4532
|
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
4533
|
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
4534
|
+
device const char * ids,
|
4535
|
+
device const char * src1,
|
4536
|
+
device uchar * dst,
|
4537
|
+
constant int64_t & nbi1,
|
4538
|
+
constant int64_t & ne00,
|
4539
|
+
constant int64_t & ne01,
|
4540
|
+
constant int64_t & ne02,
|
4541
|
+
constant uint64_t & nb00,
|
4542
|
+
constant uint64_t & nb01,
|
4543
|
+
constant uint64_t & nb02,
|
4544
|
+
constant int64_t & ne10,
|
4545
|
+
constant int64_t & ne11,
|
4546
|
+
constant int64_t & ne12,
|
4547
|
+
constant int64_t & ne13,
|
4548
|
+
constant uint64_t & nb10,
|
4549
|
+
constant uint64_t & nb11,
|
4550
|
+
constant uint64_t & nb12,
|
4551
|
+
constant int64_t & ne0,
|
4552
|
+
constant int64_t & ne1,
|
4553
|
+
constant int64_t & nb1,
|
4554
|
+
constant uint & r2,
|
4555
|
+
constant uint & r3,
|
4556
|
+
constant int & idx,
|
4557
|
+
device const char * src00,
|
4558
|
+
device const char * src01,
|
4559
|
+
device const char * src02,
|
4560
|
+
device const char * src03,
|
4561
|
+
device const char * src04,
|
4562
|
+
device const char * src05,
|
4563
|
+
device const char * src06,
|
4564
|
+
device const char * src07,
|
4565
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4566
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4567
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4568
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4569
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4570
|
+
|
4571
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4572
|
+
|
4573
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4574
|
+
|
4575
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4576
|
+
|
4577
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4578
|
+
src0[id],
|
4579
|
+
(device const float *) (src1 + bid*nb11),
|
4580
|
+
(device float *) ( dst + bid*nb1),
|
4581
|
+
ne00,
|
4582
|
+
ne01,
|
4583
|
+
ne02,
|
4584
|
+
ne10,
|
4585
|
+
ne12,
|
4586
|
+
ne0,
|
4587
|
+
ne1,
|
4588
|
+
r2,
|
4589
|
+
r3,
|
4590
|
+
tgpig,
|
4591
|
+
tiisg,
|
4592
|
+
sgitg);
|
4593
|
+
}
|
4594
|
+
|
4595
|
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
4596
|
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
4597
|
+
device const char * ids,
|
4598
|
+
device const char * src1,
|
4599
|
+
device uchar * dst,
|
4600
|
+
constant int64_t & nbi1,
|
4601
|
+
constant int64_t & ne00,
|
4602
|
+
constant int64_t & ne01,
|
4603
|
+
constant int64_t & ne02,
|
4604
|
+
constant uint64_t & nb00,
|
4605
|
+
constant uint64_t & nb01,
|
4606
|
+
constant uint64_t & nb02,
|
4607
|
+
constant int64_t & ne10,
|
4608
|
+
constant int64_t & ne11,
|
4609
|
+
constant int64_t & ne12,
|
4610
|
+
constant int64_t & ne13,
|
4611
|
+
constant uint64_t & nb10,
|
4612
|
+
constant uint64_t & nb11,
|
4613
|
+
constant uint64_t & nb12,
|
4614
|
+
constant int64_t & ne0,
|
4615
|
+
constant int64_t & ne1,
|
4616
|
+
constant int64_t & nb1,
|
4617
|
+
constant uint & r2,
|
4618
|
+
constant uint & r3,
|
4619
|
+
constant int & idx,
|
4620
|
+
device const char * src00,
|
4621
|
+
device const char * src01,
|
4622
|
+
device const char * src02,
|
4623
|
+
device const char * src03,
|
4624
|
+
device const char * src04,
|
4625
|
+
device const char * src05,
|
4626
|
+
device const char * src06,
|
4627
|
+
device const char * src07,
|
4628
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4629
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4630
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4631
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4632
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4633
|
+
|
4634
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4635
|
+
|
4636
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4637
|
+
|
4638
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4639
|
+
|
4640
|
+
kernel_mul_mv_q2_K_f32_impl(
|
4641
|
+
src0[id],
|
4642
|
+
(device const float *) (src1 + bid*nb11),
|
4643
|
+
(device float *) ( dst + bid*nb1),
|
4644
|
+
ne00,
|
4645
|
+
ne01,
|
4646
|
+
ne02,
|
4647
|
+
ne10,
|
4648
|
+
ne12,
|
4649
|
+
ne0,
|
4650
|
+
ne1,
|
4651
|
+
r2,
|
4652
|
+
r3,
|
4653
|
+
tgpig,
|
4654
|
+
tiisg,
|
4655
|
+
sgitg);
|
4656
|
+
}
|
4657
|
+
|
4658
|
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
4659
|
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
4660
|
+
device const char * ids,
|
4661
|
+
device const char * src1,
|
4662
|
+
device uchar * dst,
|
4663
|
+
constant int64_t & nbi1,
|
4664
|
+
constant int64_t & ne00,
|
4665
|
+
constant int64_t & ne01,
|
4666
|
+
constant int64_t & ne02,
|
4667
|
+
constant uint64_t & nb00,
|
4668
|
+
constant uint64_t & nb01,
|
4669
|
+
constant uint64_t & nb02,
|
4670
|
+
constant int64_t & ne10,
|
4671
|
+
constant int64_t & ne11,
|
4672
|
+
constant int64_t & ne12,
|
4673
|
+
constant int64_t & ne13,
|
4674
|
+
constant uint64_t & nb10,
|
4675
|
+
constant uint64_t & nb11,
|
4676
|
+
constant uint64_t & nb12,
|
4677
|
+
constant int64_t & ne0,
|
4678
|
+
constant int64_t & ne1,
|
4679
|
+
constant int64_t & nb1,
|
4680
|
+
constant uint & r2,
|
4681
|
+
constant uint & r3,
|
4682
|
+
constant int & idx,
|
4683
|
+
device const char * src00,
|
4684
|
+
device const char * src01,
|
4685
|
+
device const char * src02,
|
4686
|
+
device const char * src03,
|
4687
|
+
device const char * src04,
|
4688
|
+
device const char * src05,
|
4689
|
+
device const char * src06,
|
4690
|
+
device const char * src07,
|
4691
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4692
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4693
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4694
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4695
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4696
|
+
|
4697
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4698
|
+
|
4699
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4700
|
+
|
4701
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4702
|
+
|
4703
|
+
kernel_mul_mv_q3_K_f32_impl(
|
4704
|
+
src0[id],
|
4705
|
+
(device const float *) (src1 + bid*nb11),
|
4706
|
+
(device float *) ( dst + bid*nb1),
|
4707
|
+
ne00,
|
4708
|
+
ne01,
|
4709
|
+
ne02,
|
4710
|
+
ne10,
|
4711
|
+
ne12,
|
4712
|
+
ne0,
|
4713
|
+
ne1,
|
4714
|
+
r2,
|
4715
|
+
r3,
|
4716
|
+
tgpig,
|
4717
|
+
tiisg,
|
4718
|
+
sgitg);
|
4719
|
+
}
|
4720
|
+
|
4721
|
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
4722
|
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
4723
|
+
device const char * ids,
|
4724
|
+
device const char * src1,
|
4725
|
+
device uchar * dst,
|
4726
|
+
constant int64_t & nbi1,
|
4727
|
+
constant int64_t & ne00,
|
4728
|
+
constant int64_t & ne01,
|
4729
|
+
constant int64_t & ne02,
|
4730
|
+
constant uint64_t & nb00,
|
4731
|
+
constant uint64_t & nb01,
|
4732
|
+
constant uint64_t & nb02,
|
4733
|
+
constant int64_t & ne10,
|
4734
|
+
constant int64_t & ne11,
|
4735
|
+
constant int64_t & ne12,
|
4736
|
+
constant int64_t & ne13,
|
4737
|
+
constant uint64_t & nb10,
|
4738
|
+
constant uint64_t & nb11,
|
4739
|
+
constant uint64_t & nb12,
|
4740
|
+
constant int64_t & ne0,
|
4741
|
+
constant int64_t & ne1,
|
4742
|
+
constant int64_t & nb1,
|
4743
|
+
constant uint & r2,
|
4744
|
+
constant uint & r3,
|
4745
|
+
constant int & idx,
|
4746
|
+
device const char * src00,
|
4747
|
+
device const char * src01,
|
4748
|
+
device const char * src02,
|
4749
|
+
device const char * src03,
|
4750
|
+
device const char * src04,
|
4751
|
+
device const char * src05,
|
4752
|
+
device const char * src06,
|
4753
|
+
device const char * src07,
|
4754
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4755
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4756
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4757
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4758
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4759
|
+
|
4760
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4761
|
+
|
4762
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4763
|
+
|
4764
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4765
|
+
|
4766
|
+
kernel_mul_mv_q4_K_f32_impl(
|
4767
|
+
src0[id],
|
4768
|
+
(device const float *) (src1 + bid*nb11),
|
4769
|
+
(device float *) ( dst + bid*nb1),
|
4770
|
+
ne00,
|
4771
|
+
ne01,
|
4772
|
+
ne02,
|
4773
|
+
ne10,
|
4774
|
+
ne12,
|
4775
|
+
ne0,
|
4776
|
+
ne1,
|
4777
|
+
r2,
|
4778
|
+
r3,
|
4779
|
+
tgpig,
|
4780
|
+
tiisg,
|
4781
|
+
sgitg);
|
4782
|
+
}
|
4783
|
+
|
4784
|
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
4785
|
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
4786
|
+
device const char * ids,
|
4787
|
+
device const char * src1,
|
4788
|
+
device uchar * dst,
|
4789
|
+
constant int64_t & nbi1,
|
4790
|
+
constant int64_t & ne00,
|
4791
|
+
constant int64_t & ne01,
|
4792
|
+
constant int64_t & ne02,
|
4793
|
+
constant uint64_t & nb00,
|
4794
|
+
constant uint64_t & nb01,
|
4795
|
+
constant uint64_t & nb02,
|
4796
|
+
constant int64_t & ne10,
|
4797
|
+
constant int64_t & ne11,
|
4798
|
+
constant int64_t & ne12,
|
4799
|
+
constant int64_t & ne13,
|
4800
|
+
constant uint64_t & nb10,
|
4801
|
+
constant uint64_t & nb11,
|
4802
|
+
constant uint64_t & nb12,
|
4803
|
+
constant int64_t & ne0,
|
4804
|
+
constant int64_t & ne1,
|
4805
|
+
constant int64_t & nb1,
|
4806
|
+
constant uint & r2,
|
4807
|
+
constant uint & r3,
|
4808
|
+
constant int & idx,
|
4809
|
+
device const char * src00,
|
4810
|
+
device const char * src01,
|
4811
|
+
device const char * src02,
|
4812
|
+
device const char * src03,
|
4813
|
+
device const char * src04,
|
4814
|
+
device const char * src05,
|
4815
|
+
device const char * src06,
|
4816
|
+
device const char * src07,
|
4817
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4818
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4819
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4820
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4821
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4822
|
+
|
4823
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4824
|
+
|
4825
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4826
|
+
|
4827
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4828
|
+
|
4829
|
+
kernel_mul_mv_q5_K_f32_impl(
|
4830
|
+
src0[id],
|
4831
|
+
(device const float *) (src1 + bid*nb11),
|
4832
|
+
(device float *) ( dst + bid*nb1),
|
4833
|
+
ne00,
|
4834
|
+
ne01,
|
4835
|
+
ne02,
|
4836
|
+
ne10,
|
4837
|
+
ne12,
|
4838
|
+
ne0,
|
4839
|
+
ne1,
|
4840
|
+
r2,
|
4841
|
+
r3,
|
4842
|
+
tgpig,
|
4843
|
+
tiisg,
|
4844
|
+
sgitg);
|
4845
|
+
}
|
4846
|
+
|
4847
|
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
4848
|
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
4849
|
+
device const char * ids,
|
4850
|
+
device const char * src1,
|
4851
|
+
device uchar * dst,
|
4852
|
+
constant int64_t & nbi1,
|
4853
|
+
constant int64_t & ne00,
|
4854
|
+
constant int64_t & ne01,
|
4855
|
+
constant int64_t & ne02,
|
4856
|
+
constant uint64_t & nb00,
|
4857
|
+
constant uint64_t & nb01,
|
4858
|
+
constant uint64_t & nb02,
|
4859
|
+
constant int64_t & ne10,
|
4860
|
+
constant int64_t & ne11,
|
4861
|
+
constant int64_t & ne12,
|
4862
|
+
constant int64_t & ne13,
|
4863
|
+
constant uint64_t & nb10,
|
4864
|
+
constant uint64_t & nb11,
|
4865
|
+
constant uint64_t & nb12,
|
4866
|
+
constant int64_t & ne0,
|
4867
|
+
constant int64_t & ne1,
|
4868
|
+
constant int64_t & nb1,
|
4869
|
+
constant uint & r2,
|
4870
|
+
constant uint & r3,
|
4871
|
+
constant int & idx,
|
4872
|
+
device const char * src00,
|
4873
|
+
device const char * src01,
|
4874
|
+
device const char * src02,
|
4875
|
+
device const char * src03,
|
4876
|
+
device const char * src04,
|
4877
|
+
device const char * src05,
|
4878
|
+
device const char * src06,
|
4879
|
+
device const char * src07,
|
4880
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4882
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4883
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4884
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
4885
|
+
|
4886
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
4887
|
+
|
4888
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
4889
|
+
|
4890
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
4891
|
+
|
4892
|
+
kernel_mul_mv_q6_K_f32_impl(
|
4893
|
+
src0[id],
|
4894
|
+
(device const float *) (src1 + bid*nb11),
|
4895
|
+
(device float *) ( dst + bid*nb1),
|
4896
|
+
ne00,
|
4897
|
+
ne01,
|
4898
|
+
ne02,
|
4899
|
+
ne10,
|
4900
|
+
ne12,
|
4901
|
+
ne0,
|
4902
|
+
ne1,
|
4903
|
+
r2,
|
4904
|
+
r3,
|
4905
|
+
tgpig,
|
4906
|
+
tiisg,
|
4907
|
+
sgitg);
|
4908
|
+
}
|