llama_cpp 0.9.4 → 0.10.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +121 -15
- data/ext/llama_cpp/src/ggml-alloc.c +43 -8
- data/ext/llama_cpp/src/ggml-alloc.h +7 -0
- data/ext/llama_cpp/src/ggml-backend-impl.h +46 -21
- data/ext/llama_cpp/src/ggml-backend.c +563 -156
- data/ext/llama_cpp/src/ggml-backend.h +62 -17
- data/ext/llama_cpp/src/ggml-cuda.cu +1270 -434
- data/ext/llama_cpp/src/ggml-cuda.h +9 -1
- data/ext/llama_cpp/src/ggml-impl.h +1 -1
- data/ext/llama_cpp/src/ggml-metal.h +6 -0
- data/ext/llama_cpp/src/ggml-metal.m +535 -175
- data/ext/llama_cpp/src/ggml-metal.metal +888 -237
- data/ext/llama_cpp/src/ggml-opencl.cpp +5 -7
- data/ext/llama_cpp/src/ggml.c +393 -127
- data/ext/llama_cpp/src/ggml.h +59 -7
- data/ext/llama_cpp/src/llama.cpp +791 -357
- data/ext/llama_cpp/src/llama.h +29 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +20 -2
- metadata +3 -3
@@ -3,6 +3,8 @@
|
|
3
3
|
using namespace metal;
|
4
4
|
|
5
5
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
6
|
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
7
|
+
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
6
8
|
|
7
9
|
#define QK4_0 32
|
8
10
|
#define QR4_0 2
|
@@ -39,8 +41,15 @@ typedef struct {
|
|
39
41
|
int8_t qs[QK8_0]; // quants
|
40
42
|
} block_q8_0;
|
41
43
|
|
42
|
-
//
|
43
|
-
|
44
|
+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
45
|
+
|
46
|
+
enum ggml_sort_order {
|
47
|
+
GGML_SORT_ASC,
|
48
|
+
GGML_SORT_DESC,
|
49
|
+
};
|
50
|
+
|
51
|
+
// general-purpose kernel for addition, multiplication and division of two tensors
|
52
|
+
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
44
53
|
// cons: not very efficient
|
45
54
|
kernel void kernel_add(
|
46
55
|
device const char * src0,
|
@@ -81,16 +90,111 @@ kernel void kernel_add(
|
|
81
90
|
const int64_t i12 = i02 % ne12;
|
82
91
|
const int64_t i11 = i01 % ne11;
|
83
92
|
|
84
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01
|
85
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11
|
86
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1
|
93
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
94
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
95
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
96
|
+
|
97
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
98
|
+
const int i10 = i0 % ne10;
|
99
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
|
100
|
+
}
|
101
|
+
}
|
102
|
+
|
103
|
+
kernel void kernel_mul(
|
104
|
+
device const char * src0,
|
105
|
+
device const char * src1,
|
106
|
+
device char * dst,
|
107
|
+
constant int64_t & ne00,
|
108
|
+
constant int64_t & ne01,
|
109
|
+
constant int64_t & ne02,
|
110
|
+
constant int64_t & ne03,
|
111
|
+
constant int64_t & nb00,
|
112
|
+
constant int64_t & nb01,
|
113
|
+
constant int64_t & nb02,
|
114
|
+
constant int64_t & nb03,
|
115
|
+
constant int64_t & ne10,
|
116
|
+
constant int64_t & ne11,
|
117
|
+
constant int64_t & ne12,
|
118
|
+
constant int64_t & ne13,
|
119
|
+
constant int64_t & nb10,
|
120
|
+
constant int64_t & nb11,
|
121
|
+
constant int64_t & nb12,
|
122
|
+
constant int64_t & nb13,
|
123
|
+
constant int64_t & ne0,
|
124
|
+
constant int64_t & ne1,
|
125
|
+
constant int64_t & ne2,
|
126
|
+
constant int64_t & ne3,
|
127
|
+
constant int64_t & nb0,
|
128
|
+
constant int64_t & nb1,
|
129
|
+
constant int64_t & nb2,
|
130
|
+
constant int64_t & nb3,
|
131
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
132
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
133
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
134
|
+
const int64_t i03 = tgpig.z;
|
135
|
+
const int64_t i02 = tgpig.y;
|
136
|
+
const int64_t i01 = tgpig.x;
|
137
|
+
|
138
|
+
const int64_t i13 = i03 % ne13;
|
139
|
+
const int64_t i12 = i02 % ne12;
|
140
|
+
const int64_t i11 = i01 % ne11;
|
141
|
+
|
142
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
143
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
144
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
87
145
|
|
88
146
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
89
|
-
|
147
|
+
const int i10 = i0 % ne10;
|
148
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
|
149
|
+
}
|
150
|
+
}
|
151
|
+
|
152
|
+
kernel void kernel_div(
|
153
|
+
device const char * src0,
|
154
|
+
device const char * src1,
|
155
|
+
device char * dst,
|
156
|
+
constant int64_t & ne00,
|
157
|
+
constant int64_t & ne01,
|
158
|
+
constant int64_t & ne02,
|
159
|
+
constant int64_t & ne03,
|
160
|
+
constant int64_t & nb00,
|
161
|
+
constant int64_t & nb01,
|
162
|
+
constant int64_t & nb02,
|
163
|
+
constant int64_t & nb03,
|
164
|
+
constant int64_t & ne10,
|
165
|
+
constant int64_t & ne11,
|
166
|
+
constant int64_t & ne12,
|
167
|
+
constant int64_t & ne13,
|
168
|
+
constant int64_t & nb10,
|
169
|
+
constant int64_t & nb11,
|
170
|
+
constant int64_t & nb12,
|
171
|
+
constant int64_t & nb13,
|
172
|
+
constant int64_t & ne0,
|
173
|
+
constant int64_t & ne1,
|
174
|
+
constant int64_t & ne2,
|
175
|
+
constant int64_t & ne3,
|
176
|
+
constant int64_t & nb0,
|
177
|
+
constant int64_t & nb1,
|
178
|
+
constant int64_t & nb2,
|
179
|
+
constant int64_t & nb3,
|
180
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
181
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
182
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
183
|
+
const int64_t i03 = tgpig.z;
|
184
|
+
const int64_t i02 = tgpig.y;
|
185
|
+
const int64_t i01 = tgpig.x;
|
186
|
+
|
187
|
+
const int64_t i13 = i03 % ne13;
|
188
|
+
const int64_t i12 = i02 % ne12;
|
189
|
+
const int64_t i11 = i01 % ne11;
|
190
|
+
|
191
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
192
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
193
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
90
194
|
|
91
|
-
|
92
|
-
|
93
|
-
dst_ptr
|
195
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
196
|
+
const int i10 = i0 % ne10;
|
197
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
|
94
198
|
}
|
95
199
|
}
|
96
200
|
|
@@ -105,23 +209,22 @@ kernel void kernel_add_row(
|
|
105
209
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
106
210
|
}
|
107
211
|
|
108
|
-
kernel void
|
212
|
+
kernel void kernel_mul_row(
|
109
213
|
device const float4 * src0,
|
110
214
|
device const float4 * src1,
|
111
215
|
device float4 * dst,
|
216
|
+
constant int64_t & nb [[buffer(27)]],
|
112
217
|
uint tpig[[thread_position_in_grid]]) {
|
113
|
-
dst[tpig] = src0[tpig] * src1[tpig];
|
218
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
114
219
|
}
|
115
220
|
|
116
|
-
|
117
|
-
// broadcast src1 into src0
|
118
|
-
kernel void kernel_mul_row(
|
221
|
+
kernel void kernel_div_row(
|
119
222
|
device const float4 * src0,
|
120
223
|
device const float4 * src1,
|
121
224
|
device float4 * dst,
|
122
|
-
constant int64_t & nb,
|
225
|
+
constant int64_t & nb [[buffer(27)]],
|
123
226
|
uint tpig[[thread_position_in_grid]]) {
|
124
|
-
dst[tpig] = src0[tpig]
|
227
|
+
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
125
228
|
}
|
126
229
|
|
127
230
|
kernel void kernel_scale(
|
@@ -162,6 +265,54 @@ kernel void kernel_sqr(
|
|
162
265
|
dst[tpig] = src0[tpig] * src0[tpig];
|
163
266
|
}
|
164
267
|
|
268
|
+
kernel void kernel_sum_rows(
|
269
|
+
device const float * src0,
|
270
|
+
device float * dst,
|
271
|
+
constant int64_t & ne00,
|
272
|
+
constant int64_t & ne01,
|
273
|
+
constant int64_t & ne02,
|
274
|
+
constant int64_t & ne03,
|
275
|
+
constant int64_t & nb00,
|
276
|
+
constant int64_t & nb01,
|
277
|
+
constant int64_t & nb02,
|
278
|
+
constant int64_t & nb03,
|
279
|
+
constant int64_t & ne10,
|
280
|
+
constant int64_t & ne11,
|
281
|
+
constant int64_t & ne12,
|
282
|
+
constant int64_t & ne13,
|
283
|
+
constant int64_t & nb10,
|
284
|
+
constant int64_t & nb11,
|
285
|
+
constant int64_t & nb12,
|
286
|
+
constant int64_t & nb13,
|
287
|
+
constant int64_t & ne0,
|
288
|
+
constant int64_t & ne1,
|
289
|
+
constant int64_t & ne2,
|
290
|
+
constant int64_t & ne3,
|
291
|
+
constant int64_t & nb0,
|
292
|
+
constant int64_t & nb1,
|
293
|
+
constant int64_t & nb2,
|
294
|
+
constant int64_t & nb3,
|
295
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
296
|
+
int64_t i3 = tpig.z;
|
297
|
+
int64_t i2 = tpig.y;
|
298
|
+
int64_t i1 = tpig.x;
|
299
|
+
|
300
|
+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
301
|
+
return;
|
302
|
+
}
|
303
|
+
|
304
|
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
305
|
+
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
306
|
+
|
307
|
+
float row_sum = 0;
|
308
|
+
|
309
|
+
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
310
|
+
row_sum += src_row[i0];
|
311
|
+
}
|
312
|
+
|
313
|
+
dst_row[0] = row_sum;
|
314
|
+
}
|
315
|
+
|
165
316
|
constant float GELU_COEF_A = 0.044715f;
|
166
317
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
167
318
|
|
@@ -180,10 +331,12 @@ kernel void kernel_gelu(
|
|
180
331
|
|
181
332
|
kernel void kernel_soft_max(
|
182
333
|
device const float * src0,
|
334
|
+
device const float * src1,
|
183
335
|
device float * dst,
|
184
336
|
constant int64_t & ne00,
|
185
337
|
constant int64_t & ne01,
|
186
338
|
constant int64_t & ne02,
|
339
|
+
constant float & scale,
|
187
340
|
threadgroup float * buf [[threadgroup(0)]],
|
188
341
|
uint tgpig[[threadgroup_position_in_grid]],
|
189
342
|
uint tpitg[[thread_position_in_threadgroup]],
|
@@ -194,73 +347,77 @@ kernel void kernel_soft_max(
|
|
194
347
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
195
348
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
196
349
|
|
197
|
-
device const float * psrc0 =
|
198
|
-
device
|
350
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
351
|
+
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
352
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
199
353
|
|
200
354
|
// parallel max
|
201
|
-
float lmax =
|
355
|
+
float lmax = -INFINITY;
|
202
356
|
|
203
|
-
for (int i00 = tpitg
|
204
|
-
lmax = MAX(lmax, psrc0[i00]);
|
357
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
358
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
205
359
|
}
|
206
360
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
361
|
+
// find the max value in the block
|
362
|
+
float max_val = simd_max(lmax);
|
363
|
+
if (ntg > N_SIMDWIDTH) {
|
364
|
+
if (sgitg == 0) {
|
365
|
+
buf[tiisg] = -INFINITY;
|
366
|
+
}
|
211
367
|
|
212
|
-
|
368
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
213
369
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
218
|
-
}
|
219
|
-
}
|
370
|
+
if (tiisg == 0) {
|
371
|
+
buf[sgitg] = max_val;
|
372
|
+
}
|
220
373
|
|
221
|
-
|
374
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
222
375
|
|
223
|
-
|
376
|
+
max_val = buf[tiisg];
|
377
|
+
max_val = simd_max(max_val);
|
378
|
+
}
|
224
379
|
|
225
380
|
// parallel sum
|
226
381
|
float lsum = 0.0f;
|
227
382
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
228
|
-
const float exp_psrc0 = exp(psrc0[i00] -
|
383
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
229
384
|
lsum += exp_psrc0;
|
230
|
-
// Remember the result of exp here. exp is expensive, so we really do not
|
231
|
-
// wish to compute it twice.
|
232
385
|
pdst[i00] = exp_psrc0;
|
233
386
|
}
|
234
387
|
|
235
388
|
float sum = simd_sum(lsum);
|
236
|
-
if (
|
237
|
-
|
238
|
-
|
389
|
+
if (ntg > N_SIMDWIDTH) {
|
390
|
+
if (sgitg == 0) {
|
391
|
+
buf[tiisg] = 0.0f;
|
392
|
+
}
|
239
393
|
|
240
|
-
|
394
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
241
395
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
buf[tpitg] += buf[tpitg + i];
|
246
|
-
}
|
247
|
-
}
|
396
|
+
if (tiisg == 0) {
|
397
|
+
buf[sgitg] = sum;
|
398
|
+
}
|
248
399
|
|
249
|
-
|
400
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
401
|
+
|
402
|
+
sum = buf[tiisg];
|
403
|
+
sum = simd_sum(sum);
|
404
|
+
}
|
250
405
|
|
251
|
-
|
406
|
+
const float inv_sum = 1.0f/sum;
|
252
407
|
|
253
408
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
254
|
-
pdst[i00]
|
409
|
+
pdst[i00] *= inv_sum;
|
255
410
|
}
|
256
411
|
}
|
257
412
|
|
258
413
|
kernel void kernel_soft_max_4(
|
259
414
|
device const float * src0,
|
415
|
+
device const float * src1,
|
260
416
|
device float * dst,
|
261
417
|
constant int64_t & ne00,
|
262
418
|
constant int64_t & ne01,
|
263
419
|
constant int64_t & ne02,
|
420
|
+
constant float & scale,
|
264
421
|
threadgroup float * buf [[threadgroup(0)]],
|
265
422
|
uint tgpig[[threadgroup_position_in_grid]],
|
266
423
|
uint tpitg[[thread_position_in_threadgroup]],
|
@@ -271,64 +428,68 @@ kernel void kernel_soft_max_4(
|
|
271
428
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
272
429
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
273
430
|
|
274
|
-
device const float4 * psrc4 =
|
275
|
-
device
|
431
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
432
|
+
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
433
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
276
434
|
|
277
435
|
// parallel max
|
278
|
-
float4 lmax4 =
|
436
|
+
float4 lmax4 = -INFINITY;
|
279
437
|
|
280
|
-
for (int i00 = tpitg
|
281
|
-
lmax4 = fmax(lmax4, psrc4[i00]);
|
438
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
439
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
282
440
|
}
|
283
441
|
|
284
442
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
285
|
-
float max = simd_max(lmax);
|
286
|
-
if (tiisg == 0) {
|
287
|
-
buf[sgitg] = max;
|
288
|
-
}
|
289
443
|
|
290
|
-
|
444
|
+
float max_val = simd_max(lmax);
|
445
|
+
if (ntg > N_SIMDWIDTH) {
|
446
|
+
if (sgitg == 0) {
|
447
|
+
buf[tiisg] = -INFINITY;
|
448
|
+
}
|
291
449
|
|
292
|
-
|
293
|
-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
294
|
-
if (tpitg < i) {
|
295
|
-
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
296
|
-
}
|
297
|
-
}
|
450
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
298
451
|
|
299
|
-
|
452
|
+
if (tiisg == 0) {
|
453
|
+
buf[sgitg] = max_val;
|
454
|
+
}
|
300
455
|
|
301
|
-
|
456
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
457
|
+
|
458
|
+
max_val = buf[tiisg];
|
459
|
+
max_val = simd_max(max_val);
|
460
|
+
}
|
302
461
|
|
303
462
|
// parallel sum
|
304
463
|
float4 lsum4 = 0.0f;
|
305
464
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
306
|
-
const float4 exp_psrc4 = exp(psrc4[i00] -
|
465
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
307
466
|
lsum4 += exp_psrc4;
|
308
467
|
pdst4[i00] = exp_psrc4;
|
309
468
|
}
|
310
469
|
|
311
470
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
312
471
|
float sum = simd_sum(lsum);
|
313
|
-
if (
|
314
|
-
|
315
|
-
|
472
|
+
if (ntg > N_SIMDWIDTH) {
|
473
|
+
if (sgitg == 0) {
|
474
|
+
buf[tiisg] = 0.0f;
|
475
|
+
}
|
316
476
|
|
317
|
-
|
477
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
318
478
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
buf[tpitg] += buf[tpitg + i];
|
323
|
-
}
|
324
|
-
}
|
479
|
+
if (tiisg == 0) {
|
480
|
+
buf[sgitg] = sum;
|
481
|
+
}
|
325
482
|
|
326
|
-
|
483
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
484
|
+
|
485
|
+
sum = buf[tiisg];
|
486
|
+
sum = simd_sum(sum);
|
487
|
+
}
|
327
488
|
|
328
|
-
|
489
|
+
const float inv_sum = 1.0f/sum;
|
329
490
|
|
330
491
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
331
|
-
pdst4[i00]
|
492
|
+
pdst4[i00] *= inv_sum;
|
332
493
|
}
|
333
494
|
}
|
334
495
|
|
@@ -435,14 +596,13 @@ kernel void kernel_rms_norm(
|
|
435
596
|
constant int64_t & ne00,
|
436
597
|
constant uint64_t & nb01,
|
437
598
|
constant float & eps,
|
438
|
-
threadgroup float *
|
599
|
+
threadgroup float * buf [[threadgroup(0)]],
|
439
600
|
uint tgpig[[threadgroup_position_in_grid]],
|
440
601
|
uint tpitg[[thread_position_in_threadgroup]],
|
441
602
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
442
603
|
uint tiisg[[thread_index_in_simdgroup]],
|
443
604
|
uint ntg[[threads_per_threadgroup]]) {
|
444
|
-
device const float4 * x
|
445
|
-
device const float * x_scalar = (device const float *) x;
|
605
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
446
606
|
|
447
607
|
float4 sumf = 0;
|
448
608
|
float all_sum = 0;
|
@@ -453,40 +613,30 @@ kernel void kernel_rms_norm(
|
|
453
613
|
}
|
454
614
|
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
455
615
|
all_sum = simd_sum(all_sum);
|
456
|
-
if (
|
457
|
-
|
458
|
-
|
616
|
+
if (ntg > N_SIMDWIDTH) {
|
617
|
+
if (sgitg == 0) {
|
618
|
+
buf[tiisg] = 0.0f;
|
619
|
+
}
|
459
620
|
|
460
|
-
|
621
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
461
622
|
|
462
|
-
|
463
|
-
|
464
|
-
if (tpitg < i) {
|
465
|
-
sum[tpitg] += sum[tpitg + i];
|
466
|
-
}
|
467
|
-
}
|
468
|
-
if (tpitg == 0) {
|
469
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
470
|
-
sum[0] += x_scalar[i];
|
623
|
+
if (tiisg == 0) {
|
624
|
+
buf[sgitg] = all_sum;
|
471
625
|
}
|
472
|
-
sum[0] /= ne00;
|
473
|
-
}
|
474
626
|
|
475
|
-
|
627
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
628
|
+
|
629
|
+
all_sum = buf[tiisg];
|
630
|
+
all_sum = simd_sum(all_sum);
|
631
|
+
}
|
476
632
|
|
477
|
-
const float mean =
|
633
|
+
const float mean = all_sum/ne00;
|
478
634
|
const float scale = 1.0f/sqrt(mean + eps);
|
479
635
|
|
480
636
|
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
481
|
-
device float * y_scalar = (device float *) y;
|
482
637
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
483
638
|
y[i00] = x[i00] * scale;
|
484
639
|
}
|
485
|
-
if (tpitg == 0) {
|
486
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
487
|
-
y_scalar[i00] = x_scalar[i00] * scale;
|
488
|
-
}
|
489
|
-
}
|
490
640
|
}
|
491
641
|
|
492
642
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
@@ -576,15 +726,25 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
576
726
|
// putting them in the kernel cause a significant performance penalty
|
577
727
|
#define N_DST 4 // each SIMD group works on 4 rows
|
578
728
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
579
|
-
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
580
729
|
//Note: This is a template, but strictly speaking it only applies to
|
581
730
|
// quantizations where the block size is 32. It also does not
|
582
731
|
// giard against the number of rows not being divisible by
|
583
732
|
// N_DST, so this is another explicit assumption of the implementation.
|
584
733
|
template<typename block_q_type, int nr, int nsg, int nw>
|
585
|
-
void mul_vec_q_n_f32(
|
586
|
-
|
587
|
-
|
734
|
+
void mul_vec_q_n_f32(
|
735
|
+
device const void * src0,
|
736
|
+
device const float * src1,
|
737
|
+
device float * dst,
|
738
|
+
int64_t ne00,
|
739
|
+
int64_t ne01,
|
740
|
+
int64_t ne02,
|
741
|
+
int64_t ne10,
|
742
|
+
int64_t ne12,
|
743
|
+
int64_t ne0,
|
744
|
+
int64_t ne1,
|
745
|
+
uint r2,
|
746
|
+
uint r3,
|
747
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
588
748
|
const int nb = ne00/QK4_0;
|
589
749
|
|
590
750
|
const int r0 = tgpig.x;
|
@@ -593,7 +753,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
593
753
|
|
594
754
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
595
755
|
|
596
|
-
const uint
|
756
|
+
const uint i12 = im%ne12;
|
757
|
+
const uint i13 = im/ne12;
|
758
|
+
|
759
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
597
760
|
|
598
761
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
599
762
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
@@ -643,13 +806,14 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
643
806
|
constant int64_t & ne02[[buffer(5)]],
|
644
807
|
constant int64_t & ne10[[buffer(9)]],
|
645
808
|
constant int64_t & ne12[[buffer(11)]],
|
646
|
-
constant int64_t & ne0[[buffer(15)]],
|
647
|
-
constant int64_t & ne1[[buffer(16)]],
|
648
|
-
constant uint &
|
809
|
+
constant int64_t & ne0 [[buffer(15)]],
|
810
|
+
constant int64_t & ne1 [[buffer(16)]],
|
811
|
+
constant uint & r2 [[buffer(17)]],
|
812
|
+
constant uint & r3 [[buffer(18)]],
|
649
813
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
650
814
|
uint tiisg[[thread_index_in_simdgroup]],
|
651
815
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
652
|
-
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
816
|
+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
653
817
|
}
|
654
818
|
|
655
819
|
kernel void kernel_mul_mv_q4_1_f32(
|
@@ -661,13 +825,14 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
661
825
|
constant int64_t & ne02[[buffer(5)]],
|
662
826
|
constant int64_t & ne10[[buffer(9)]],
|
663
827
|
constant int64_t & ne12[[buffer(11)]],
|
664
|
-
constant int64_t & ne0[[buffer(15)]],
|
665
|
-
constant int64_t & ne1[[buffer(16)]],
|
666
|
-
constant uint &
|
828
|
+
constant int64_t & ne0 [[buffer(15)]],
|
829
|
+
constant int64_t & ne1 [[buffer(16)]],
|
830
|
+
constant uint & r2 [[buffer(17)]],
|
831
|
+
constant uint & r3 [[buffer(18)]],
|
667
832
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
668
833
|
uint tiisg[[thread_index_in_simdgroup]],
|
669
834
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
670
|
-
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
835
|
+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
671
836
|
}
|
672
837
|
|
673
838
|
kernel void kernel_mul_mv_q5_0_f32(
|
@@ -679,13 +844,14 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
679
844
|
constant int64_t & ne02[[buffer(5)]],
|
680
845
|
constant int64_t & ne10[[buffer(9)]],
|
681
846
|
constant int64_t & ne12[[buffer(11)]],
|
682
|
-
constant int64_t & ne0[[buffer(15)]],
|
683
|
-
constant int64_t & ne1[[buffer(16)]],
|
684
|
-
constant uint &
|
847
|
+
constant int64_t & ne0 [[buffer(15)]],
|
848
|
+
constant int64_t & ne1 [[buffer(16)]],
|
849
|
+
constant uint & r2 [[buffer(17)]],
|
850
|
+
constant uint & r3 [[buffer(18)]],
|
685
851
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
686
852
|
uint tiisg[[thread_index_in_simdgroup]],
|
687
853
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
688
|
-
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
854
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
689
855
|
}
|
690
856
|
|
691
857
|
kernel void kernel_mul_mv_q5_1_f32(
|
@@ -697,13 +863,14 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
697
863
|
constant int64_t & ne02[[buffer(5)]],
|
698
864
|
constant int64_t & ne10[[buffer(9)]],
|
699
865
|
constant int64_t & ne12[[buffer(11)]],
|
700
|
-
constant int64_t & ne0[[buffer(15)]],
|
701
|
-
constant int64_t & ne1[[buffer(16)]],
|
702
|
-
constant uint &
|
866
|
+
constant int64_t & ne0 [[buffer(15)]],
|
867
|
+
constant int64_t & ne1 [[buffer(16)]],
|
868
|
+
constant uint & r2 [[buffer(17)]],
|
869
|
+
constant uint & r3 [[buffer(18)]],
|
703
870
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
704
871
|
uint tiisg[[thread_index_in_simdgroup]],
|
705
872
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
706
|
-
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
873
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
707
874
|
}
|
708
875
|
|
709
876
|
|
@@ -718,9 +885,10 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
718
885
|
constant int64_t & ne02[[buffer(5)]],
|
719
886
|
constant int64_t & ne10[[buffer(9)]],
|
720
887
|
constant int64_t & ne12[[buffer(11)]],
|
721
|
-
constant int64_t & ne0[[buffer(15)]],
|
722
|
-
constant int64_t & ne1[[buffer(16)]],
|
723
|
-
constant uint &
|
888
|
+
constant int64_t & ne0 [[buffer(15)]],
|
889
|
+
constant int64_t & ne1 [[buffer(16)]],
|
890
|
+
constant uint & r2 [[buffer(17)]],
|
891
|
+
constant uint & r3 [[buffer(18)]],
|
724
892
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
725
893
|
uint tiisg[[thread_index_in_simdgroup]],
|
726
894
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -732,8 +900,14 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
732
900
|
const int r0 = tgpig.x;
|
733
901
|
const int r1 = tgpig.y;
|
734
902
|
const int im = tgpig.z;
|
903
|
+
|
735
904
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
736
|
-
|
905
|
+
|
906
|
+
const uint i12 = im%ne12;
|
907
|
+
const uint i13 = im/ne12;
|
908
|
+
|
909
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
910
|
+
|
737
911
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
738
912
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
739
913
|
|
@@ -791,6 +965,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
791
965
|
constant uint64_t & nb12,
|
792
966
|
constant int64_t & ne0,
|
793
967
|
constant int64_t & ne1,
|
968
|
+
constant uint & r2 [[buffer(17)]],
|
969
|
+
constant uint & r3 [[buffer(18)]],
|
794
970
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
795
971
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
796
972
|
|
@@ -798,7 +974,12 @@ kernel void kernel_mul_mv_f32_f32(
|
|
798
974
|
const int64_t rb = tgpig.y*N_F32_F32;
|
799
975
|
const int64_t im = tgpig.z;
|
800
976
|
|
801
|
-
|
977
|
+
const uint i12 = im%ne12;
|
978
|
+
const uint i13 = im/ne12;
|
979
|
+
|
980
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
981
|
+
|
982
|
+
device const float * x = (device const float *) (src0 + offset0);
|
802
983
|
|
803
984
|
if (ne00 < 128) {
|
804
985
|
for (int row = 0; row < N_F32_F32; ++row) {
|
@@ -864,6 +1045,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
864
1045
|
constant uint64_t & nb12,
|
865
1046
|
constant int64_t & ne0,
|
866
1047
|
constant int64_t & ne1,
|
1048
|
+
constant uint & r2 [[buffer(17)]],
|
1049
|
+
constant uint & r3 [[buffer(18)]],
|
867
1050
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
868
1051
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
869
1052
|
|
@@ -871,7 +1054,12 @@ kernel void kernel_mul_mv_f16_f16(
|
|
871
1054
|
const int64_t rb = tgpig.y*N_F16_F16;
|
872
1055
|
const int64_t im = tgpig.z;
|
873
1056
|
|
874
|
-
|
1057
|
+
const uint i12 = im%ne12;
|
1058
|
+
const uint i13 = im/ne12;
|
1059
|
+
|
1060
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1061
|
+
|
1062
|
+
device const half * x = (device const half *) (src0 + offset0);
|
875
1063
|
|
876
1064
|
if (ne00 < 128) {
|
877
1065
|
for (int row = 0; row < N_F16_F16; ++row) {
|
@@ -935,6 +1123,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
935
1123
|
constant uint64_t & nb12,
|
936
1124
|
constant int64_t & ne0,
|
937
1125
|
constant int64_t & ne1,
|
1126
|
+
constant uint & r2 [[buffer(17)]],
|
1127
|
+
constant uint & r3 [[buffer(18)]],
|
938
1128
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
939
1129
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
940
1130
|
|
@@ -942,7 +1132,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
942
1132
|
const int64_t r1 = tgpig.y;
|
943
1133
|
const int64_t im = tgpig.z;
|
944
1134
|
|
945
|
-
|
1135
|
+
const uint i12 = im%ne12;
|
1136
|
+
const uint i13 = im/ne12;
|
1137
|
+
|
1138
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1139
|
+
|
1140
|
+
device const half * x = (device const half *) (src0 + offset0);
|
946
1141
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
947
1142
|
|
948
1143
|
float sumf = 0;
|
@@ -989,6 +1184,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
989
1184
|
constant uint64_t & nb12,
|
990
1185
|
constant int64_t & ne0,
|
991
1186
|
constant int64_t & ne1,
|
1187
|
+
constant uint & r2 [[buffer(17)]],
|
1188
|
+
constant uint & r3 [[buffer(18)]],
|
992
1189
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
993
1190
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
994
1191
|
|
@@ -996,7 +1193,12 @@ kernel void kernel_mul_mv_f16_f32(
|
|
996
1193
|
const int64_t rb = tgpig.y*N_F16_F32;
|
997
1194
|
const int64_t im = tgpig.z;
|
998
1195
|
|
999
|
-
|
1196
|
+
const uint i12 = im%ne12;
|
1197
|
+
const uint i13 = im/ne12;
|
1198
|
+
|
1199
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1200
|
+
|
1201
|
+
device const half * x = (device const half *) (src0 + offset0);
|
1000
1202
|
|
1001
1203
|
if (ne00 < 128) {
|
1002
1204
|
for (int row = 0; row < N_F16_F32; ++row) {
|
@@ -1061,6 +1263,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1061
1263
|
constant uint64_t & nb12,
|
1062
1264
|
constant int64_t & ne0,
|
1063
1265
|
constant int64_t & ne1,
|
1266
|
+
constant uint & r2 [[buffer(17)]],
|
1267
|
+
constant uint & r3 [[buffer(18)]],
|
1064
1268
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1065
1269
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1066
1270
|
|
@@ -1068,7 +1272,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1068
1272
|
const int64_t r0 = tgpig.x;
|
1069
1273
|
const int64_t im = tgpig.z;
|
1070
1274
|
|
1071
|
-
|
1275
|
+
const uint i12 = im%ne12;
|
1276
|
+
const uint i13 = im/ne12;
|
1277
|
+
|
1278
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
1279
|
+
|
1280
|
+
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
1072
1281
|
|
1073
1282
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
1074
1283
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
@@ -1120,17 +1329,21 @@ kernel void kernel_alibi_f32(
|
|
1120
1329
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1121
1330
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1122
1331
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1332
|
+
const int64_t k = i3*ne3 + i2;
|
1123
1333
|
|
1124
|
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1125
1334
|
float m_k;
|
1126
|
-
if (
|
1127
|
-
m_k = pow(m0,
|
1335
|
+
if (k < n_heads_log2_floor) {
|
1336
|
+
m_k = pow(m0, k + 1);
|
1128
1337
|
} else {
|
1129
|
-
m_k = pow(m1, 2 * (
|
1338
|
+
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
1130
1339
|
}
|
1340
|
+
|
1341
|
+
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
1342
|
+
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
1131
1343
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1132
|
-
|
1133
|
-
|
1344
|
+
const float src_v = *(device float *)(src_row + i00*nb00);
|
1345
|
+
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
1346
|
+
*dst_v = i00 * m_k + src_v;
|
1134
1347
|
}
|
1135
1348
|
}
|
1136
1349
|
|
@@ -1335,6 +1548,58 @@ kernel void kernel_im2col_f16(
|
|
1335
1548
|
}
|
1336
1549
|
}
|
1337
1550
|
|
1551
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
1552
|
+
typedef void (argsort_t)(
|
1553
|
+
device const float * x,
|
1554
|
+
device int32_t * dst,
|
1555
|
+
constant int64_t & ncols,
|
1556
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1557
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1558
|
+
|
1559
|
+
template<ggml_sort_order order>
|
1560
|
+
kernel void kernel_argsort_f32_i32(
|
1561
|
+
device const float * x,
|
1562
|
+
device int32_t * dst,
|
1563
|
+
constant int64_t & ncols,
|
1564
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1565
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1566
|
+
// bitonic sort
|
1567
|
+
int col = tpitg[0];
|
1568
|
+
int row = tgpig[1];
|
1569
|
+
|
1570
|
+
if (col >= ncols) return;
|
1571
|
+
|
1572
|
+
device const float * x_row = x + row * ncols;
|
1573
|
+
device int32_t * dst_row = dst + row * ncols;
|
1574
|
+
|
1575
|
+
// initialize indices
|
1576
|
+
if (col < ncols) {
|
1577
|
+
dst_row[col] = col;
|
1578
|
+
}
|
1579
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1580
|
+
|
1581
|
+
for (int k = 2; k <= ncols; k *= 2) {
|
1582
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
1583
|
+
int ixj = col ^ j;
|
1584
|
+
if (ixj > col) {
|
1585
|
+
if ((col & k) == 0) {
|
1586
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
1587
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
1588
|
+
}
|
1589
|
+
} else {
|
1590
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
1591
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
1592
|
+
}
|
1593
|
+
}
|
1594
|
+
}
|
1595
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1596
|
+
}
|
1597
|
+
}
|
1598
|
+
}
|
1599
|
+
|
1600
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
1601
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
1602
|
+
|
1338
1603
|
kernel void kernel_cpy_f16_f16(
|
1339
1604
|
device const half * src0,
|
1340
1605
|
device half * dst,
|
@@ -1460,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
|
|
1460
1725
|
}
|
1461
1726
|
}
|
1462
1727
|
|
1728
|
+
kernel void kernel_cpy_f32_q8_0(
|
1729
|
+
device const float * src0,
|
1730
|
+
device void * dst,
|
1731
|
+
constant int64_t & ne00,
|
1732
|
+
constant int64_t & ne01,
|
1733
|
+
constant int64_t & ne02,
|
1734
|
+
constant int64_t & ne03,
|
1735
|
+
constant uint64_t & nb00,
|
1736
|
+
constant uint64_t & nb01,
|
1737
|
+
constant uint64_t & nb02,
|
1738
|
+
constant uint64_t & nb03,
|
1739
|
+
constant int64_t & ne0,
|
1740
|
+
constant int64_t & ne1,
|
1741
|
+
constant int64_t & ne2,
|
1742
|
+
constant int64_t & ne3,
|
1743
|
+
constant uint64_t & nb0,
|
1744
|
+
constant uint64_t & nb1,
|
1745
|
+
constant uint64_t & nb2,
|
1746
|
+
constant uint64_t & nb3,
|
1747
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1748
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1749
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1750
|
+
const int64_t i03 = tgpig[2];
|
1751
|
+
const int64_t i02 = tgpig[1];
|
1752
|
+
const int64_t i01 = tgpig[0];
|
1753
|
+
|
1754
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1755
|
+
|
1756
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1757
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1758
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1759
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
1760
|
+
|
1761
|
+
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1762
|
+
|
1763
|
+
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
1764
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1765
|
+
|
1766
|
+
float amax = 0.0f; // absolute max
|
1767
|
+
|
1768
|
+
for (int j = 0; j < QK8_0; j++) {
|
1769
|
+
const float v = src[j];
|
1770
|
+
amax = MAX(amax, fabs(v));
|
1771
|
+
}
|
1772
|
+
|
1773
|
+
const float d = amax / ((1 << 7) - 1);
|
1774
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1775
|
+
|
1776
|
+
dst_data[i00/QK8_0].d = d;
|
1777
|
+
|
1778
|
+
for (int j = 0; j < QK8_0; ++j) {
|
1779
|
+
const float x0 = src[j]*id;
|
1780
|
+
|
1781
|
+
dst_data[i00/QK8_0].qs[j] = round(x0);
|
1782
|
+
}
|
1783
|
+
}
|
1784
|
+
}
|
1785
|
+
|
1786
|
+
kernel void kernel_cpy_f32_q4_0(
|
1787
|
+
device const float * src0,
|
1788
|
+
device void * dst,
|
1789
|
+
constant int64_t & ne00,
|
1790
|
+
constant int64_t & ne01,
|
1791
|
+
constant int64_t & ne02,
|
1792
|
+
constant int64_t & ne03,
|
1793
|
+
constant uint64_t & nb00,
|
1794
|
+
constant uint64_t & nb01,
|
1795
|
+
constant uint64_t & nb02,
|
1796
|
+
constant uint64_t & nb03,
|
1797
|
+
constant int64_t & ne0,
|
1798
|
+
constant int64_t & ne1,
|
1799
|
+
constant int64_t & ne2,
|
1800
|
+
constant int64_t & ne3,
|
1801
|
+
constant uint64_t & nb0,
|
1802
|
+
constant uint64_t & nb1,
|
1803
|
+
constant uint64_t & nb2,
|
1804
|
+
constant uint64_t & nb3,
|
1805
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1806
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1807
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1808
|
+
const int64_t i03 = tgpig[2];
|
1809
|
+
const int64_t i02 = tgpig[1];
|
1810
|
+
const int64_t i01 = tgpig[0];
|
1811
|
+
|
1812
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1813
|
+
|
1814
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1815
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1816
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1817
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
|
1818
|
+
|
1819
|
+
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1820
|
+
|
1821
|
+
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
|
1822
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1823
|
+
|
1824
|
+
float amax = 0.0f; // absolute max
|
1825
|
+
float max = 0.0f;
|
1826
|
+
|
1827
|
+
for (int j = 0; j < QK4_0; j++) {
|
1828
|
+
const float v = src[j];
|
1829
|
+
if (amax < fabs(v)) {
|
1830
|
+
amax = fabs(v);
|
1831
|
+
max = v;
|
1832
|
+
}
|
1833
|
+
}
|
1834
|
+
|
1835
|
+
const float d = max / -8;
|
1836
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1837
|
+
|
1838
|
+
dst_data[i00/QK4_0].d = d;
|
1839
|
+
|
1840
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
1841
|
+
const float x0 = src[0 + j]*id;
|
1842
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
1843
|
+
|
1844
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
1845
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
1846
|
+
|
1847
|
+
dst_data[i00/QK4_0].qs[j] = xi0;
|
1848
|
+
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
1849
|
+
}
|
1850
|
+
}
|
1851
|
+
}
|
1852
|
+
|
1853
|
+
kernel void kernel_cpy_f32_q4_1(
|
1854
|
+
device const float * src0,
|
1855
|
+
device void * dst,
|
1856
|
+
constant int64_t & ne00,
|
1857
|
+
constant int64_t & ne01,
|
1858
|
+
constant int64_t & ne02,
|
1859
|
+
constant int64_t & ne03,
|
1860
|
+
constant uint64_t & nb00,
|
1861
|
+
constant uint64_t & nb01,
|
1862
|
+
constant uint64_t & nb02,
|
1863
|
+
constant uint64_t & nb03,
|
1864
|
+
constant int64_t & ne0,
|
1865
|
+
constant int64_t & ne1,
|
1866
|
+
constant int64_t & ne2,
|
1867
|
+
constant int64_t & ne3,
|
1868
|
+
constant uint64_t & nb0,
|
1869
|
+
constant uint64_t & nb1,
|
1870
|
+
constant uint64_t & nb2,
|
1871
|
+
constant uint64_t & nb3,
|
1872
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1873
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1874
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1875
|
+
const int64_t i03 = tgpig[2];
|
1876
|
+
const int64_t i02 = tgpig[1];
|
1877
|
+
const int64_t i01 = tgpig[0];
|
1878
|
+
|
1879
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1880
|
+
|
1881
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
1882
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1883
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1884
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
|
1885
|
+
|
1886
|
+
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1887
|
+
|
1888
|
+
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
|
1889
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1890
|
+
|
1891
|
+
float min = FLT_MAX;
|
1892
|
+
float max = -FLT_MAX;
|
1893
|
+
|
1894
|
+
for (int j = 0; j < QK4_1; j++) {
|
1895
|
+
const float v = src[j];
|
1896
|
+
if (min > v) min = v;
|
1897
|
+
if (max < v) max = v;
|
1898
|
+
}
|
1899
|
+
|
1900
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
1901
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1902
|
+
|
1903
|
+
dst_data[i00/QK4_1].d = d;
|
1904
|
+
dst_data[i00/QK4_1].m = min;
|
1905
|
+
|
1906
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
1907
|
+
const float x0 = (src[0 + j] - min)*id;
|
1908
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
1909
|
+
|
1910
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
1911
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
1912
|
+
|
1913
|
+
dst_data[i00/QK4_1].qs[j] = xi0;
|
1914
|
+
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
1915
|
+
}
|
1916
|
+
}
|
1917
|
+
}
|
1918
|
+
|
1463
1919
|
kernel void kernel_concat(
|
1464
1920
|
device const char * src0,
|
1465
1921
|
device const char * src1,
|
@@ -1617,23 +2073,30 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1617
2073
|
constant int64_t & ne02[[buffer(5)]],
|
1618
2074
|
constant int64_t & ne10[[buffer(9)]],
|
1619
2075
|
constant int64_t & ne12[[buffer(11)]],
|
1620
|
-
constant int64_t & ne0[[buffer(15)]],
|
1621
|
-
constant int64_t & ne1[[buffer(16)]],
|
1622
|
-
constant uint &
|
2076
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2077
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2078
|
+
constant uint & r2 [[buffer(17)]],
|
2079
|
+
constant uint & r3 [[buffer(18)]],
|
1623
2080
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1624
|
-
uint
|
1625
|
-
uint
|
2081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1626
2083
|
|
1627
2084
|
const int nb = ne00/QK_K;
|
1628
2085
|
const int r0 = tgpig.x;
|
1629
2086
|
const int r1 = tgpig.y;
|
1630
|
-
const int
|
2087
|
+
const int im = tgpig.z;
|
1631
2088
|
|
1632
2089
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1633
2090
|
const int ib_row = first_row * nb;
|
1634
|
-
|
2091
|
+
|
2092
|
+
const uint i12 = im%ne12;
|
2093
|
+
const uint i13 = im/ne12;
|
2094
|
+
|
2095
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2096
|
+
|
1635
2097
|
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
1636
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2098
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2099
|
+
|
1637
2100
|
float yl[32];
|
1638
2101
|
float sumf[N_DST]={0.f}, all_sum;
|
1639
2102
|
|
@@ -1642,11 +2105,11 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1642
2105
|
#if QK_K == 256
|
1643
2106
|
const int ix = tiisg/8; // 0...3
|
1644
2107
|
const int it = tiisg%8; // 0...7
|
1645
|
-
const int
|
2108
|
+
const int iq = it/4; // 0 or 1
|
1646
2109
|
const int ir = it%4; // 0...3
|
1647
2110
|
const int is = (8*ir)/16;// 0 or 1
|
1648
2111
|
|
1649
|
-
device const float * y4 = y + ix * QK_K + 128 *
|
2112
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
1650
2113
|
|
1651
2114
|
for (int ib = ix; ib < nb; ib += 4) {
|
1652
2115
|
|
@@ -1658,8 +2121,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1658
2121
|
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
1659
2122
|
}
|
1660
2123
|
|
1661
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*
|
1662
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 *
|
2124
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
2125
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
1663
2126
|
device const half * dh = &x[ib].d;
|
1664
2127
|
|
1665
2128
|
for (int row = 0; row < N_DST; row++) {
|
@@ -1746,7 +2209,7 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
1746
2209
|
for (int row = 0; row < N_DST; ++row) {
|
1747
2210
|
all_sum = simd_sum(sumf[row]);
|
1748
2211
|
if (tiisg == 0) {
|
1749
|
-
dst[r1*ne0 +
|
2212
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
1750
2213
|
}
|
1751
2214
|
}
|
1752
2215
|
}
|
@@ -1761,9 +2224,10 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1761
2224
|
constant int64_t & ne02[[buffer(5)]],
|
1762
2225
|
constant int64_t & ne10[[buffer(9)]],
|
1763
2226
|
constant int64_t & ne12[[buffer(11)]],
|
1764
|
-
constant int64_t & ne0[[buffer(15)]],
|
1765
|
-
constant int64_t & ne1[[buffer(16)]],
|
1766
|
-
constant uint &
|
2227
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2228
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2229
|
+
constant uint & r2 [[buffer(17)]],
|
2230
|
+
constant uint & r3 [[buffer(18)]],
|
1767
2231
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1768
2232
|
uint tiisg[[thread_index_in_simdgroup]],
|
1769
2233
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1772,12 +2236,17 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1772
2236
|
|
1773
2237
|
const int64_t r0 = tgpig.x;
|
1774
2238
|
const int64_t r1 = tgpig.y;
|
1775
|
-
const int64_t
|
2239
|
+
const int64_t im = tgpig.z;
|
1776
2240
|
|
1777
2241
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
1778
|
-
|
2242
|
+
|
2243
|
+
const uint i12 = im%ne12;
|
2244
|
+
const uint i13 = im/ne12;
|
2245
|
+
|
2246
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2247
|
+
|
1779
2248
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1780
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2249
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
1781
2250
|
|
1782
2251
|
float yl[32];
|
1783
2252
|
|
@@ -1899,7 +2368,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1899
2368
|
}
|
1900
2369
|
if (tiisg == 0) {
|
1901
2370
|
for (int row = 0; row < 2; ++row) {
|
1902
|
-
dst[r1*ne0 +
|
2371
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
|
1903
2372
|
}
|
1904
2373
|
}
|
1905
2374
|
}
|
@@ -1913,26 +2382,33 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1913
2382
|
constant int64_t & ne02[[buffer(5)]],
|
1914
2383
|
constant int64_t & ne10[[buffer(9)]],
|
1915
2384
|
constant int64_t & ne12[[buffer(11)]],
|
1916
|
-
constant int64_t & ne0[[buffer(15)]],
|
1917
|
-
constant int64_t & ne1[[buffer(16)]],
|
1918
|
-
constant uint &
|
2385
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2386
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2387
|
+
constant uint & r2 [[buffer(17)]],
|
2388
|
+
constant uint & r3 [[buffer(18)]],
|
1919
2389
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1920
|
-
uint
|
1921
|
-
uint
|
2390
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2391
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1922
2392
|
|
1923
2393
|
const int nb = ne00/QK_K;
|
1924
2394
|
|
1925
2395
|
const int64_t r0 = tgpig.x;
|
1926
2396
|
const int64_t r1 = tgpig.y;
|
1927
|
-
const int64_t
|
2397
|
+
const int64_t im = tgpig.z;
|
1928
2398
|
|
1929
2399
|
const int row = 2 * r0 + sgitg;
|
1930
|
-
|
2400
|
+
|
2401
|
+
const uint i12 = im%ne12;
|
2402
|
+
const uint i13 = im/ne12;
|
2403
|
+
|
2404
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2405
|
+
|
1931
2406
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
1932
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2407
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2408
|
+
|
1933
2409
|
const int ix = tiisg/4;
|
1934
2410
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
1935
|
-
const int
|
2411
|
+
const int iq = il/8; // 0, 0, 1, 1
|
1936
2412
|
const int in = il%8; // 0, 4, 0, 4
|
1937
2413
|
|
1938
2414
|
float2 sum = {0.f, 0.f};
|
@@ -1952,7 +2428,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1952
2428
|
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
1953
2429
|
|
1954
2430
|
for (int l = 0; l < 4; l += 2) {
|
1955
|
-
const uint16_t hm = h[l/2] >>
|
2431
|
+
const uint16_t hm = h[l/2] >> iq;
|
1956
2432
|
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
1957
2433
|
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
1958
2434
|
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
@@ -1968,7 +2444,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
1968
2444
|
|
1969
2445
|
const float tot = simd_sum(sumf);
|
1970
2446
|
if (tiisg == 0) {
|
1971
|
-
dst[r1*ne0 +
|
2447
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
1972
2448
|
}
|
1973
2449
|
|
1974
2450
|
}
|
@@ -1986,10 +2462,11 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
1986
2462
|
constant int64_t & ne12 [[buffer(11)]],
|
1987
2463
|
constant int64_t & ne0 [[buffer(15)]],
|
1988
2464
|
constant int64_t & ne1 [[buffer(16)]],
|
1989
|
-
constant uint &
|
2465
|
+
constant uint & r2 [[buffer(17)]],
|
2466
|
+
constant uint & r3 [[buffer(18)]],
|
1990
2467
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1991
|
-
uint
|
1992
|
-
uint
|
2468
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
2469
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
1993
2470
|
|
1994
2471
|
const uint16_t kmask1 = 0x3f3f;
|
1995
2472
|
const uint16_t kmask2 = 0x0f0f;
|
@@ -1997,26 +2474,32 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
1997
2474
|
|
1998
2475
|
const int ix = tiisg/8; // 0...3
|
1999
2476
|
const int it = tiisg%8; // 0...7
|
2000
|
-
const int
|
2477
|
+
const int iq = it/4; // 0 or 1
|
2001
2478
|
const int ir = it%4; // 0...3
|
2002
2479
|
|
2003
2480
|
const int nb = ne00/QK_K;
|
2004
2481
|
const int r0 = tgpig.x;
|
2005
2482
|
const int r1 = tgpig.y;
|
2006
|
-
const int
|
2483
|
+
const int im = tgpig.z;
|
2007
2484
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
2008
2485
|
const int first_row = r0 * N_DST;
|
2009
2486
|
const int ib_row = first_row * nb;
|
2010
|
-
|
2487
|
+
|
2488
|
+
const uint i12 = im%ne12;
|
2489
|
+
const uint i13 = im/ne12;
|
2490
|
+
|
2491
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2492
|
+
|
2011
2493
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
2012
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2494
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2495
|
+
|
2013
2496
|
float yl[16];
|
2014
2497
|
float yh[16];
|
2015
2498
|
float sumf[N_DST]={0.f}, all_sum;
|
2016
2499
|
|
2017
2500
|
const int step = sizeof(block_q4_K) * nb / 2;
|
2018
2501
|
|
2019
|
-
device const float * y4 = y + ix * QK_K + 64 *
|
2502
|
+
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
2020
2503
|
|
2021
2504
|
uint16_t sc16[4];
|
2022
2505
|
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
@@ -2031,8 +2514,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2031
2514
|
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
2032
2515
|
}
|
2033
2516
|
|
2034
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales +
|
2035
|
-
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 *
|
2517
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
|
2518
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
2036
2519
|
device const half * dh = &x[ib].d;
|
2037
2520
|
|
2038
2521
|
for (int row = 0; row < N_DST; row++) {
|
@@ -2076,7 +2559,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2076
2559
|
for (int row = 0; row < N_DST; ++row) {
|
2077
2560
|
all_sum = simd_sum(sumf[row]);
|
2078
2561
|
if (tiisg == 0) {
|
2079
|
-
dst[r1*ne0 +
|
2562
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
2080
2563
|
}
|
2081
2564
|
}
|
2082
2565
|
}
|
@@ -2090,9 +2573,10 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2090
2573
|
constant int64_t & ne02[[buffer(5)]],
|
2091
2574
|
constant int64_t & ne10[[buffer(9)]],
|
2092
2575
|
constant int64_t & ne12[[buffer(11)]],
|
2093
|
-
constant int64_t & ne0[[buffer(15)]],
|
2094
|
-
constant int64_t & ne1[[buffer(16)]],
|
2095
|
-
constant uint &
|
2576
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2577
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2578
|
+
constant uint & r2 [[buffer(17)]],
|
2579
|
+
constant uint & r3 [[buffer(18)]],
|
2096
2580
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2097
2581
|
uint tiisg[[thread_index_in_simdgroup]],
|
2098
2582
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2103,12 +2587,18 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2103
2587
|
const int nb = ne00/QK_K;
|
2104
2588
|
const int r0 = tgpig.x;
|
2105
2589
|
const int r1 = tgpig.y;
|
2106
|
-
const int
|
2590
|
+
const int im = tgpig.z;
|
2107
2591
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
2108
2592
|
const int ib_row = first_row * nb;
|
2109
|
-
|
2593
|
+
|
2594
|
+
const uint i12 = im%ne12;
|
2595
|
+
const uint i13 = im/ne12;
|
2596
|
+
|
2597
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2598
|
+
|
2110
2599
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
2111
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
2600
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2601
|
+
|
2112
2602
|
float yl[8];
|
2113
2603
|
float yh[8];
|
2114
2604
|
float sumf[N_DST]={0.f}, all_sum;
|
@@ -2164,7 +2654,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
2164
2654
|
for (int row = 0; row < N_DST; ++row) {
|
2165
2655
|
all_sum = simd_sum(sumf[row]);
|
2166
2656
|
if (tiisg == 0) {
|
2167
|
-
dst[r1*ne0+
|
2657
|
+
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
2168
2658
|
}
|
2169
2659
|
}
|
2170
2660
|
}
|
@@ -2179,9 +2669,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2179
2669
|
constant int64_t & ne02[[buffer(5)]],
|
2180
2670
|
constant int64_t & ne10[[buffer(9)]],
|
2181
2671
|
constant int64_t & ne12[[buffer(11)]],
|
2182
|
-
constant int64_t & ne0[[buffer(15)]],
|
2183
|
-
constant int64_t & ne1[[buffer(16)]],
|
2184
|
-
constant uint &
|
2672
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2673
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2674
|
+
constant uint & r2 [[buffer(17)]],
|
2675
|
+
constant uint & r3 [[buffer(18)]],
|
2185
2676
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2186
2677
|
uint tiisg[[thread_index_in_simdgroup]],
|
2187
2678
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2190,12 +2681,17 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2190
2681
|
|
2191
2682
|
const int64_t r0 = tgpig.x;
|
2192
2683
|
const int64_t r1 = tgpig.y;
|
2193
|
-
const int
|
2684
|
+
const int im = tgpig.z;
|
2194
2685
|
|
2195
2686
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
2196
|
-
|
2687
|
+
|
2688
|
+
const uint i12 = im%ne12;
|
2689
|
+
const uint i13 = im/ne12;
|
2690
|
+
|
2691
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2692
|
+
|
2197
2693
|
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
2198
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2694
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2199
2695
|
|
2200
2696
|
float sumf[2]={0.f};
|
2201
2697
|
|
@@ -2211,15 +2707,15 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2211
2707
|
|
2212
2708
|
const int tid = tiisg/4;
|
2213
2709
|
const int ix = tiisg%4;
|
2214
|
-
const int
|
2710
|
+
const int iq = tid/4;
|
2215
2711
|
const int ir = tid%4;
|
2216
2712
|
const int n = 8;
|
2217
2713
|
|
2218
2714
|
const int l0 = n*ir;
|
2219
|
-
const int q_offset = 32*
|
2220
|
-
const int y_offset = 64*
|
2715
|
+
const int q_offset = 32*iq + l0;
|
2716
|
+
const int y_offset = 64*iq + l0;
|
2221
2717
|
|
2222
|
-
const uint8_t hm1 = 1u << (2*
|
2718
|
+
const uint8_t hm1 = 1u << (2*iq);
|
2223
2719
|
const uint8_t hm2 = hm1 << 1;
|
2224
2720
|
const uint8_t hm3 = hm1 << 4;
|
2225
2721
|
const uint8_t hm4 = hm2 << 4;
|
@@ -2234,7 +2730,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2234
2730
|
device const uint8_t * q1 = x[i].qs + q_offset;
|
2235
2731
|
device const uint8_t * qh = x[i].qh + l0;
|
2236
2732
|
device const half * dh = &x[i].d;
|
2237
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales +
|
2733
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
|
2238
2734
|
|
2239
2735
|
device const float * y2 = y1 + 128;
|
2240
2736
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
@@ -2290,7 +2786,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2290
2786
|
|
2291
2787
|
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
2292
2788
|
const int ix = tiisg%8;
|
2293
|
-
const int
|
2789
|
+
const int iq = il/8; // 0, 0, 1, 1
|
2294
2790
|
const int in = il%8; // 0, 4, 0, 4
|
2295
2791
|
|
2296
2792
|
device const float * y = yy + ix*QK_K + il;
|
@@ -2315,7 +2811,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2315
2811
|
|
2316
2812
|
float2 acc = {0.f, 0.f};
|
2317
2813
|
for (int l = 0; l < 4; ++l) {
|
2318
|
-
const uint8_t hl = h[l] >>
|
2814
|
+
const uint8_t hl = h[l] >> iq;
|
2319
2815
|
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
2320
2816
|
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
2321
2817
|
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
@@ -2337,7 +2833,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
2337
2833
|
for (int row = 0; row < 2; ++row) {
|
2338
2834
|
const float tot = simd_sum(sumf[row]);
|
2339
2835
|
if (tiisg == 0) {
|
2340
|
-
dst[r1*ne0 +
|
2836
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
2341
2837
|
}
|
2342
2838
|
}
|
2343
2839
|
|
@@ -2352,9 +2848,10 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2352
2848
|
constant int64_t & ne02[[buffer(5)]],
|
2353
2849
|
constant int64_t & ne10[[buffer(9)]],
|
2354
2850
|
constant int64_t & ne12[[buffer(11)]],
|
2355
|
-
constant int64_t & ne0[[buffer(15)]],
|
2356
|
-
constant int64_t & ne1[[buffer(16)]],
|
2357
|
-
constant uint &
|
2851
|
+
constant int64_t & ne0 [[buffer(15)]],
|
2852
|
+
constant int64_t & ne1 [[buffer(16)]],
|
2853
|
+
constant uint & r2 [[buffer(17)]],
|
2854
|
+
constant uint & r3 [[buffer(18)]],
|
2358
2855
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2359
2856
|
uint tiisg[[thread_index_in_simdgroup]],
|
2360
2857
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2368,12 +2865,17 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2368
2865
|
|
2369
2866
|
const int64_t r0 = tgpig.x;
|
2370
2867
|
const int64_t r1 = tgpig.y;
|
2371
|
-
const int
|
2868
|
+
const int im = tgpig.z;
|
2372
2869
|
|
2373
2870
|
const int row = 2 * r0 + sgitg;
|
2374
|
-
|
2871
|
+
|
2872
|
+
const uint i12 = im%ne12;
|
2873
|
+
const uint i13 = im/ne12;
|
2874
|
+
|
2875
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
2876
|
+
|
2375
2877
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
2376
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
2878
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
2377
2879
|
|
2378
2880
|
float sumf = 0;
|
2379
2881
|
|
@@ -2439,7 +2941,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
2439
2941
|
|
2440
2942
|
const float tot = simd_sum(sumf);
|
2441
2943
|
if (tiisg == 0) {
|
2442
|
-
dst[r1*ne0 +
|
2944
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
2443
2945
|
}
|
2444
2946
|
}
|
2445
2947
|
|
@@ -2749,24 +3251,25 @@ kernel void kernel_get_rows(
|
|
2749
3251
|
|
2750
3252
|
// each block_q contains 16*nl weights
|
2751
3253
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
2752
|
-
|
2753
|
-
|
2754
|
-
|
2755
|
-
|
2756
|
-
|
2757
|
-
|
2758
|
-
|
2759
|
-
|
2760
|
-
|
2761
|
-
|
2762
|
-
|
2763
|
-
|
2764
|
-
|
2765
|
-
|
2766
|
-
|
2767
|
-
|
2768
|
-
|
2769
|
-
|
3254
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
3255
|
+
device const uchar * src1,
|
3256
|
+
device float * dst,
|
3257
|
+
constant int64_t & ne00,
|
3258
|
+
constant int64_t & ne02,
|
3259
|
+
constant int64_t & nb01,
|
3260
|
+
constant int64_t & nb02,
|
3261
|
+
constant int64_t & ne12,
|
3262
|
+
constant int64_t & nb10,
|
3263
|
+
constant int64_t & nb11,
|
3264
|
+
constant int64_t & nb12,
|
3265
|
+
constant int64_t & ne0,
|
3266
|
+
constant int64_t & ne1,
|
3267
|
+
constant uint & r2,
|
3268
|
+
constant uint & r3,
|
3269
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3270
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3271
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3272
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2770
3273
|
|
2771
3274
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
2772
3275
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
@@ -2792,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2792
3295
|
|
2793
3296
|
short il = (tiitg % THREAD_PER_ROW);
|
2794
3297
|
|
2795
|
-
uint
|
3298
|
+
const uint i12 = im%ne12;
|
3299
|
+
const uint i13 = im/ne12;
|
3300
|
+
|
3301
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
2796
3302
|
ushort offset1 = il/nl;
|
2797
3303
|
|
2798
3304
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
@@ -2876,14 +3382,116 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2876
3382
|
}
|
2877
3383
|
}
|
2878
3384
|
|
3385
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3386
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
3387
|
+
device const uchar * src1,
|
3388
|
+
device float * dst,
|
3389
|
+
constant int64_t & ne00,
|
3390
|
+
constant int64_t & ne02,
|
3391
|
+
constant int64_t & nb01,
|
3392
|
+
constant int64_t & nb02,
|
3393
|
+
constant int64_t & ne12,
|
3394
|
+
constant int64_t & nb10,
|
3395
|
+
constant int64_t & nb11,
|
3396
|
+
constant int64_t & nb12,
|
3397
|
+
constant int64_t & ne0,
|
3398
|
+
constant int64_t & ne1,
|
3399
|
+
constant uint & r2,
|
3400
|
+
constant uint & r3,
|
3401
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3402
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3403
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3404
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3405
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3406
|
+
src0,
|
3407
|
+
src1,
|
3408
|
+
dst,
|
3409
|
+
ne00,
|
3410
|
+
ne02,
|
3411
|
+
nb01,
|
3412
|
+
nb02,
|
3413
|
+
ne12,
|
3414
|
+
nb10,
|
3415
|
+
nb11,
|
3416
|
+
nb12,
|
3417
|
+
ne0,
|
3418
|
+
ne1,
|
3419
|
+
r2,
|
3420
|
+
r3,
|
3421
|
+
shared_memory,
|
3422
|
+
tgpig,
|
3423
|
+
tiitg,
|
3424
|
+
sgitg);
|
3425
|
+
}
|
3426
|
+
|
3427
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3428
|
+
kernel void kernel_mul_mm_id(
|
3429
|
+
device const int32_t * ids,
|
3430
|
+
device const uchar * src1,
|
3431
|
+
device float * dst,
|
3432
|
+
constant int64_t & ne00,
|
3433
|
+
constant int64_t & ne02,
|
3434
|
+
constant int64_t & nb01,
|
3435
|
+
constant int64_t & nb02,
|
3436
|
+
constant int64_t & ne12,
|
3437
|
+
constant int64_t & nb10,
|
3438
|
+
constant int64_t & nb11,
|
3439
|
+
constant int64_t & nb12,
|
3440
|
+
constant int64_t & ne0,
|
3441
|
+
constant int64_t & ne1,
|
3442
|
+
constant uint & r2,
|
3443
|
+
constant uint & r3,
|
3444
|
+
constant int & idx,
|
3445
|
+
device const uchar * src00,
|
3446
|
+
device const uchar * src01,
|
3447
|
+
device const uchar * src02,
|
3448
|
+
device const uchar * src03,
|
3449
|
+
device const uchar * src04,
|
3450
|
+
device const uchar * src05,
|
3451
|
+
device const uchar * src06,
|
3452
|
+
device const uchar * src07,
|
3453
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
3454
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3455
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3456
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3457
|
+
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3458
|
+
|
3459
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
3460
|
+
src0[ids[idx]],
|
3461
|
+
src1,
|
3462
|
+
dst,
|
3463
|
+
ne00,
|
3464
|
+
ne02,
|
3465
|
+
nb01,
|
3466
|
+
nb02,
|
3467
|
+
ne12,
|
3468
|
+
nb10,
|
3469
|
+
nb11,
|
3470
|
+
nb12,
|
3471
|
+
ne0,
|
3472
|
+
ne1,
|
3473
|
+
r2,
|
3474
|
+
r3,
|
3475
|
+
shared_memory,
|
3476
|
+
tgpig,
|
3477
|
+
tiitg,
|
3478
|
+
sgitg);
|
3479
|
+
}
|
3480
|
+
|
2879
3481
|
#if QK_K == 256
|
2880
3482
|
#define QK_NL 16
|
2881
3483
|
#else
|
2882
3484
|
#define QK_NL 4
|
2883
3485
|
#endif
|
2884
3486
|
|
2885
|
-
typedef void (get_rows_t)(
|
2886
|
-
|
3487
|
+
typedef void (get_rows_t)(
|
3488
|
+
device const void * src0,
|
3489
|
+
device const int * src1,
|
3490
|
+
device float * dst,
|
3491
|
+
constant int64_t & ne00,
|
3492
|
+
constant uint64_t & nb01,
|
3493
|
+
constant uint64_t & nb1,
|
3494
|
+
uint, uint, uint);
|
2887
3495
|
|
2888
3496
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
2889
3497
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
@@ -2912,8 +3520,10 @@ typedef void (mat_mm_t)(
|
|
2912
3520
|
constant int64_t & nb12,
|
2913
3521
|
constant int64_t & ne0,
|
2914
3522
|
constant int64_t & ne1,
|
2915
|
-
constant uint &
|
2916
|
-
|
3523
|
+
constant uint & r2,
|
3524
|
+
constant uint & r3,
|
3525
|
+
threadgroup uchar *,
|
3526
|
+
uint3, uint, uint);
|
2917
3527
|
|
2918
3528
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
2919
3529
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
@@ -2927,3 +3537,44 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
2927
3537
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
2928
3538
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
2929
3539
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
3540
|
+
|
3541
|
+
typedef void (mat_mm_id_t)(
|
3542
|
+
device const int32_t * ids,
|
3543
|
+
device const uchar * src1,
|
3544
|
+
device float * dst,
|
3545
|
+
constant int64_t & ne00,
|
3546
|
+
constant int64_t & ne02,
|
3547
|
+
constant int64_t & nb01,
|
3548
|
+
constant int64_t & nb02,
|
3549
|
+
constant int64_t & ne12,
|
3550
|
+
constant int64_t & nb10,
|
3551
|
+
constant int64_t & nb11,
|
3552
|
+
constant int64_t & nb12,
|
3553
|
+
constant int64_t & ne0,
|
3554
|
+
constant int64_t & ne1,
|
3555
|
+
constant uint & r2,
|
3556
|
+
constant uint & r3,
|
3557
|
+
constant int & idx,
|
3558
|
+
device const uchar * src00,
|
3559
|
+
device const uchar * src01,
|
3560
|
+
device const uchar * src02,
|
3561
|
+
device const uchar * src03,
|
3562
|
+
device const uchar * src04,
|
3563
|
+
device const uchar * src05,
|
3564
|
+
device const uchar * src06,
|
3565
|
+
device const uchar * src07,
|
3566
|
+
threadgroup uchar *,
|
3567
|
+
uint3, uint, uint);
|
3568
|
+
|
3569
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
3570
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
3571
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
3572
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
3573
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
3574
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
3575
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
3576
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
3577
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
3578
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
3579
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
3580
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|