whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +5 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +187 -112
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +55 -19
- package/cpp/ggml-alloc.h +7 -0
- package/cpp/ggml-backend-impl.h +46 -21
- package/cpp/ggml-backend.c +563 -156
- package/cpp/ggml-backend.h +62 -17
- package/cpp/ggml-impl.h +1 -1
- package/cpp/ggml-metal-whisper.metal +1010 -253
- package/cpp/ggml-metal.h +7 -1
- package/cpp/ggml-metal.m +618 -187
- package/cpp/ggml-quants.c +64 -59
- package/cpp/ggml-quants.h +40 -40
- package/cpp/ggml.c +751 -1466
- package/cpp/ggml.h +90 -25
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +141 -59
- package/cpp/rn-whisper.h +47 -15
- package/cpp/whisper.cpp +1635 -928
- package/cpp/whisper.h +55 -10
- package/ios/RNWhisper.mm +7 -7
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +3 -11
- package/ios/RNWhisperContext.mm +62 -134
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +6 -5
- package/src/version.json +1 -1
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
using namespace metal;
|
|
4
4
|
|
|
5
5
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
6
|
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
|
7
|
+
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
|
6
8
|
|
|
7
9
|
#define QK4_0 32
|
|
8
10
|
#define QR4_0 2
|
|
@@ -39,8 +41,15 @@ typedef struct {
|
|
|
39
41
|
int8_t qs[QK8_0]; // quants
|
|
40
42
|
} block_q8_0;
|
|
41
43
|
|
|
42
|
-
//
|
|
43
|
-
|
|
44
|
+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
45
|
+
|
|
46
|
+
enum ggml_sort_order {
|
|
47
|
+
GGML_SORT_ASC,
|
|
48
|
+
GGML_SORT_DESC,
|
|
49
|
+
};
|
|
50
|
+
|
|
51
|
+
// general-purpose kernel for addition, multiplication and division of two tensors
|
|
52
|
+
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
|
44
53
|
// cons: not very efficient
|
|
45
54
|
kernel void kernel_add(
|
|
46
55
|
device const char * src0,
|
|
@@ -81,16 +90,111 @@ kernel void kernel_add(
|
|
|
81
90
|
const int64_t i12 = i02 % ne12;
|
|
82
91
|
const int64_t i11 = i01 % ne11;
|
|
83
92
|
|
|
84
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01
|
|
85
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11
|
|
86
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1
|
|
93
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
94
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
95
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
87
96
|
|
|
88
97
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
89
|
-
|
|
98
|
+
const int i10 = i0 % ne10;
|
|
99
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
kernel void kernel_mul(
|
|
104
|
+
device const char * src0,
|
|
105
|
+
device const char * src1,
|
|
106
|
+
device char * dst,
|
|
107
|
+
constant int64_t & ne00,
|
|
108
|
+
constant int64_t & ne01,
|
|
109
|
+
constant int64_t & ne02,
|
|
110
|
+
constant int64_t & ne03,
|
|
111
|
+
constant int64_t & nb00,
|
|
112
|
+
constant int64_t & nb01,
|
|
113
|
+
constant int64_t & nb02,
|
|
114
|
+
constant int64_t & nb03,
|
|
115
|
+
constant int64_t & ne10,
|
|
116
|
+
constant int64_t & ne11,
|
|
117
|
+
constant int64_t & ne12,
|
|
118
|
+
constant int64_t & ne13,
|
|
119
|
+
constant int64_t & nb10,
|
|
120
|
+
constant int64_t & nb11,
|
|
121
|
+
constant int64_t & nb12,
|
|
122
|
+
constant int64_t & nb13,
|
|
123
|
+
constant int64_t & ne0,
|
|
124
|
+
constant int64_t & ne1,
|
|
125
|
+
constant int64_t & ne2,
|
|
126
|
+
constant int64_t & ne3,
|
|
127
|
+
constant int64_t & nb0,
|
|
128
|
+
constant int64_t & nb1,
|
|
129
|
+
constant int64_t & nb2,
|
|
130
|
+
constant int64_t & nb3,
|
|
131
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
132
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
133
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
134
|
+
const int64_t i03 = tgpig.z;
|
|
135
|
+
const int64_t i02 = tgpig.y;
|
|
136
|
+
const int64_t i01 = tgpig.x;
|
|
137
|
+
|
|
138
|
+
const int64_t i13 = i03 % ne13;
|
|
139
|
+
const int64_t i12 = i02 % ne12;
|
|
140
|
+
const int64_t i11 = i01 % ne11;
|
|
141
|
+
|
|
142
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
143
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
144
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
90
145
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
dst_ptr
|
|
146
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
147
|
+
const int i10 = i0 % ne10;
|
|
148
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
kernel void kernel_div(
|
|
153
|
+
device const char * src0,
|
|
154
|
+
device const char * src1,
|
|
155
|
+
device char * dst,
|
|
156
|
+
constant int64_t & ne00,
|
|
157
|
+
constant int64_t & ne01,
|
|
158
|
+
constant int64_t & ne02,
|
|
159
|
+
constant int64_t & ne03,
|
|
160
|
+
constant int64_t & nb00,
|
|
161
|
+
constant int64_t & nb01,
|
|
162
|
+
constant int64_t & nb02,
|
|
163
|
+
constant int64_t & nb03,
|
|
164
|
+
constant int64_t & ne10,
|
|
165
|
+
constant int64_t & ne11,
|
|
166
|
+
constant int64_t & ne12,
|
|
167
|
+
constant int64_t & ne13,
|
|
168
|
+
constant int64_t & nb10,
|
|
169
|
+
constant int64_t & nb11,
|
|
170
|
+
constant int64_t & nb12,
|
|
171
|
+
constant int64_t & nb13,
|
|
172
|
+
constant int64_t & ne0,
|
|
173
|
+
constant int64_t & ne1,
|
|
174
|
+
constant int64_t & ne2,
|
|
175
|
+
constant int64_t & ne3,
|
|
176
|
+
constant int64_t & nb0,
|
|
177
|
+
constant int64_t & nb1,
|
|
178
|
+
constant int64_t & nb2,
|
|
179
|
+
constant int64_t & nb3,
|
|
180
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
181
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
182
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
183
|
+
const int64_t i03 = tgpig.z;
|
|
184
|
+
const int64_t i02 = tgpig.y;
|
|
185
|
+
const int64_t i01 = tgpig.x;
|
|
186
|
+
|
|
187
|
+
const int64_t i13 = i03 % ne13;
|
|
188
|
+
const int64_t i12 = i02 % ne12;
|
|
189
|
+
const int64_t i11 = i01 % ne11;
|
|
190
|
+
|
|
191
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
192
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
193
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
194
|
+
|
|
195
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
196
|
+
const int i10 = i0 % ne10;
|
|
197
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
|
|
94
198
|
}
|
|
95
199
|
}
|
|
96
200
|
|
|
@@ -105,23 +209,22 @@ kernel void kernel_add_row(
|
|
|
105
209
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
106
210
|
}
|
|
107
211
|
|
|
108
|
-
kernel void
|
|
212
|
+
kernel void kernel_mul_row(
|
|
109
213
|
device const float4 * src0,
|
|
110
214
|
device const float4 * src1,
|
|
111
215
|
device float4 * dst,
|
|
216
|
+
constant int64_t & nb [[buffer(27)]],
|
|
112
217
|
uint tpig[[thread_position_in_grid]]) {
|
|
113
|
-
dst[tpig] = src0[tpig] * src1[tpig];
|
|
218
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
|
114
219
|
}
|
|
115
220
|
|
|
116
|
-
|
|
117
|
-
// broadcast src1 into src0
|
|
118
|
-
kernel void kernel_mul_row(
|
|
221
|
+
kernel void kernel_div_row(
|
|
119
222
|
device const float4 * src0,
|
|
120
223
|
device const float4 * src1,
|
|
121
224
|
device float4 * dst,
|
|
122
|
-
constant int64_t & nb,
|
|
225
|
+
constant int64_t & nb [[buffer(27)]],
|
|
123
226
|
uint tpig[[thread_position_in_grid]]) {
|
|
124
|
-
dst[tpig] = src0[tpig]
|
|
227
|
+
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
|
125
228
|
}
|
|
126
229
|
|
|
127
230
|
kernel void kernel_scale(
|
|
@@ -162,6 +265,54 @@ kernel void kernel_sqr(
|
|
|
162
265
|
dst[tpig] = src0[tpig] * src0[tpig];
|
|
163
266
|
}
|
|
164
267
|
|
|
268
|
+
kernel void kernel_sum_rows(
|
|
269
|
+
device const float * src0,
|
|
270
|
+
device float * dst,
|
|
271
|
+
constant int64_t & ne00,
|
|
272
|
+
constant int64_t & ne01,
|
|
273
|
+
constant int64_t & ne02,
|
|
274
|
+
constant int64_t & ne03,
|
|
275
|
+
constant int64_t & nb00,
|
|
276
|
+
constant int64_t & nb01,
|
|
277
|
+
constant int64_t & nb02,
|
|
278
|
+
constant int64_t & nb03,
|
|
279
|
+
constant int64_t & ne10,
|
|
280
|
+
constant int64_t & ne11,
|
|
281
|
+
constant int64_t & ne12,
|
|
282
|
+
constant int64_t & ne13,
|
|
283
|
+
constant int64_t & nb10,
|
|
284
|
+
constant int64_t & nb11,
|
|
285
|
+
constant int64_t & nb12,
|
|
286
|
+
constant int64_t & nb13,
|
|
287
|
+
constant int64_t & ne0,
|
|
288
|
+
constant int64_t & ne1,
|
|
289
|
+
constant int64_t & ne2,
|
|
290
|
+
constant int64_t & ne3,
|
|
291
|
+
constant int64_t & nb0,
|
|
292
|
+
constant int64_t & nb1,
|
|
293
|
+
constant int64_t & nb2,
|
|
294
|
+
constant int64_t & nb3,
|
|
295
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
|
296
|
+
int64_t i3 = tpig.z;
|
|
297
|
+
int64_t i2 = tpig.y;
|
|
298
|
+
int64_t i1 = tpig.x;
|
|
299
|
+
|
|
300
|
+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
|
301
|
+
return;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
|
305
|
+
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
|
306
|
+
|
|
307
|
+
float row_sum = 0;
|
|
308
|
+
|
|
309
|
+
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
|
310
|
+
row_sum += src_row[i0];
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
dst_row[0] = row_sum;
|
|
314
|
+
}
|
|
315
|
+
|
|
165
316
|
constant float GELU_COEF_A = 0.044715f;
|
|
166
317
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
167
318
|
|
|
@@ -180,10 +331,12 @@ kernel void kernel_gelu(
|
|
|
180
331
|
|
|
181
332
|
kernel void kernel_soft_max(
|
|
182
333
|
device const float * src0,
|
|
334
|
+
device const float * src1,
|
|
183
335
|
device float * dst,
|
|
184
336
|
constant int64_t & ne00,
|
|
185
337
|
constant int64_t & ne01,
|
|
186
338
|
constant int64_t & ne02,
|
|
339
|
+
constant float & scale,
|
|
187
340
|
threadgroup float * buf [[threadgroup(0)]],
|
|
188
341
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
189
342
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
@@ -194,73 +347,77 @@ kernel void kernel_soft_max(
|
|
|
194
347
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
195
348
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
196
349
|
|
|
197
|
-
device const float * psrc0 =
|
|
198
|
-
device
|
|
350
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
351
|
+
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
|
352
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
199
353
|
|
|
200
354
|
// parallel max
|
|
201
|
-
float lmax =
|
|
355
|
+
float lmax = -INFINITY;
|
|
202
356
|
|
|
203
|
-
for (int i00 = tpitg
|
|
204
|
-
lmax = MAX(lmax, psrc0[i00]);
|
|
357
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
358
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
205
359
|
}
|
|
206
360
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
361
|
+
// find the max value in the block
|
|
362
|
+
float max_val = simd_max(lmax);
|
|
363
|
+
if (ntg > N_SIMDWIDTH) {
|
|
364
|
+
if (sgitg == 0) {
|
|
365
|
+
buf[tiisg] = -INFINITY;
|
|
366
|
+
}
|
|
211
367
|
|
|
212
|
-
|
|
368
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
213
369
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
218
|
-
}
|
|
219
|
-
}
|
|
370
|
+
if (tiisg == 0) {
|
|
371
|
+
buf[sgitg] = max_val;
|
|
372
|
+
}
|
|
220
373
|
|
|
221
|
-
|
|
374
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
222
375
|
|
|
223
|
-
|
|
376
|
+
max_val = buf[tiisg];
|
|
377
|
+
max_val = simd_max(max_val);
|
|
378
|
+
}
|
|
224
379
|
|
|
225
380
|
// parallel sum
|
|
226
381
|
float lsum = 0.0f;
|
|
227
382
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
228
|
-
const float exp_psrc0 = exp(psrc0[i00] -
|
|
383
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
229
384
|
lsum += exp_psrc0;
|
|
230
|
-
// Remember the result of exp here. exp is expensive, so we really do not
|
|
231
|
-
// wish to compute it twice.
|
|
232
385
|
pdst[i00] = exp_psrc0;
|
|
233
386
|
}
|
|
234
387
|
|
|
235
388
|
float sum = simd_sum(lsum);
|
|
236
|
-
if (
|
|
237
|
-
|
|
238
|
-
|
|
389
|
+
if (ntg > N_SIMDWIDTH) {
|
|
390
|
+
if (sgitg == 0) {
|
|
391
|
+
buf[tiisg] = 0.0f;
|
|
392
|
+
}
|
|
239
393
|
|
|
240
|
-
|
|
394
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
241
395
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
buf[tpitg] += buf[tpitg + i];
|
|
246
|
-
}
|
|
247
|
-
}
|
|
396
|
+
if (tiisg == 0) {
|
|
397
|
+
buf[sgitg] = sum;
|
|
398
|
+
}
|
|
248
399
|
|
|
249
|
-
|
|
400
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
250
401
|
|
|
251
|
-
|
|
402
|
+
sum = buf[tiisg];
|
|
403
|
+
sum = simd_sum(sum);
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
const float inv_sum = 1.0f/sum;
|
|
252
407
|
|
|
253
408
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
254
|
-
pdst[i00]
|
|
409
|
+
pdst[i00] *= inv_sum;
|
|
255
410
|
}
|
|
256
411
|
}
|
|
257
412
|
|
|
258
413
|
kernel void kernel_soft_max_4(
|
|
259
414
|
device const float * src0,
|
|
415
|
+
device const float * src1,
|
|
260
416
|
device float * dst,
|
|
261
417
|
constant int64_t & ne00,
|
|
262
418
|
constant int64_t & ne01,
|
|
263
419
|
constant int64_t & ne02,
|
|
420
|
+
constant float & scale,
|
|
264
421
|
threadgroup float * buf [[threadgroup(0)]],
|
|
265
422
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
266
423
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
@@ -271,64 +428,68 @@ kernel void kernel_soft_max_4(
|
|
|
271
428
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
272
429
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
273
430
|
|
|
274
|
-
device const float4 * psrc4 =
|
|
275
|
-
device
|
|
431
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
432
|
+
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
433
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
276
434
|
|
|
277
435
|
// parallel max
|
|
278
|
-
float4 lmax4 =
|
|
436
|
+
float4 lmax4 = -INFINITY;
|
|
279
437
|
|
|
280
|
-
for (int i00 = tpitg
|
|
281
|
-
lmax4 = fmax(lmax4, psrc4[i00]);
|
|
438
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
439
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
282
440
|
}
|
|
283
441
|
|
|
284
442
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
285
|
-
float max = simd_max(lmax);
|
|
286
|
-
if (tiisg == 0) {
|
|
287
|
-
buf[sgitg] = max;
|
|
288
|
-
}
|
|
289
443
|
|
|
290
|
-
|
|
444
|
+
float max_val = simd_max(lmax);
|
|
445
|
+
if (ntg > N_SIMDWIDTH) {
|
|
446
|
+
if (sgitg == 0) {
|
|
447
|
+
buf[tiisg] = -INFINITY;
|
|
448
|
+
}
|
|
291
449
|
|
|
292
|
-
|
|
293
|
-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
294
|
-
if (tpitg < i) {
|
|
295
|
-
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
296
|
-
}
|
|
297
|
-
}
|
|
450
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
298
451
|
|
|
299
|
-
|
|
452
|
+
if (tiisg == 0) {
|
|
453
|
+
buf[sgitg] = max_val;
|
|
454
|
+
}
|
|
300
455
|
|
|
301
|
-
|
|
456
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
457
|
+
|
|
458
|
+
max_val = buf[tiisg];
|
|
459
|
+
max_val = simd_max(max_val);
|
|
460
|
+
}
|
|
302
461
|
|
|
303
462
|
// parallel sum
|
|
304
463
|
float4 lsum4 = 0.0f;
|
|
305
464
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
306
|
-
const float4 exp_psrc4 = exp(psrc4[i00] -
|
|
465
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
307
466
|
lsum4 += exp_psrc4;
|
|
308
467
|
pdst4[i00] = exp_psrc4;
|
|
309
468
|
}
|
|
310
469
|
|
|
311
470
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
312
471
|
float sum = simd_sum(lsum);
|
|
313
|
-
if (
|
|
314
|
-
|
|
315
|
-
|
|
472
|
+
if (ntg > N_SIMDWIDTH) {
|
|
473
|
+
if (sgitg == 0) {
|
|
474
|
+
buf[tiisg] = 0.0f;
|
|
475
|
+
}
|
|
316
476
|
|
|
317
|
-
|
|
477
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
318
478
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
buf[tpitg] += buf[tpitg + i];
|
|
323
|
-
}
|
|
324
|
-
}
|
|
479
|
+
if (tiisg == 0) {
|
|
480
|
+
buf[sgitg] = sum;
|
|
481
|
+
}
|
|
325
482
|
|
|
326
|
-
|
|
483
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
327
484
|
|
|
328
|
-
|
|
485
|
+
sum = buf[tiisg];
|
|
486
|
+
sum = simd_sum(sum);
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
const float inv_sum = 1.0f/sum;
|
|
329
490
|
|
|
330
491
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
331
|
-
pdst4[i00]
|
|
492
|
+
pdst4[i00] *= inv_sum;
|
|
332
493
|
}
|
|
333
494
|
}
|
|
334
495
|
|
|
@@ -435,14 +596,13 @@ kernel void kernel_rms_norm(
|
|
|
435
596
|
constant int64_t & ne00,
|
|
436
597
|
constant uint64_t & nb01,
|
|
437
598
|
constant float & eps,
|
|
438
|
-
threadgroup float *
|
|
599
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
439
600
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
440
601
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
441
602
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
442
603
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
443
604
|
uint ntg[[threads_per_threadgroup]]) {
|
|
444
|
-
device const float4 * x
|
|
445
|
-
device const float * x_scalar = (device const float *) x;
|
|
605
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
|
446
606
|
|
|
447
607
|
float4 sumf = 0;
|
|
448
608
|
float all_sum = 0;
|
|
@@ -453,40 +613,30 @@ kernel void kernel_rms_norm(
|
|
|
453
613
|
}
|
|
454
614
|
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
|
455
615
|
all_sum = simd_sum(all_sum);
|
|
456
|
-
if (
|
|
457
|
-
|
|
458
|
-
|
|
616
|
+
if (ntg > N_SIMDWIDTH) {
|
|
617
|
+
if (sgitg == 0) {
|
|
618
|
+
buf[tiisg] = 0.0f;
|
|
619
|
+
}
|
|
459
620
|
|
|
460
|
-
|
|
621
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
461
622
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
if (tpitg < i) {
|
|
465
|
-
sum[tpitg] += sum[tpitg + i];
|
|
466
|
-
}
|
|
467
|
-
}
|
|
468
|
-
if (tpitg == 0) {
|
|
469
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
|
470
|
-
sum[0] += x_scalar[i];
|
|
623
|
+
if (tiisg == 0) {
|
|
624
|
+
buf[sgitg] = all_sum;
|
|
471
625
|
}
|
|
472
|
-
sum[0] /= ne00;
|
|
473
|
-
}
|
|
474
626
|
|
|
475
|
-
|
|
627
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
476
628
|
|
|
477
|
-
|
|
629
|
+
all_sum = buf[tiisg];
|
|
630
|
+
all_sum = simd_sum(all_sum);
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
const float mean = all_sum/ne00;
|
|
478
634
|
const float scale = 1.0f/sqrt(mean + eps);
|
|
479
635
|
|
|
480
636
|
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
|
481
|
-
device float * y_scalar = (device float *) y;
|
|
482
637
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
483
638
|
y[i00] = x[i00] * scale;
|
|
484
639
|
}
|
|
485
|
-
if (tpitg == 0) {
|
|
486
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
|
487
|
-
y_scalar[i00] = x_scalar[i00] * scale;
|
|
488
|
-
}
|
|
489
|
-
}
|
|
490
640
|
}
|
|
491
641
|
|
|
492
642
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
@@ -576,15 +726,25 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
576
726
|
// putting them in the kernel cause a significant performance penalty
|
|
577
727
|
#define N_DST 4 // each SIMD group works on 4 rows
|
|
578
728
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
579
|
-
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
580
729
|
//Note: This is a template, but strictly speaking it only applies to
|
|
581
730
|
// quantizations where the block size is 32. It also does not
|
|
582
731
|
// giard against the number of rows not being divisible by
|
|
583
732
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
584
733
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
585
|
-
void mul_vec_q_n_f32(
|
|
586
|
-
|
|
587
|
-
|
|
734
|
+
void mul_vec_q_n_f32(
|
|
735
|
+
device const void * src0,
|
|
736
|
+
device const float * src1,
|
|
737
|
+
device float * dst,
|
|
738
|
+
int64_t ne00,
|
|
739
|
+
int64_t ne01,
|
|
740
|
+
int64_t ne02,
|
|
741
|
+
int64_t ne10,
|
|
742
|
+
int64_t ne12,
|
|
743
|
+
int64_t ne0,
|
|
744
|
+
int64_t ne1,
|
|
745
|
+
uint r2,
|
|
746
|
+
uint r3,
|
|
747
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
|
588
748
|
const int nb = ne00/QK4_0;
|
|
589
749
|
|
|
590
750
|
const int r0 = tgpig.x;
|
|
@@ -593,7 +753,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
|
593
753
|
|
|
594
754
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
|
595
755
|
|
|
596
|
-
const uint
|
|
756
|
+
const uint i12 = im%ne12;
|
|
757
|
+
const uint i13 = im/ne12;
|
|
758
|
+
|
|
759
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
597
760
|
|
|
598
761
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
|
599
762
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
@@ -643,13 +806,14 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
|
643
806
|
constant int64_t & ne02[[buffer(5)]],
|
|
644
807
|
constant int64_t & ne10[[buffer(9)]],
|
|
645
808
|
constant int64_t & ne12[[buffer(11)]],
|
|
646
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
647
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
648
|
-
constant uint &
|
|
809
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
810
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
811
|
+
constant uint & r2 [[buffer(17)]],
|
|
812
|
+
constant uint & r3 [[buffer(18)]],
|
|
649
813
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
650
814
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
651
815
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
652
|
-
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
|
816
|
+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
653
817
|
}
|
|
654
818
|
|
|
655
819
|
kernel void kernel_mul_mv_q4_1_f32(
|
|
@@ -661,13 +825,14 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
661
825
|
constant int64_t & ne02[[buffer(5)]],
|
|
662
826
|
constant int64_t & ne10[[buffer(9)]],
|
|
663
827
|
constant int64_t & ne12[[buffer(11)]],
|
|
664
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
665
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
666
|
-
constant uint &
|
|
828
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
829
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
830
|
+
constant uint & r2 [[buffer(17)]],
|
|
831
|
+
constant uint & r3 [[buffer(18)]],
|
|
667
832
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
668
833
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
669
834
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
670
|
-
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
|
835
|
+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
671
836
|
}
|
|
672
837
|
|
|
673
838
|
kernel void kernel_mul_mv_q5_0_f32(
|
|
@@ -679,13 +844,14 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
|
679
844
|
constant int64_t & ne02[[buffer(5)]],
|
|
680
845
|
constant int64_t & ne10[[buffer(9)]],
|
|
681
846
|
constant int64_t & ne12[[buffer(11)]],
|
|
682
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
683
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
684
|
-
constant uint &
|
|
847
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
848
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
849
|
+
constant uint & r2 [[buffer(17)]],
|
|
850
|
+
constant uint & r3 [[buffer(18)]],
|
|
685
851
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
686
852
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
687
853
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
688
|
-
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
|
854
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
689
855
|
}
|
|
690
856
|
|
|
691
857
|
kernel void kernel_mul_mv_q5_1_f32(
|
|
@@ -697,13 +863,14 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
|
697
863
|
constant int64_t & ne02[[buffer(5)]],
|
|
698
864
|
constant int64_t & ne10[[buffer(9)]],
|
|
699
865
|
constant int64_t & ne12[[buffer(11)]],
|
|
700
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
701
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
702
|
-
constant uint &
|
|
866
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
867
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
868
|
+
constant uint & r2 [[buffer(17)]],
|
|
869
|
+
constant uint & r3 [[buffer(18)]],
|
|
703
870
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
704
871
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
705
872
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
706
|
-
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,
|
|
873
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
707
874
|
}
|
|
708
875
|
|
|
709
876
|
|
|
@@ -718,9 +885,10 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
718
885
|
constant int64_t & ne02[[buffer(5)]],
|
|
719
886
|
constant int64_t & ne10[[buffer(9)]],
|
|
720
887
|
constant int64_t & ne12[[buffer(11)]],
|
|
721
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
722
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
723
|
-
constant uint &
|
|
888
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
889
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
890
|
+
constant uint & r2 [[buffer(17)]],
|
|
891
|
+
constant uint & r3 [[buffer(18)]],
|
|
724
892
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
725
893
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
726
894
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -732,8 +900,14 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
732
900
|
const int r0 = tgpig.x;
|
|
733
901
|
const int r1 = tgpig.y;
|
|
734
902
|
const int im = tgpig.z;
|
|
903
|
+
|
|
735
904
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
|
736
|
-
|
|
905
|
+
|
|
906
|
+
const uint i12 = im%ne12;
|
|
907
|
+
const uint i13 = im/ne12;
|
|
908
|
+
|
|
909
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
910
|
+
|
|
737
911
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
|
738
912
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
739
913
|
|
|
@@ -791,14 +965,21 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
791
965
|
constant uint64_t & nb12,
|
|
792
966
|
constant int64_t & ne0,
|
|
793
967
|
constant int64_t & ne1,
|
|
968
|
+
constant uint & r2 [[buffer(17)]],
|
|
969
|
+
constant uint & r3 [[buffer(18)]],
|
|
794
970
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
795
|
-
uint
|
|
971
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
796
972
|
|
|
797
973
|
const int64_t r0 = tgpig.x;
|
|
798
974
|
const int64_t rb = tgpig.y*N_F32_F32;
|
|
799
975
|
const int64_t im = tgpig.z;
|
|
800
976
|
|
|
801
|
-
|
|
977
|
+
const uint i12 = im%ne12;
|
|
978
|
+
const uint i13 = im/ne12;
|
|
979
|
+
|
|
980
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
981
|
+
|
|
982
|
+
device const float * x = (device const float *) (src0 + offset0);
|
|
802
983
|
|
|
803
984
|
if (ne00 < 128) {
|
|
804
985
|
for (int row = 0; row < N_F32_F32; ++row) {
|
|
@@ -844,6 +1025,86 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
844
1025
|
}
|
|
845
1026
|
}
|
|
846
1027
|
|
|
1028
|
+
#define N_F16_F16 4
|
|
1029
|
+
|
|
1030
|
+
kernel void kernel_mul_mv_f16_f16(
|
|
1031
|
+
device const char * src0,
|
|
1032
|
+
device const char * src1,
|
|
1033
|
+
device float * dst,
|
|
1034
|
+
constant int64_t & ne00,
|
|
1035
|
+
constant int64_t & ne01,
|
|
1036
|
+
constant int64_t & ne02,
|
|
1037
|
+
constant uint64_t & nb00,
|
|
1038
|
+
constant uint64_t & nb01,
|
|
1039
|
+
constant uint64_t & nb02,
|
|
1040
|
+
constant int64_t & ne10,
|
|
1041
|
+
constant int64_t & ne11,
|
|
1042
|
+
constant int64_t & ne12,
|
|
1043
|
+
constant uint64_t & nb10,
|
|
1044
|
+
constant uint64_t & nb11,
|
|
1045
|
+
constant uint64_t & nb12,
|
|
1046
|
+
constant int64_t & ne0,
|
|
1047
|
+
constant int64_t & ne1,
|
|
1048
|
+
constant uint & r2 [[buffer(17)]],
|
|
1049
|
+
constant uint & r3 [[buffer(18)]],
|
|
1050
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1051
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1052
|
+
|
|
1053
|
+
const int64_t r0 = tgpig.x;
|
|
1054
|
+
const int64_t rb = tgpig.y*N_F16_F16;
|
|
1055
|
+
const int64_t im = tgpig.z;
|
|
1056
|
+
|
|
1057
|
+
const uint i12 = im%ne12;
|
|
1058
|
+
const uint i13 = im/ne12;
|
|
1059
|
+
|
|
1060
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1061
|
+
|
|
1062
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
1063
|
+
|
|
1064
|
+
if (ne00 < 128) {
|
|
1065
|
+
for (int row = 0; row < N_F16_F16; ++row) {
|
|
1066
|
+
int r1 = rb + row;
|
|
1067
|
+
if (r1 >= ne11) {
|
|
1068
|
+
break;
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
1072
|
+
|
|
1073
|
+
float sumf = 0;
|
|
1074
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
|
1075
|
+
sumf += (half) x[i] * (half) y[i];
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
float all_sum = simd_sum(sumf);
|
|
1079
|
+
if (tiisg == 0) {
|
|
1080
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1081
|
+
}
|
|
1082
|
+
}
|
|
1083
|
+
} else {
|
|
1084
|
+
device const half4 * x4 = (device const half4 *)x;
|
|
1085
|
+
for (int row = 0; row < N_F16_F16; ++row) {
|
|
1086
|
+
int r1 = rb + row;
|
|
1087
|
+
if (r1 >= ne11) {
|
|
1088
|
+
break;
|
|
1089
|
+
}
|
|
1090
|
+
|
|
1091
|
+
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
1092
|
+
device const half4 * y4 = (device const half4 *) y;
|
|
1093
|
+
|
|
1094
|
+
float sumf = 0;
|
|
1095
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1096
|
+
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
float all_sum = simd_sum(sumf);
|
|
1100
|
+
if (tiisg == 0) {
|
|
1101
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
|
1102
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1103
|
+
}
|
|
1104
|
+
}
|
|
1105
|
+
}
|
|
1106
|
+
}
|
|
1107
|
+
|
|
847
1108
|
kernel void kernel_mul_mv_f16_f32_1row(
|
|
848
1109
|
device const char * src0,
|
|
849
1110
|
device const char * src1,
|
|
@@ -862,6 +1123,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
862
1123
|
constant uint64_t & nb12,
|
|
863
1124
|
constant int64_t & ne0,
|
|
864
1125
|
constant int64_t & ne1,
|
|
1126
|
+
constant uint & r2 [[buffer(17)]],
|
|
1127
|
+
constant uint & r3 [[buffer(18)]],
|
|
865
1128
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
866
1129
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
867
1130
|
|
|
@@ -869,7 +1132,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
869
1132
|
const int64_t r1 = tgpig.y;
|
|
870
1133
|
const int64_t im = tgpig.z;
|
|
871
1134
|
|
|
872
|
-
|
|
1135
|
+
const uint i12 = im%ne12;
|
|
1136
|
+
const uint i13 = im/ne12;
|
|
1137
|
+
|
|
1138
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1139
|
+
|
|
1140
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
873
1141
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
874
1142
|
|
|
875
1143
|
float sumf = 0;
|
|
@@ -916,6 +1184,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
916
1184
|
constant uint64_t & nb12,
|
|
917
1185
|
constant int64_t & ne0,
|
|
918
1186
|
constant int64_t & ne1,
|
|
1187
|
+
constant uint & r2 [[buffer(17)]],
|
|
1188
|
+
constant uint & r3 [[buffer(18)]],
|
|
919
1189
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
920
1190
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
921
1191
|
|
|
@@ -923,7 +1193,12 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
923
1193
|
const int64_t rb = tgpig.y*N_F16_F32;
|
|
924
1194
|
const int64_t im = tgpig.z;
|
|
925
1195
|
|
|
926
|
-
|
|
1196
|
+
const uint i12 = im%ne12;
|
|
1197
|
+
const uint i13 = im/ne12;
|
|
1198
|
+
|
|
1199
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1200
|
+
|
|
1201
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
927
1202
|
|
|
928
1203
|
if (ne00 < 128) {
|
|
929
1204
|
for (int row = 0; row < N_F16_F32; ++row) {
|
|
@@ -988,6 +1263,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
988
1263
|
constant uint64_t & nb12,
|
|
989
1264
|
constant int64_t & ne0,
|
|
990
1265
|
constant int64_t & ne1,
|
|
1266
|
+
constant uint & r2 [[buffer(17)]],
|
|
1267
|
+
constant uint & r3 [[buffer(18)]],
|
|
991
1268
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
992
1269
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
993
1270
|
|
|
@@ -995,7 +1272,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
995
1272
|
const int64_t r0 = tgpig.x;
|
|
996
1273
|
const int64_t im = tgpig.z;
|
|
997
1274
|
|
|
998
|
-
|
|
1275
|
+
const uint i12 = im%ne12;
|
|
1276
|
+
const uint i13 = im/ne12;
|
|
1277
|
+
|
|
1278
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1279
|
+
|
|
1280
|
+
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
|
999
1281
|
|
|
1000
1282
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
|
1001
1283
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
|
@@ -1047,17 +1329,21 @@ kernel void kernel_alibi_f32(
|
|
|
1047
1329
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1048
1330
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1049
1331
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1332
|
+
const int64_t k = i3*ne3 + i2;
|
|
1050
1333
|
|
|
1051
|
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1052
1334
|
float m_k;
|
|
1053
|
-
if (
|
|
1054
|
-
m_k = pow(m0,
|
|
1335
|
+
if (k < n_heads_log2_floor) {
|
|
1336
|
+
m_k = pow(m0, k + 1);
|
|
1055
1337
|
} else {
|
|
1056
|
-
m_k = pow(m1, 2 * (
|
|
1338
|
+
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
|
1057
1339
|
}
|
|
1340
|
+
|
|
1341
|
+
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
|
1342
|
+
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
1058
1343
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1059
|
-
|
|
1060
|
-
|
|
1344
|
+
const float src_v = *(device float *)(src_row + i00*nb00);
|
|
1345
|
+
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
|
1346
|
+
*dst_v = i00 * m_k + src_v;
|
|
1061
1347
|
}
|
|
1062
1348
|
}
|
|
1063
1349
|
|
|
@@ -1201,33 +1487,118 @@ kernel void kernel_rope(
|
|
|
1201
1487
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
1202
1488
|
}
|
|
1203
1489
|
} else {
|
|
1204
|
-
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
1205
|
-
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
|
1490
|
+
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
1491
|
+
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
|
1492
|
+
|
|
1493
|
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
|
1494
|
+
const float cur_rot = inv_ndims*ic - ib;
|
|
1495
|
+
|
|
1496
|
+
const float theta = theta_0 * pow(freq_base, cur_rot);
|
|
1497
|
+
float cos_theta, sin_theta;
|
|
1498
|
+
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1499
|
+
|
|
1500
|
+
const int64_t i0 = ib*n_dims + ic/2;
|
|
1501
|
+
|
|
1502
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1503
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1504
|
+
|
|
1505
|
+
const float x0 = src[0];
|
|
1506
|
+
const float x1 = src[n_dims/2];
|
|
1507
|
+
|
|
1508
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1509
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1510
|
+
}
|
|
1511
|
+
}
|
|
1512
|
+
}
|
|
1513
|
+
}
|
|
1514
|
+
|
|
1515
|
+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
|
1516
|
+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
|
1517
|
+
|
|
1518
|
+
kernel void kernel_im2col_f16(
|
|
1519
|
+
device const float * x,
|
|
1520
|
+
device half * dst,
|
|
1521
|
+
constant int32_t & ofs0,
|
|
1522
|
+
constant int32_t & ofs1,
|
|
1523
|
+
constant int32_t & IW,
|
|
1524
|
+
constant int32_t & IH,
|
|
1525
|
+
constant int32_t & CHW,
|
|
1526
|
+
constant int32_t & s0,
|
|
1527
|
+
constant int32_t & s1,
|
|
1528
|
+
constant int32_t & p0,
|
|
1529
|
+
constant int32_t & p1,
|
|
1530
|
+
constant int32_t & d0,
|
|
1531
|
+
constant int32_t & d1,
|
|
1532
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1533
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
1534
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1535
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1536
|
+
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
|
|
1537
|
+
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
|
|
1538
|
+
|
|
1539
|
+
const int32_t offset_dst =
|
|
1540
|
+
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
|
1541
|
+
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
|
1542
|
+
|
|
1543
|
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
1544
|
+
dst[offset_dst] = 0.0f;
|
|
1545
|
+
} else {
|
|
1546
|
+
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
|
1547
|
+
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
|
1548
|
+
}
|
|
1549
|
+
}
|
|
1206
1550
|
|
|
1207
|
-
|
|
1208
|
-
|
|
1551
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
|
1552
|
+
typedef void (argsort_t)(
|
|
1553
|
+
device const float * x,
|
|
1554
|
+
device int32_t * dst,
|
|
1555
|
+
constant int64_t & ncols,
|
|
1556
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1557
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
1209
1558
|
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1559
|
+
template<ggml_sort_order order>
|
|
1560
|
+
kernel void kernel_argsort_f32_i32(
|
|
1561
|
+
device const float * x,
|
|
1562
|
+
device int32_t * dst,
|
|
1563
|
+
constant int64_t & ncols,
|
|
1564
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1565
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
|
1566
|
+
// bitonic sort
|
|
1567
|
+
int col = tpitg[0];
|
|
1568
|
+
int row = tgpig[1];
|
|
1213
1569
|
|
|
1214
|
-
|
|
1570
|
+
if (col >= ncols) return;
|
|
1215
1571
|
|
|
1216
|
-
|
|
1217
|
-
|
|
1572
|
+
device const float * x_row = x + row * ncols;
|
|
1573
|
+
device int32_t * dst_row = dst + row * ncols;
|
|
1218
1574
|
|
|
1219
|
-
|
|
1220
|
-
|
|
1575
|
+
// initialize indices
|
|
1576
|
+
if (col < ncols) {
|
|
1577
|
+
dst_row[col] = col;
|
|
1578
|
+
}
|
|
1579
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1221
1580
|
|
|
1222
|
-
|
|
1223
|
-
|
|
1581
|
+
for (int k = 2; k <= ncols; k *= 2) {
|
|
1582
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
|
1583
|
+
int ixj = col ^ j;
|
|
1584
|
+
if (ixj > col) {
|
|
1585
|
+
if ((col & k) == 0) {
|
|
1586
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
|
1587
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
|
1588
|
+
}
|
|
1589
|
+
} else {
|
|
1590
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
|
1591
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
|
1592
|
+
}
|
|
1593
|
+
}
|
|
1224
1594
|
}
|
|
1595
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1225
1596
|
}
|
|
1226
1597
|
}
|
|
1227
1598
|
}
|
|
1228
1599
|
|
|
1229
|
-
template [[host_name("
|
|
1230
|
-
template [[host_name("
|
|
1600
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
|
1601
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
|
1231
1602
|
|
|
1232
1603
|
kernel void kernel_cpy_f16_f16(
|
|
1233
1604
|
device const half * src0,
|
|
@@ -1354,6 +1725,197 @@ kernel void kernel_cpy_f32_f32(
|
|
|
1354
1725
|
}
|
|
1355
1726
|
}
|
|
1356
1727
|
|
|
1728
|
+
kernel void kernel_cpy_f32_q8_0(
|
|
1729
|
+
device const float * src0,
|
|
1730
|
+
device void * dst,
|
|
1731
|
+
constant int64_t & ne00,
|
|
1732
|
+
constant int64_t & ne01,
|
|
1733
|
+
constant int64_t & ne02,
|
|
1734
|
+
constant int64_t & ne03,
|
|
1735
|
+
constant uint64_t & nb00,
|
|
1736
|
+
constant uint64_t & nb01,
|
|
1737
|
+
constant uint64_t & nb02,
|
|
1738
|
+
constant uint64_t & nb03,
|
|
1739
|
+
constant int64_t & ne0,
|
|
1740
|
+
constant int64_t & ne1,
|
|
1741
|
+
constant int64_t & ne2,
|
|
1742
|
+
constant int64_t & ne3,
|
|
1743
|
+
constant uint64_t & nb0,
|
|
1744
|
+
constant uint64_t & nb1,
|
|
1745
|
+
constant uint64_t & nb2,
|
|
1746
|
+
constant uint64_t & nb3,
|
|
1747
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1748
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1749
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1750
|
+
const int64_t i03 = tgpig[2];
|
|
1751
|
+
const int64_t i02 = tgpig[1];
|
|
1752
|
+
const int64_t i01 = tgpig[0];
|
|
1753
|
+
|
|
1754
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
1755
|
+
|
|
1756
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1757
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1758
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1759
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
|
1760
|
+
|
|
1761
|
+
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1762
|
+
|
|
1763
|
+
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
|
1764
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1765
|
+
|
|
1766
|
+
float amax = 0.0f; // absolute max
|
|
1767
|
+
|
|
1768
|
+
for (int j = 0; j < QK8_0; j++) {
|
|
1769
|
+
const float v = src[j];
|
|
1770
|
+
amax = MAX(amax, fabs(v));
|
|
1771
|
+
}
|
|
1772
|
+
|
|
1773
|
+
const float d = amax / ((1 << 7) - 1);
|
|
1774
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
1775
|
+
|
|
1776
|
+
dst_data[i00/QK8_0].d = d;
|
|
1777
|
+
|
|
1778
|
+
for (int j = 0; j < QK8_0; ++j) {
|
|
1779
|
+
const float x0 = src[j]*id;
|
|
1780
|
+
|
|
1781
|
+
dst_data[i00/QK8_0].qs[j] = round(x0);
|
|
1782
|
+
}
|
|
1783
|
+
}
|
|
1784
|
+
}
|
|
1785
|
+
|
|
1786
|
+
kernel void kernel_cpy_f32_q4_0(
|
|
1787
|
+
device const float * src0,
|
|
1788
|
+
device void * dst,
|
|
1789
|
+
constant int64_t & ne00,
|
|
1790
|
+
constant int64_t & ne01,
|
|
1791
|
+
constant int64_t & ne02,
|
|
1792
|
+
constant int64_t & ne03,
|
|
1793
|
+
constant uint64_t & nb00,
|
|
1794
|
+
constant uint64_t & nb01,
|
|
1795
|
+
constant uint64_t & nb02,
|
|
1796
|
+
constant uint64_t & nb03,
|
|
1797
|
+
constant int64_t & ne0,
|
|
1798
|
+
constant int64_t & ne1,
|
|
1799
|
+
constant int64_t & ne2,
|
|
1800
|
+
constant int64_t & ne3,
|
|
1801
|
+
constant uint64_t & nb0,
|
|
1802
|
+
constant uint64_t & nb1,
|
|
1803
|
+
constant uint64_t & nb2,
|
|
1804
|
+
constant uint64_t & nb3,
|
|
1805
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1806
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1807
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1808
|
+
const int64_t i03 = tgpig[2];
|
|
1809
|
+
const int64_t i02 = tgpig[1];
|
|
1810
|
+
const int64_t i01 = tgpig[0];
|
|
1811
|
+
|
|
1812
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
1813
|
+
|
|
1814
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1815
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1816
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1817
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
|
|
1818
|
+
|
|
1819
|
+
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1820
|
+
|
|
1821
|
+
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
|
|
1822
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1823
|
+
|
|
1824
|
+
float amax = 0.0f; // absolute max
|
|
1825
|
+
float max = 0.0f;
|
|
1826
|
+
|
|
1827
|
+
for (int j = 0; j < QK4_0; j++) {
|
|
1828
|
+
const float v = src[j];
|
|
1829
|
+
if (amax < fabs(v)) {
|
|
1830
|
+
amax = fabs(v);
|
|
1831
|
+
max = v;
|
|
1832
|
+
}
|
|
1833
|
+
}
|
|
1834
|
+
|
|
1835
|
+
const float d = max / -8;
|
|
1836
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
1837
|
+
|
|
1838
|
+
dst_data[i00/QK4_0].d = d;
|
|
1839
|
+
|
|
1840
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
|
1841
|
+
const float x0 = src[0 + j]*id;
|
|
1842
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
|
1843
|
+
|
|
1844
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
1845
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
1846
|
+
|
|
1847
|
+
dst_data[i00/QK4_0].qs[j] = xi0;
|
|
1848
|
+
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
|
1849
|
+
}
|
|
1850
|
+
}
|
|
1851
|
+
}
|
|
1852
|
+
|
|
1853
|
+
kernel void kernel_cpy_f32_q4_1(
|
|
1854
|
+
device const float * src0,
|
|
1855
|
+
device void * dst,
|
|
1856
|
+
constant int64_t & ne00,
|
|
1857
|
+
constant int64_t & ne01,
|
|
1858
|
+
constant int64_t & ne02,
|
|
1859
|
+
constant int64_t & ne03,
|
|
1860
|
+
constant uint64_t & nb00,
|
|
1861
|
+
constant uint64_t & nb01,
|
|
1862
|
+
constant uint64_t & nb02,
|
|
1863
|
+
constant uint64_t & nb03,
|
|
1864
|
+
constant int64_t & ne0,
|
|
1865
|
+
constant int64_t & ne1,
|
|
1866
|
+
constant int64_t & ne2,
|
|
1867
|
+
constant int64_t & ne3,
|
|
1868
|
+
constant uint64_t & nb0,
|
|
1869
|
+
constant uint64_t & nb1,
|
|
1870
|
+
constant uint64_t & nb2,
|
|
1871
|
+
constant uint64_t & nb3,
|
|
1872
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1873
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1874
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1875
|
+
const int64_t i03 = tgpig[2];
|
|
1876
|
+
const int64_t i02 = tgpig[1];
|
|
1877
|
+
const int64_t i01 = tgpig[0];
|
|
1878
|
+
|
|
1879
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
1880
|
+
|
|
1881
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1882
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1883
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1884
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
|
|
1885
|
+
|
|
1886
|
+
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1887
|
+
|
|
1888
|
+
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
|
|
1889
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1890
|
+
|
|
1891
|
+
float min = FLT_MAX;
|
|
1892
|
+
float max = -FLT_MAX;
|
|
1893
|
+
|
|
1894
|
+
for (int j = 0; j < QK4_1; j++) {
|
|
1895
|
+
const float v = src[j];
|
|
1896
|
+
if (min > v) min = v;
|
|
1897
|
+
if (max < v) max = v;
|
|
1898
|
+
}
|
|
1899
|
+
|
|
1900
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
|
1901
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
1902
|
+
|
|
1903
|
+
dst_data[i00/QK4_1].d = d;
|
|
1904
|
+
dst_data[i00/QK4_1].m = min;
|
|
1905
|
+
|
|
1906
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
|
1907
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
1908
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
1909
|
+
|
|
1910
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
1911
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
1912
|
+
|
|
1913
|
+
dst_data[i00/QK4_1].qs[j] = xi0;
|
|
1914
|
+
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
|
1915
|
+
}
|
|
1916
|
+
}
|
|
1917
|
+
}
|
|
1918
|
+
|
|
1357
1919
|
kernel void kernel_concat(
|
|
1358
1920
|
device const char * src0,
|
|
1359
1921
|
device const char * src1,
|
|
@@ -1511,23 +2073,30 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
1511
2073
|
constant int64_t & ne02[[buffer(5)]],
|
|
1512
2074
|
constant int64_t & ne10[[buffer(9)]],
|
|
1513
2075
|
constant int64_t & ne12[[buffer(11)]],
|
|
1514
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1515
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1516
|
-
constant uint &
|
|
2076
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2077
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2078
|
+
constant uint & r2 [[buffer(17)]],
|
|
2079
|
+
constant uint & r3 [[buffer(18)]],
|
|
1517
2080
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1518
|
-
uint
|
|
1519
|
-
uint
|
|
2081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1520
2083
|
|
|
1521
2084
|
const int nb = ne00/QK_K;
|
|
1522
2085
|
const int r0 = tgpig.x;
|
|
1523
2086
|
const int r1 = tgpig.y;
|
|
1524
|
-
const int
|
|
2087
|
+
const int im = tgpig.z;
|
|
1525
2088
|
|
|
1526
2089
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
1527
2090
|
const int ib_row = first_row * nb;
|
|
1528
|
-
|
|
2091
|
+
|
|
2092
|
+
const uint i12 = im%ne12;
|
|
2093
|
+
const uint i13 = im/ne12;
|
|
2094
|
+
|
|
2095
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2096
|
+
|
|
1529
2097
|
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
|
1530
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2098
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2099
|
+
|
|
1531
2100
|
float yl[32];
|
|
1532
2101
|
float sumf[N_DST]={0.f}, all_sum;
|
|
1533
2102
|
|
|
@@ -1536,11 +2105,11 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
1536
2105
|
#if QK_K == 256
|
|
1537
2106
|
const int ix = tiisg/8; // 0...3
|
|
1538
2107
|
const int it = tiisg%8; // 0...7
|
|
1539
|
-
const int
|
|
2108
|
+
const int iq = it/4; // 0 or 1
|
|
1540
2109
|
const int ir = it%4; // 0...3
|
|
1541
2110
|
const int is = (8*ir)/16;// 0 or 1
|
|
1542
2111
|
|
|
1543
|
-
device const float * y4 = y + ix * QK_K + 128 *
|
|
2112
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
|
1544
2113
|
|
|
1545
2114
|
for (int ib = ix; ib < nb; ib += 4) {
|
|
1546
2115
|
|
|
@@ -1552,8 +2121,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
1552
2121
|
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
|
1553
2122
|
}
|
|
1554
2123
|
|
|
1555
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*
|
|
1556
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 *
|
|
2124
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
|
2125
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
1557
2126
|
device const half * dh = &x[ib].d;
|
|
1558
2127
|
|
|
1559
2128
|
for (int row = 0; row < N_DST; row++) {
|
|
@@ -1640,7 +2209,7 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
1640
2209
|
for (int row = 0; row < N_DST; ++row) {
|
|
1641
2210
|
all_sum = simd_sum(sumf[row]);
|
|
1642
2211
|
if (tiisg == 0) {
|
|
1643
|
-
dst[r1*ne0 +
|
|
2212
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
1644
2213
|
}
|
|
1645
2214
|
}
|
|
1646
2215
|
}
|
|
@@ -1655,9 +2224,10 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1655
2224
|
constant int64_t & ne02[[buffer(5)]],
|
|
1656
2225
|
constant int64_t & ne10[[buffer(9)]],
|
|
1657
2226
|
constant int64_t & ne12[[buffer(11)]],
|
|
1658
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1659
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1660
|
-
constant uint &
|
|
2227
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2228
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2229
|
+
constant uint & r2 [[buffer(17)]],
|
|
2230
|
+
constant uint & r3 [[buffer(18)]],
|
|
1661
2231
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1662
2232
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1663
2233
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1666,12 +2236,17 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1666
2236
|
|
|
1667
2237
|
const int64_t r0 = tgpig.x;
|
|
1668
2238
|
const int64_t r1 = tgpig.y;
|
|
1669
|
-
const int64_t
|
|
2239
|
+
const int64_t im = tgpig.z;
|
|
1670
2240
|
|
|
1671
2241
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
|
1672
|
-
|
|
2242
|
+
|
|
2243
|
+
const uint i12 = im%ne12;
|
|
2244
|
+
const uint i13 = im/ne12;
|
|
2245
|
+
|
|
2246
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2247
|
+
|
|
1673
2248
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
|
1674
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2249
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
1675
2250
|
|
|
1676
2251
|
float yl[32];
|
|
1677
2252
|
|
|
@@ -1793,7 +2368,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1793
2368
|
}
|
|
1794
2369
|
if (tiisg == 0) {
|
|
1795
2370
|
for (int row = 0; row < 2; ++row) {
|
|
1796
|
-
dst[r1*ne0 +
|
|
2371
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
|
|
1797
2372
|
}
|
|
1798
2373
|
}
|
|
1799
2374
|
}
|
|
@@ -1807,26 +2382,33 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1807
2382
|
constant int64_t & ne02[[buffer(5)]],
|
|
1808
2383
|
constant int64_t & ne10[[buffer(9)]],
|
|
1809
2384
|
constant int64_t & ne12[[buffer(11)]],
|
|
1810
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1811
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1812
|
-
constant uint &
|
|
2385
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2386
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2387
|
+
constant uint & r2 [[buffer(17)]],
|
|
2388
|
+
constant uint & r3 [[buffer(18)]],
|
|
1813
2389
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1814
|
-
uint
|
|
1815
|
-
uint
|
|
2390
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2391
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1816
2392
|
|
|
1817
2393
|
const int nb = ne00/QK_K;
|
|
1818
2394
|
|
|
1819
2395
|
const int64_t r0 = tgpig.x;
|
|
1820
2396
|
const int64_t r1 = tgpig.y;
|
|
1821
|
-
const int64_t
|
|
2397
|
+
const int64_t im = tgpig.z;
|
|
1822
2398
|
|
|
1823
2399
|
const int row = 2 * r0 + sgitg;
|
|
1824
|
-
|
|
2400
|
+
|
|
2401
|
+
const uint i12 = im%ne12;
|
|
2402
|
+
const uint i13 = im/ne12;
|
|
2403
|
+
|
|
2404
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2405
|
+
|
|
1825
2406
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
|
1826
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2407
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2408
|
+
|
|
1827
2409
|
const int ix = tiisg/4;
|
|
1828
2410
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
|
1829
|
-
const int
|
|
2411
|
+
const int iq = il/8; // 0, 0, 1, 1
|
|
1830
2412
|
const int in = il%8; // 0, 4, 0, 4
|
|
1831
2413
|
|
|
1832
2414
|
float2 sum = {0.f, 0.f};
|
|
@@ -1846,7 +2428,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1846
2428
|
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
|
1847
2429
|
|
|
1848
2430
|
for (int l = 0; l < 4; l += 2) {
|
|
1849
|
-
const uint16_t hm = h[l/2] >>
|
|
2431
|
+
const uint16_t hm = h[l/2] >> iq;
|
|
1850
2432
|
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
|
1851
2433
|
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
|
1852
2434
|
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
|
@@ -1862,7 +2444,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1862
2444
|
|
|
1863
2445
|
const float tot = simd_sum(sumf);
|
|
1864
2446
|
if (tiisg == 0) {
|
|
1865
|
-
dst[r1*ne0 +
|
|
2447
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
|
1866
2448
|
}
|
|
1867
2449
|
|
|
1868
2450
|
}
|
|
@@ -1880,10 +2462,11 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1880
2462
|
constant int64_t & ne12 [[buffer(11)]],
|
|
1881
2463
|
constant int64_t & ne0 [[buffer(15)]],
|
|
1882
2464
|
constant int64_t & ne1 [[buffer(16)]],
|
|
1883
|
-
constant uint &
|
|
2465
|
+
constant uint & r2 [[buffer(17)]],
|
|
2466
|
+
constant uint & r3 [[buffer(18)]],
|
|
1884
2467
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1885
|
-
uint
|
|
1886
|
-
uint
|
|
2468
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2469
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1887
2470
|
|
|
1888
2471
|
const uint16_t kmask1 = 0x3f3f;
|
|
1889
2472
|
const uint16_t kmask2 = 0x0f0f;
|
|
@@ -1891,26 +2474,32 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1891
2474
|
|
|
1892
2475
|
const int ix = tiisg/8; // 0...3
|
|
1893
2476
|
const int it = tiisg%8; // 0...7
|
|
1894
|
-
const int
|
|
2477
|
+
const int iq = it/4; // 0 or 1
|
|
1895
2478
|
const int ir = it%4; // 0...3
|
|
1896
2479
|
|
|
1897
2480
|
const int nb = ne00/QK_K;
|
|
1898
2481
|
const int r0 = tgpig.x;
|
|
1899
2482
|
const int r1 = tgpig.y;
|
|
1900
|
-
const int
|
|
2483
|
+
const int im = tgpig.z;
|
|
1901
2484
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
1902
2485
|
const int first_row = r0 * N_DST;
|
|
1903
2486
|
const int ib_row = first_row * nb;
|
|
1904
|
-
|
|
2487
|
+
|
|
2488
|
+
const uint i12 = im%ne12;
|
|
2489
|
+
const uint i13 = im/ne12;
|
|
2490
|
+
|
|
2491
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2492
|
+
|
|
1905
2493
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
|
1906
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2494
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2495
|
+
|
|
1907
2496
|
float yl[16];
|
|
1908
2497
|
float yh[16];
|
|
1909
2498
|
float sumf[N_DST]={0.f}, all_sum;
|
|
1910
2499
|
|
|
1911
2500
|
const int step = sizeof(block_q4_K) * nb / 2;
|
|
1912
2501
|
|
|
1913
|
-
device const float * y4 = y + ix * QK_K + 64 *
|
|
2502
|
+
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
|
1914
2503
|
|
|
1915
2504
|
uint16_t sc16[4];
|
|
1916
2505
|
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
|
@@ -1925,8 +2514,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1925
2514
|
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
|
1926
2515
|
}
|
|
1927
2516
|
|
|
1928
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales +
|
|
1929
|
-
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 *
|
|
2517
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
|
|
2518
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
1930
2519
|
device const half * dh = &x[ib].d;
|
|
1931
2520
|
|
|
1932
2521
|
for (int row = 0; row < N_DST; row++) {
|
|
@@ -1970,7 +2559,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1970
2559
|
for (int row = 0; row < N_DST; ++row) {
|
|
1971
2560
|
all_sum = simd_sum(sumf[row]);
|
|
1972
2561
|
if (tiisg == 0) {
|
|
1973
|
-
dst[r1*ne0 +
|
|
2562
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
1974
2563
|
}
|
|
1975
2564
|
}
|
|
1976
2565
|
}
|
|
@@ -1984,9 +2573,10 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1984
2573
|
constant int64_t & ne02[[buffer(5)]],
|
|
1985
2574
|
constant int64_t & ne10[[buffer(9)]],
|
|
1986
2575
|
constant int64_t & ne12[[buffer(11)]],
|
|
1987
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1988
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1989
|
-
constant uint &
|
|
2576
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2577
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2578
|
+
constant uint & r2 [[buffer(17)]],
|
|
2579
|
+
constant uint & r3 [[buffer(18)]],
|
|
1990
2580
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1991
2581
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1992
2582
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1997,12 +2587,18 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1997
2587
|
const int nb = ne00/QK_K;
|
|
1998
2588
|
const int r0 = tgpig.x;
|
|
1999
2589
|
const int r1 = tgpig.y;
|
|
2000
|
-
const int
|
|
2590
|
+
const int im = tgpig.z;
|
|
2001
2591
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
2002
2592
|
const int ib_row = first_row * nb;
|
|
2003
|
-
|
|
2593
|
+
|
|
2594
|
+
const uint i12 = im%ne12;
|
|
2595
|
+
const uint i13 = im/ne12;
|
|
2596
|
+
|
|
2597
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2598
|
+
|
|
2004
2599
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
|
2005
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2600
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2601
|
+
|
|
2006
2602
|
float yl[8];
|
|
2007
2603
|
float yh[8];
|
|
2008
2604
|
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -2058,7 +2654,7 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
2058
2654
|
for (int row = 0; row < N_DST; ++row) {
|
|
2059
2655
|
all_sum = simd_sum(sumf[row]);
|
|
2060
2656
|
if (tiisg == 0) {
|
|
2061
|
-
dst[r1*ne0+
|
|
2657
|
+
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
|
2062
2658
|
}
|
|
2063
2659
|
}
|
|
2064
2660
|
}
|
|
@@ -2073,9 +2669,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2073
2669
|
constant int64_t & ne02[[buffer(5)]],
|
|
2074
2670
|
constant int64_t & ne10[[buffer(9)]],
|
|
2075
2671
|
constant int64_t & ne12[[buffer(11)]],
|
|
2076
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
2077
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
2078
|
-
constant uint &
|
|
2672
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2673
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2674
|
+
constant uint & r2 [[buffer(17)]],
|
|
2675
|
+
constant uint & r3 [[buffer(18)]],
|
|
2079
2676
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2080
2677
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2081
2678
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2084,12 +2681,17 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2084
2681
|
|
|
2085
2682
|
const int64_t r0 = tgpig.x;
|
|
2086
2683
|
const int64_t r1 = tgpig.y;
|
|
2087
|
-
const int
|
|
2684
|
+
const int im = tgpig.z;
|
|
2088
2685
|
|
|
2089
2686
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
|
2090
|
-
|
|
2687
|
+
|
|
2688
|
+
const uint i12 = im%ne12;
|
|
2689
|
+
const uint i13 = im/ne12;
|
|
2690
|
+
|
|
2691
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2692
|
+
|
|
2091
2693
|
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
|
2092
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2694
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2093
2695
|
|
|
2094
2696
|
float sumf[2]={0.f};
|
|
2095
2697
|
|
|
@@ -2105,15 +2707,15 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2105
2707
|
|
|
2106
2708
|
const int tid = tiisg/4;
|
|
2107
2709
|
const int ix = tiisg%4;
|
|
2108
|
-
const int
|
|
2710
|
+
const int iq = tid/4;
|
|
2109
2711
|
const int ir = tid%4;
|
|
2110
2712
|
const int n = 8;
|
|
2111
2713
|
|
|
2112
2714
|
const int l0 = n*ir;
|
|
2113
|
-
const int q_offset = 32*
|
|
2114
|
-
const int y_offset = 64*
|
|
2715
|
+
const int q_offset = 32*iq + l0;
|
|
2716
|
+
const int y_offset = 64*iq + l0;
|
|
2115
2717
|
|
|
2116
|
-
const uint8_t hm1 = 1u << (2*
|
|
2718
|
+
const uint8_t hm1 = 1u << (2*iq);
|
|
2117
2719
|
const uint8_t hm2 = hm1 << 1;
|
|
2118
2720
|
const uint8_t hm3 = hm1 << 4;
|
|
2119
2721
|
const uint8_t hm4 = hm2 << 4;
|
|
@@ -2128,7 +2730,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2128
2730
|
device const uint8_t * q1 = x[i].qs + q_offset;
|
|
2129
2731
|
device const uint8_t * qh = x[i].qh + l0;
|
|
2130
2732
|
device const half * dh = &x[i].d;
|
|
2131
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales +
|
|
2733
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
|
|
2132
2734
|
|
|
2133
2735
|
device const float * y2 = y1 + 128;
|
|
2134
2736
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
|
@@ -2184,7 +2786,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2184
2786
|
|
|
2185
2787
|
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
|
2186
2788
|
const int ix = tiisg%8;
|
|
2187
|
-
const int
|
|
2789
|
+
const int iq = il/8; // 0, 0, 1, 1
|
|
2188
2790
|
const int in = il%8; // 0, 4, 0, 4
|
|
2189
2791
|
|
|
2190
2792
|
device const float * y = yy + ix*QK_K + il;
|
|
@@ -2209,7 +2811,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2209
2811
|
|
|
2210
2812
|
float2 acc = {0.f, 0.f};
|
|
2211
2813
|
for (int l = 0; l < 4; ++l) {
|
|
2212
|
-
const uint8_t hl = h[l] >>
|
|
2814
|
+
const uint8_t hl = h[l] >> iq;
|
|
2213
2815
|
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
|
2214
2816
|
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
|
2215
2817
|
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
|
@@ -2231,7 +2833,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2231
2833
|
for (int row = 0; row < 2; ++row) {
|
|
2232
2834
|
const float tot = simd_sum(sumf[row]);
|
|
2233
2835
|
if (tiisg == 0) {
|
|
2234
|
-
dst[r1*ne0 +
|
|
2836
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
|
2235
2837
|
}
|
|
2236
2838
|
}
|
|
2237
2839
|
|
|
@@ -2246,9 +2848,10 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
2246
2848
|
constant int64_t & ne02[[buffer(5)]],
|
|
2247
2849
|
constant int64_t & ne10[[buffer(9)]],
|
|
2248
2850
|
constant int64_t & ne12[[buffer(11)]],
|
|
2249
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
2250
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
2251
|
-
constant uint &
|
|
2851
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2852
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2853
|
+
constant uint & r2 [[buffer(17)]],
|
|
2854
|
+
constant uint & r3 [[buffer(18)]],
|
|
2252
2855
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2253
2856
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2254
2857
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2262,12 +2865,17 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
2262
2865
|
|
|
2263
2866
|
const int64_t r0 = tgpig.x;
|
|
2264
2867
|
const int64_t r1 = tgpig.y;
|
|
2265
|
-
const int
|
|
2868
|
+
const int im = tgpig.z;
|
|
2266
2869
|
|
|
2267
2870
|
const int row = 2 * r0 + sgitg;
|
|
2268
|
-
|
|
2871
|
+
|
|
2872
|
+
const uint i12 = im%ne12;
|
|
2873
|
+
const uint i13 = im/ne12;
|
|
2874
|
+
|
|
2875
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2876
|
+
|
|
2269
2877
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
|
2270
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2878
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2271
2879
|
|
|
2272
2880
|
float sumf = 0;
|
|
2273
2881
|
|
|
@@ -2333,7 +2941,7 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
2333
2941
|
|
|
2334
2942
|
const float tot = simd_sum(sumf);
|
|
2335
2943
|
if (tiisg == 0) {
|
|
2336
|
-
dst[r1*ne0 +
|
|
2944
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
|
2337
2945
|
}
|
|
2338
2946
|
}
|
|
2339
2947
|
|
|
@@ -2643,24 +3251,25 @@ kernel void kernel_get_rows(
|
|
|
2643
3251
|
|
|
2644
3252
|
// each block_q contains 16*nl weights
|
|
2645
3253
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
3254
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
|
3255
|
+
device const uchar * src1,
|
|
3256
|
+
device float * dst,
|
|
3257
|
+
constant int64_t & ne00,
|
|
3258
|
+
constant int64_t & ne02,
|
|
3259
|
+
constant int64_t & nb01,
|
|
3260
|
+
constant int64_t & nb02,
|
|
3261
|
+
constant int64_t & ne12,
|
|
3262
|
+
constant int64_t & nb10,
|
|
3263
|
+
constant int64_t & nb11,
|
|
3264
|
+
constant int64_t & nb12,
|
|
3265
|
+
constant int64_t & ne0,
|
|
3266
|
+
constant int64_t & ne1,
|
|
3267
|
+
constant uint & r2,
|
|
3268
|
+
constant uint & r3,
|
|
3269
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3270
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3271
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3272
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2664
3273
|
|
|
2665
3274
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
2666
3275
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
@@ -2686,7 +3295,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2686
3295
|
|
|
2687
3296
|
short il = (tiitg % THREAD_PER_ROW);
|
|
2688
3297
|
|
|
2689
|
-
uint
|
|
3298
|
+
const uint i12 = im%ne12;
|
|
3299
|
+
const uint i13 = im/ne12;
|
|
3300
|
+
|
|
3301
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
2690
3302
|
ushort offset1 = il/nl;
|
|
2691
3303
|
|
|
2692
3304
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
@@ -2770,14 +3382,116 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2770
3382
|
}
|
|
2771
3383
|
}
|
|
2772
3384
|
|
|
3385
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3386
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
|
3387
|
+
device const uchar * src1,
|
|
3388
|
+
device float * dst,
|
|
3389
|
+
constant int64_t & ne00,
|
|
3390
|
+
constant int64_t & ne02,
|
|
3391
|
+
constant int64_t & nb01,
|
|
3392
|
+
constant int64_t & nb02,
|
|
3393
|
+
constant int64_t & ne12,
|
|
3394
|
+
constant int64_t & nb10,
|
|
3395
|
+
constant int64_t & nb11,
|
|
3396
|
+
constant int64_t & nb12,
|
|
3397
|
+
constant int64_t & ne0,
|
|
3398
|
+
constant int64_t & ne1,
|
|
3399
|
+
constant uint & r2,
|
|
3400
|
+
constant uint & r3,
|
|
3401
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3402
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3403
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3404
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3405
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
3406
|
+
src0,
|
|
3407
|
+
src1,
|
|
3408
|
+
dst,
|
|
3409
|
+
ne00,
|
|
3410
|
+
ne02,
|
|
3411
|
+
nb01,
|
|
3412
|
+
nb02,
|
|
3413
|
+
ne12,
|
|
3414
|
+
nb10,
|
|
3415
|
+
nb11,
|
|
3416
|
+
nb12,
|
|
3417
|
+
ne0,
|
|
3418
|
+
ne1,
|
|
3419
|
+
r2,
|
|
3420
|
+
r3,
|
|
3421
|
+
shared_memory,
|
|
3422
|
+
tgpig,
|
|
3423
|
+
tiitg,
|
|
3424
|
+
sgitg);
|
|
3425
|
+
}
|
|
3426
|
+
|
|
3427
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3428
|
+
kernel void kernel_mul_mm_id(
|
|
3429
|
+
device const int32_t * ids,
|
|
3430
|
+
device const uchar * src1,
|
|
3431
|
+
device float * dst,
|
|
3432
|
+
constant int64_t & ne00,
|
|
3433
|
+
constant int64_t & ne02,
|
|
3434
|
+
constant int64_t & nb01,
|
|
3435
|
+
constant int64_t & nb02,
|
|
3436
|
+
constant int64_t & ne12,
|
|
3437
|
+
constant int64_t & nb10,
|
|
3438
|
+
constant int64_t & nb11,
|
|
3439
|
+
constant int64_t & nb12,
|
|
3440
|
+
constant int64_t & ne0,
|
|
3441
|
+
constant int64_t & ne1,
|
|
3442
|
+
constant uint & r2,
|
|
3443
|
+
constant uint & r3,
|
|
3444
|
+
constant int & idx,
|
|
3445
|
+
device const uchar * src00,
|
|
3446
|
+
device const uchar * src01,
|
|
3447
|
+
device const uchar * src02,
|
|
3448
|
+
device const uchar * src03,
|
|
3449
|
+
device const uchar * src04,
|
|
3450
|
+
device const uchar * src05,
|
|
3451
|
+
device const uchar * src06,
|
|
3452
|
+
device const uchar * src07,
|
|
3453
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3454
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3455
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3456
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3457
|
+
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
3458
|
+
|
|
3459
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
3460
|
+
src0[ids[idx]],
|
|
3461
|
+
src1,
|
|
3462
|
+
dst,
|
|
3463
|
+
ne00,
|
|
3464
|
+
ne02,
|
|
3465
|
+
nb01,
|
|
3466
|
+
nb02,
|
|
3467
|
+
ne12,
|
|
3468
|
+
nb10,
|
|
3469
|
+
nb11,
|
|
3470
|
+
nb12,
|
|
3471
|
+
ne0,
|
|
3472
|
+
ne1,
|
|
3473
|
+
r2,
|
|
3474
|
+
r3,
|
|
3475
|
+
shared_memory,
|
|
3476
|
+
tgpig,
|
|
3477
|
+
tiitg,
|
|
3478
|
+
sgitg);
|
|
3479
|
+
}
|
|
3480
|
+
|
|
2773
3481
|
#if QK_K == 256
|
|
2774
3482
|
#define QK_NL 16
|
|
2775
3483
|
#else
|
|
2776
3484
|
#define QK_NL 4
|
|
2777
3485
|
#endif
|
|
2778
3486
|
|
|
2779
|
-
typedef void (get_rows_t)(
|
|
2780
|
-
|
|
3487
|
+
typedef void (get_rows_t)(
|
|
3488
|
+
device const void * src0,
|
|
3489
|
+
device const int * src1,
|
|
3490
|
+
device float * dst,
|
|
3491
|
+
constant int64_t & ne00,
|
|
3492
|
+
constant uint64_t & nb01,
|
|
3493
|
+
constant uint64_t & nb1,
|
|
3494
|
+
uint, uint, uint);
|
|
2781
3495
|
|
|
2782
3496
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
2783
3497
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
@@ -2806,8 +3520,10 @@ typedef void (mat_mm_t)(
|
|
|
2806
3520
|
constant int64_t & nb12,
|
|
2807
3521
|
constant int64_t & ne0,
|
|
2808
3522
|
constant int64_t & ne1,
|
|
2809
|
-
constant uint &
|
|
2810
|
-
|
|
3523
|
+
constant uint & r2,
|
|
3524
|
+
constant uint & r3,
|
|
3525
|
+
threadgroup uchar *,
|
|
3526
|
+
uint3, uint, uint);
|
|
2811
3527
|
|
|
2812
3528
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
|
2813
3529
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
@@ -2821,3 +3537,44 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
2821
3537
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
2822
3538
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
2823
3539
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
3540
|
+
|
|
3541
|
+
typedef void (mat_mm_id_t)(
|
|
3542
|
+
device const int32_t * ids,
|
|
3543
|
+
device const uchar * src1,
|
|
3544
|
+
device float * dst,
|
|
3545
|
+
constant int64_t & ne00,
|
|
3546
|
+
constant int64_t & ne02,
|
|
3547
|
+
constant int64_t & nb01,
|
|
3548
|
+
constant int64_t & nb02,
|
|
3549
|
+
constant int64_t & ne12,
|
|
3550
|
+
constant int64_t & nb10,
|
|
3551
|
+
constant int64_t & nb11,
|
|
3552
|
+
constant int64_t & nb12,
|
|
3553
|
+
constant int64_t & ne0,
|
|
3554
|
+
constant int64_t & ne1,
|
|
3555
|
+
constant uint & r2,
|
|
3556
|
+
constant uint & r3,
|
|
3557
|
+
constant int & idx,
|
|
3558
|
+
device const uchar * src00,
|
|
3559
|
+
device const uchar * src01,
|
|
3560
|
+
device const uchar * src02,
|
|
3561
|
+
device const uchar * src03,
|
|
3562
|
+
device const uchar * src04,
|
|
3563
|
+
device const uchar * src05,
|
|
3564
|
+
device const uchar * src06,
|
|
3565
|
+
device const uchar * src07,
|
|
3566
|
+
threadgroup uchar *,
|
|
3567
|
+
uint3, uint, uint);
|
|
3568
|
+
|
|
3569
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
|
3570
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
3571
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
|
3572
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
|
3573
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
|
3574
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
|
3575
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
|
3576
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
3577
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
3578
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
3579
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
3580
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|