whisper.rn 0.4.0-rc.4 → 0.4.0-rc.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +5 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +57 -134
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +188 -112
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +55 -19
- package/cpp/ggml-alloc.h +8 -1
- package/cpp/ggml-backend-impl.h +46 -21
- package/cpp/ggml-backend.c +563 -156
- package/cpp/ggml-backend.h +62 -17
- package/cpp/ggml-impl.h +1 -1
- package/cpp/ggml-metal-whisper.metal +2444 -359
- package/cpp/ggml-metal.h +7 -1
- package/cpp/ggml-metal.m +1105 -197
- package/cpp/ggml-quants.c +66 -61
- package/cpp/ggml-quants.h +40 -40
- package/cpp/ggml.c +1040 -1590
- package/cpp/ggml.h +109 -30
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +143 -59
- package/cpp/rn-whisper.h +48 -15
- package/cpp/whisper.cpp +1635 -928
- package/cpp/whisper.h +55 -10
- package/ios/RNWhisper.mm +7 -7
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +3 -11
- package/ios/RNWhisperContext.mm +68 -137
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/index.d.ts +5 -0
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +6 -5
- package/src/index.ts +5 -0
- package/src/version.json +1 -1
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
using namespace metal;
|
|
4
4
|
|
|
5
5
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
6
|
+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
|
7
|
+
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
|
6
8
|
|
|
7
9
|
#define QK4_0 32
|
|
8
10
|
#define QR4_0 2
|
|
@@ -39,8 +41,15 @@ typedef struct {
|
|
|
39
41
|
int8_t qs[QK8_0]; // quants
|
|
40
42
|
} block_q8_0;
|
|
41
43
|
|
|
42
|
-
//
|
|
43
|
-
|
|
44
|
+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
45
|
+
|
|
46
|
+
enum ggml_sort_order {
|
|
47
|
+
GGML_SORT_ASC,
|
|
48
|
+
GGML_SORT_DESC,
|
|
49
|
+
};
|
|
50
|
+
|
|
51
|
+
// general-purpose kernel for addition, multiplication and division of two tensors
|
|
52
|
+
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
|
44
53
|
// cons: not very efficient
|
|
45
54
|
kernel void kernel_add(
|
|
46
55
|
device const char * src0,
|
|
@@ -70,6 +79,7 @@ kernel void kernel_add(
|
|
|
70
79
|
constant int64_t & nb1,
|
|
71
80
|
constant int64_t & nb2,
|
|
72
81
|
constant int64_t & nb3,
|
|
82
|
+
constant int64_t & offs,
|
|
73
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
74
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
75
85
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -81,16 +91,111 @@ kernel void kernel_add(
|
|
|
81
91
|
const int64_t i12 = i02 % ne12;
|
|
82
92
|
const int64_t i11 = i01 % ne11;
|
|
83
93
|
|
|
84
|
-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 +
|
|
85
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11
|
|
86
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 +
|
|
94
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
|
95
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
96
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
|
97
|
+
|
|
98
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
99
|
+
const int i10 = i0 % ne10;
|
|
100
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
kernel void kernel_mul(
|
|
105
|
+
device const char * src0,
|
|
106
|
+
device const char * src1,
|
|
107
|
+
device char * dst,
|
|
108
|
+
constant int64_t & ne00,
|
|
109
|
+
constant int64_t & ne01,
|
|
110
|
+
constant int64_t & ne02,
|
|
111
|
+
constant int64_t & ne03,
|
|
112
|
+
constant int64_t & nb00,
|
|
113
|
+
constant int64_t & nb01,
|
|
114
|
+
constant int64_t & nb02,
|
|
115
|
+
constant int64_t & nb03,
|
|
116
|
+
constant int64_t & ne10,
|
|
117
|
+
constant int64_t & ne11,
|
|
118
|
+
constant int64_t & ne12,
|
|
119
|
+
constant int64_t & ne13,
|
|
120
|
+
constant int64_t & nb10,
|
|
121
|
+
constant int64_t & nb11,
|
|
122
|
+
constant int64_t & nb12,
|
|
123
|
+
constant int64_t & nb13,
|
|
124
|
+
constant int64_t & ne0,
|
|
125
|
+
constant int64_t & ne1,
|
|
126
|
+
constant int64_t & ne2,
|
|
127
|
+
constant int64_t & ne3,
|
|
128
|
+
constant int64_t & nb0,
|
|
129
|
+
constant int64_t & nb1,
|
|
130
|
+
constant int64_t & nb2,
|
|
131
|
+
constant int64_t & nb3,
|
|
132
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
133
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
134
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
135
|
+
const int64_t i03 = tgpig.z;
|
|
136
|
+
const int64_t i02 = tgpig.y;
|
|
137
|
+
const int64_t i01 = tgpig.x;
|
|
138
|
+
|
|
139
|
+
const int64_t i13 = i03 % ne13;
|
|
140
|
+
const int64_t i12 = i02 % ne12;
|
|
141
|
+
const int64_t i11 = i01 % ne11;
|
|
142
|
+
|
|
143
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
144
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
145
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
87
146
|
|
|
88
147
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
89
|
-
|
|
148
|
+
const int i10 = i0 % ne10;
|
|
149
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
kernel void kernel_div(
|
|
154
|
+
device const char * src0,
|
|
155
|
+
device const char * src1,
|
|
156
|
+
device char * dst,
|
|
157
|
+
constant int64_t & ne00,
|
|
158
|
+
constant int64_t & ne01,
|
|
159
|
+
constant int64_t & ne02,
|
|
160
|
+
constant int64_t & ne03,
|
|
161
|
+
constant int64_t & nb00,
|
|
162
|
+
constant int64_t & nb01,
|
|
163
|
+
constant int64_t & nb02,
|
|
164
|
+
constant int64_t & nb03,
|
|
165
|
+
constant int64_t & ne10,
|
|
166
|
+
constant int64_t & ne11,
|
|
167
|
+
constant int64_t & ne12,
|
|
168
|
+
constant int64_t & ne13,
|
|
169
|
+
constant int64_t & nb10,
|
|
170
|
+
constant int64_t & nb11,
|
|
171
|
+
constant int64_t & nb12,
|
|
172
|
+
constant int64_t & nb13,
|
|
173
|
+
constant int64_t & ne0,
|
|
174
|
+
constant int64_t & ne1,
|
|
175
|
+
constant int64_t & ne2,
|
|
176
|
+
constant int64_t & ne3,
|
|
177
|
+
constant int64_t & nb0,
|
|
178
|
+
constant int64_t & nb1,
|
|
179
|
+
constant int64_t & nb2,
|
|
180
|
+
constant int64_t & nb3,
|
|
181
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
182
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
183
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
184
|
+
const int64_t i03 = tgpig.z;
|
|
185
|
+
const int64_t i02 = tgpig.y;
|
|
186
|
+
const int64_t i01 = tgpig.x;
|
|
187
|
+
|
|
188
|
+
const int64_t i13 = i03 % ne13;
|
|
189
|
+
const int64_t i12 = i02 % ne12;
|
|
190
|
+
const int64_t i11 = i01 % ne11;
|
|
90
191
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
192
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
193
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
194
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
|
195
|
+
|
|
196
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
197
|
+
const int i10 = i0 % ne10;
|
|
198
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
|
|
94
199
|
}
|
|
95
200
|
}
|
|
96
201
|
|
|
@@ -100,28 +205,27 @@ kernel void kernel_add_row(
|
|
|
100
205
|
device const float4 * src0,
|
|
101
206
|
device const float4 * src1,
|
|
102
207
|
device float4 * dst,
|
|
103
|
-
constant int64_t & nb [[buffer(
|
|
208
|
+
constant int64_t & nb [[buffer(28)]],
|
|
104
209
|
uint tpig[[thread_position_in_grid]]) {
|
|
105
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
106
211
|
}
|
|
107
212
|
|
|
108
|
-
kernel void
|
|
213
|
+
kernel void kernel_mul_row(
|
|
109
214
|
device const float4 * src0,
|
|
110
215
|
device const float4 * src1,
|
|
111
216
|
device float4 * dst,
|
|
217
|
+
constant int64_t & nb [[buffer(28)]],
|
|
112
218
|
uint tpig[[thread_position_in_grid]]) {
|
|
113
|
-
dst[tpig] = src0[tpig] * src1[tpig];
|
|
219
|
+
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
|
114
220
|
}
|
|
115
221
|
|
|
116
|
-
|
|
117
|
-
// broadcast src1 into src0
|
|
118
|
-
kernel void kernel_mul_row(
|
|
222
|
+
kernel void kernel_div_row(
|
|
119
223
|
device const float4 * src0,
|
|
120
224
|
device const float4 * src1,
|
|
121
225
|
device float4 * dst,
|
|
122
|
-
constant int64_t & nb,
|
|
226
|
+
constant int64_t & nb [[buffer(28)]],
|
|
123
227
|
uint tpig[[thread_position_in_grid]]) {
|
|
124
|
-
dst[tpig] = src0[tpig]
|
|
228
|
+
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
|
125
229
|
}
|
|
126
230
|
|
|
127
231
|
kernel void kernel_scale(
|
|
@@ -140,14 +244,6 @@ kernel void kernel_scale_4(
|
|
|
140
244
|
dst[tpig] = src0[tpig] * scale;
|
|
141
245
|
}
|
|
142
246
|
|
|
143
|
-
kernel void kernel_silu(
|
|
144
|
-
device const float4 * src0,
|
|
145
|
-
device float4 * dst,
|
|
146
|
-
uint tpig[[thread_position_in_grid]]) {
|
|
147
|
-
device const float4 & x = src0[tpig];
|
|
148
|
-
dst[tpig] = x / (1.0f + exp(-x));
|
|
149
|
-
}
|
|
150
|
-
|
|
151
247
|
kernel void kernel_relu(
|
|
152
248
|
device const float * src0,
|
|
153
249
|
device float * dst,
|
|
@@ -155,15 +251,17 @@ kernel void kernel_relu(
|
|
|
155
251
|
dst[tpig] = max(0.0f, src0[tpig]);
|
|
156
252
|
}
|
|
157
253
|
|
|
158
|
-
kernel void
|
|
254
|
+
kernel void kernel_tanh(
|
|
159
255
|
device const float * src0,
|
|
160
256
|
device float * dst,
|
|
161
257
|
uint tpig[[thread_position_in_grid]]) {
|
|
162
|
-
|
|
258
|
+
device const float & x = src0[tpig];
|
|
259
|
+
dst[tpig] = precise::tanh(x);
|
|
163
260
|
}
|
|
164
261
|
|
|
165
|
-
constant float GELU_COEF_A
|
|
166
|
-
constant float
|
|
262
|
+
constant float GELU_COEF_A = 0.044715f;
|
|
263
|
+
constant float GELU_QUICK_COEF = -1.702f;
|
|
264
|
+
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
167
265
|
|
|
168
266
|
kernel void kernel_gelu(
|
|
169
267
|
device const float4 * src0,
|
|
@@ -178,12 +276,86 @@ kernel void kernel_gelu(
|
|
|
178
276
|
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
179
277
|
}
|
|
180
278
|
|
|
279
|
+
kernel void kernel_gelu_quick(
|
|
280
|
+
device const float4 * src0,
|
|
281
|
+
device float4 * dst,
|
|
282
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
283
|
+
device const float4 & x = src0[tpig];
|
|
284
|
+
|
|
285
|
+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
kernel void kernel_silu(
|
|
289
|
+
device const float4 * src0,
|
|
290
|
+
device float4 * dst,
|
|
291
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
292
|
+
device const float4 & x = src0[tpig];
|
|
293
|
+
dst[tpig] = x / (1.0f + exp(-x));
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
kernel void kernel_sqr(
|
|
297
|
+
device const float * src0,
|
|
298
|
+
device float * dst,
|
|
299
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
300
|
+
dst[tpig] = src0[tpig] * src0[tpig];
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
kernel void kernel_sum_rows(
|
|
304
|
+
device const float * src0,
|
|
305
|
+
device float * dst,
|
|
306
|
+
constant int64_t & ne00,
|
|
307
|
+
constant int64_t & ne01,
|
|
308
|
+
constant int64_t & ne02,
|
|
309
|
+
constant int64_t & ne03,
|
|
310
|
+
constant int64_t & nb00,
|
|
311
|
+
constant int64_t & nb01,
|
|
312
|
+
constant int64_t & nb02,
|
|
313
|
+
constant int64_t & nb03,
|
|
314
|
+
constant int64_t & ne10,
|
|
315
|
+
constant int64_t & ne11,
|
|
316
|
+
constant int64_t & ne12,
|
|
317
|
+
constant int64_t & ne13,
|
|
318
|
+
constant int64_t & nb10,
|
|
319
|
+
constant int64_t & nb11,
|
|
320
|
+
constant int64_t & nb12,
|
|
321
|
+
constant int64_t & nb13,
|
|
322
|
+
constant int64_t & ne0,
|
|
323
|
+
constant int64_t & ne1,
|
|
324
|
+
constant int64_t & ne2,
|
|
325
|
+
constant int64_t & ne3,
|
|
326
|
+
constant int64_t & nb0,
|
|
327
|
+
constant int64_t & nb1,
|
|
328
|
+
constant int64_t & nb2,
|
|
329
|
+
constant int64_t & nb3,
|
|
330
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
|
331
|
+
int64_t i3 = tpig.z;
|
|
332
|
+
int64_t i2 = tpig.y;
|
|
333
|
+
int64_t i1 = tpig.x;
|
|
334
|
+
|
|
335
|
+
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
|
336
|
+
return;
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
|
340
|
+
device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
|
341
|
+
|
|
342
|
+
float row_sum = 0;
|
|
343
|
+
|
|
344
|
+
for (int64_t i0 = 0; i0 < ne00; i0++) {
|
|
345
|
+
row_sum += src_row[i0];
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
dst_row[0] = row_sum;
|
|
349
|
+
}
|
|
350
|
+
|
|
181
351
|
kernel void kernel_soft_max(
|
|
182
352
|
device const float * src0,
|
|
353
|
+
device const float * src1,
|
|
183
354
|
device float * dst,
|
|
184
355
|
constant int64_t & ne00,
|
|
185
356
|
constant int64_t & ne01,
|
|
186
357
|
constant int64_t & ne02,
|
|
358
|
+
constant float & scale,
|
|
187
359
|
threadgroup float * buf [[threadgroup(0)]],
|
|
188
360
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
189
361
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
@@ -194,73 +366,82 @@ kernel void kernel_soft_max(
|
|
|
194
366
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
195
367
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
196
368
|
|
|
197
|
-
device const float * psrc0 =
|
|
198
|
-
device
|
|
369
|
+
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
370
|
+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
|
371
|
+
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
199
372
|
|
|
200
373
|
// parallel max
|
|
201
|
-
float lmax =
|
|
374
|
+
float lmax = -INFINITY;
|
|
202
375
|
|
|
203
|
-
for (int i00 = tpitg
|
|
204
|
-
lmax = MAX(lmax, psrc0[i00]);
|
|
376
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
377
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
205
378
|
}
|
|
206
379
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
380
|
+
// find the max value in the block
|
|
381
|
+
float max_val = simd_max(lmax);
|
|
382
|
+
if (ntg > N_SIMDWIDTH) {
|
|
383
|
+
if (sgitg == 0) {
|
|
384
|
+
buf[tiisg] = -INFINITY;
|
|
385
|
+
}
|
|
211
386
|
|
|
212
|
-
|
|
387
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
213
388
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
218
|
-
}
|
|
219
|
-
}
|
|
389
|
+
if (tiisg == 0) {
|
|
390
|
+
buf[sgitg] = max_val;
|
|
391
|
+
}
|
|
220
392
|
|
|
221
|
-
|
|
393
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
222
394
|
|
|
223
|
-
|
|
395
|
+
max_val = buf[tiisg];
|
|
396
|
+
max_val = simd_max(max_val);
|
|
397
|
+
}
|
|
224
398
|
|
|
225
399
|
// parallel sum
|
|
226
400
|
float lsum = 0.0f;
|
|
227
401
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
228
|
-
const float exp_psrc0 = exp(psrc0[i00] -
|
|
402
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
229
403
|
lsum += exp_psrc0;
|
|
230
|
-
// Remember the result of exp here. exp is expensive, so we really do not
|
|
231
|
-
// wish to compute it twice.
|
|
232
404
|
pdst[i00] = exp_psrc0;
|
|
233
405
|
}
|
|
234
406
|
|
|
407
|
+
// This barrier fixes a failing test
|
|
408
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
|
409
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
410
|
+
|
|
235
411
|
float sum = simd_sum(lsum);
|
|
236
|
-
if (tiisg == 0) {
|
|
237
|
-
buf[sgitg] = sum;
|
|
238
|
-
}
|
|
239
412
|
|
|
240
|
-
|
|
413
|
+
if (ntg > N_SIMDWIDTH) {
|
|
414
|
+
if (sgitg == 0) {
|
|
415
|
+
buf[tiisg] = 0.0f;
|
|
416
|
+
}
|
|
241
417
|
|
|
242
|
-
|
|
243
|
-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
244
|
-
if (tpitg < i) {
|
|
245
|
-
buf[tpitg] += buf[tpitg + i];
|
|
246
|
-
}
|
|
247
|
-
}
|
|
418
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
248
419
|
|
|
249
|
-
|
|
420
|
+
if (tiisg == 0) {
|
|
421
|
+
buf[sgitg] = sum;
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
425
|
+
|
|
426
|
+
sum = buf[tiisg];
|
|
427
|
+
sum = simd_sum(sum);
|
|
428
|
+
}
|
|
250
429
|
|
|
251
|
-
|
|
430
|
+
const float inv_sum = 1.0f/sum;
|
|
252
431
|
|
|
253
432
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
254
|
-
pdst[i00]
|
|
433
|
+
pdst[i00] *= inv_sum;
|
|
255
434
|
}
|
|
256
435
|
}
|
|
257
436
|
|
|
258
437
|
kernel void kernel_soft_max_4(
|
|
259
438
|
device const float * src0,
|
|
439
|
+
device const float * src1,
|
|
260
440
|
device float * dst,
|
|
261
441
|
constant int64_t & ne00,
|
|
262
442
|
constant int64_t & ne01,
|
|
263
443
|
constant int64_t & ne02,
|
|
444
|
+
constant float & scale,
|
|
264
445
|
threadgroup float * buf [[threadgroup(0)]],
|
|
265
446
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
266
447
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
@@ -271,64 +452,74 @@ kernel void kernel_soft_max_4(
|
|
|
271
452
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
272
453
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
273
454
|
|
|
274
|
-
device const float4 * psrc4 =
|
|
275
|
-
device
|
|
455
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
456
|
+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
457
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
276
458
|
|
|
277
459
|
// parallel max
|
|
278
|
-
float4 lmax4 =
|
|
460
|
+
float4 lmax4 = -INFINITY;
|
|
279
461
|
|
|
280
|
-
for (int i00 = tpitg
|
|
281
|
-
lmax4 = fmax(lmax4, psrc4[i00]);
|
|
462
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
463
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
282
464
|
}
|
|
283
465
|
|
|
284
466
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
285
|
-
float max = simd_max(lmax);
|
|
286
|
-
if (tiisg == 0) {
|
|
287
|
-
buf[sgitg] = max;
|
|
288
|
-
}
|
|
289
467
|
|
|
290
|
-
|
|
468
|
+
float max_val = simd_max(lmax);
|
|
469
|
+
if (ntg > N_SIMDWIDTH) {
|
|
470
|
+
if (sgitg == 0) {
|
|
471
|
+
buf[tiisg] = -INFINITY;
|
|
472
|
+
}
|
|
291
473
|
|
|
292
|
-
|
|
293
|
-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
294
|
-
if (tpitg < i) {
|
|
295
|
-
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
296
|
-
}
|
|
297
|
-
}
|
|
474
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
298
475
|
|
|
299
|
-
|
|
476
|
+
if (tiisg == 0) {
|
|
477
|
+
buf[sgitg] = max_val;
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
300
481
|
|
|
301
|
-
|
|
482
|
+
max_val = buf[tiisg];
|
|
483
|
+
max_val = simd_max(max_val);
|
|
484
|
+
}
|
|
302
485
|
|
|
303
486
|
// parallel sum
|
|
304
487
|
float4 lsum4 = 0.0f;
|
|
305
488
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
306
|
-
const float4 exp_psrc4 = exp(psrc4[i00] -
|
|
489
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
307
490
|
lsum4 += exp_psrc4;
|
|
308
491
|
pdst4[i00] = exp_psrc4;
|
|
309
492
|
}
|
|
310
493
|
|
|
311
494
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
495
|
+
|
|
496
|
+
// This barrier fixes a failing test
|
|
497
|
+
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
|
498
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
499
|
+
|
|
312
500
|
float sum = simd_sum(lsum);
|
|
313
|
-
if (tiisg == 0) {
|
|
314
|
-
buf[sgitg] = sum;
|
|
315
|
-
}
|
|
316
501
|
|
|
317
|
-
|
|
502
|
+
if (ntg > N_SIMDWIDTH) {
|
|
503
|
+
if (sgitg == 0) {
|
|
504
|
+
buf[tiisg] = 0.0f;
|
|
505
|
+
}
|
|
318
506
|
|
|
319
|
-
|
|
320
|
-
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
321
|
-
if (tpitg < i) {
|
|
322
|
-
buf[tpitg] += buf[tpitg + i];
|
|
323
|
-
}
|
|
324
|
-
}
|
|
507
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
325
508
|
|
|
326
|
-
|
|
509
|
+
if (tiisg == 0) {
|
|
510
|
+
buf[sgitg] = sum;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
514
|
+
|
|
515
|
+
sum = buf[tiisg];
|
|
516
|
+
sum = simd_sum(sum);
|
|
517
|
+
}
|
|
327
518
|
|
|
328
|
-
|
|
519
|
+
const float inv_sum = 1.0f/sum;
|
|
329
520
|
|
|
330
521
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
331
|
-
pdst4[i00]
|
|
522
|
+
pdst4[i00] *= inv_sum;
|
|
332
523
|
}
|
|
333
524
|
}
|
|
334
525
|
|
|
@@ -435,14 +626,13 @@ kernel void kernel_rms_norm(
|
|
|
435
626
|
constant int64_t & ne00,
|
|
436
627
|
constant uint64_t & nb01,
|
|
437
628
|
constant float & eps,
|
|
438
|
-
threadgroup float *
|
|
629
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
439
630
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
440
631
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
441
632
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
442
633
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
443
634
|
uint ntg[[threads_per_threadgroup]]) {
|
|
444
|
-
device const float4 * x
|
|
445
|
-
device const float * x_scalar = (device const float *) x;
|
|
635
|
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
|
446
636
|
|
|
447
637
|
float4 sumf = 0;
|
|
448
638
|
float all_sum = 0;
|
|
@@ -453,52 +643,130 @@ kernel void kernel_rms_norm(
|
|
|
453
643
|
}
|
|
454
644
|
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
|
455
645
|
all_sum = simd_sum(all_sum);
|
|
456
|
-
if (
|
|
457
|
-
|
|
458
|
-
|
|
646
|
+
if (ntg > N_SIMDWIDTH) {
|
|
647
|
+
if (sgitg == 0) {
|
|
648
|
+
buf[tiisg] = 0.0f;
|
|
649
|
+
}
|
|
459
650
|
|
|
460
|
-
|
|
651
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
461
652
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
if (tpitg < i) {
|
|
465
|
-
sum[tpitg] += sum[tpitg + i];
|
|
466
|
-
}
|
|
467
|
-
}
|
|
468
|
-
if (tpitg == 0) {
|
|
469
|
-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
|
470
|
-
sum[0] += x_scalar[i];
|
|
653
|
+
if (tiisg == 0) {
|
|
654
|
+
buf[sgitg] = all_sum;
|
|
471
655
|
}
|
|
472
|
-
sum[0] /= ne00;
|
|
473
|
-
}
|
|
474
656
|
|
|
475
|
-
|
|
657
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
658
|
+
|
|
659
|
+
all_sum = buf[tiisg];
|
|
660
|
+
all_sum = simd_sum(all_sum);
|
|
661
|
+
}
|
|
476
662
|
|
|
477
|
-
const float mean =
|
|
663
|
+
const float mean = all_sum/ne00;
|
|
478
664
|
const float scale = 1.0f/sqrt(mean + eps);
|
|
479
665
|
|
|
480
666
|
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
|
481
|
-
device float * y_scalar = (device float *) y;
|
|
482
667
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
483
668
|
y[i00] = x[i00] * scale;
|
|
484
669
|
}
|
|
485
|
-
if (tpitg == 0) {
|
|
486
|
-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
|
487
|
-
y_scalar[i00] = x_scalar[i00] * scale;
|
|
488
|
-
}
|
|
489
|
-
}
|
|
490
670
|
}
|
|
491
671
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
672
|
+
kernel void kernel_group_norm(
|
|
673
|
+
device const float * src0,
|
|
674
|
+
device float * dst,
|
|
675
|
+
constant int64_t & ne00,
|
|
676
|
+
constant int64_t & ne01,
|
|
677
|
+
constant int64_t & ne02,
|
|
678
|
+
constant uint64_t & nb00,
|
|
679
|
+
constant uint64_t & nb01,
|
|
680
|
+
constant uint64_t & nb02,
|
|
681
|
+
constant int32_t & n_groups,
|
|
682
|
+
constant float & eps,
|
|
683
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
684
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
685
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
686
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
687
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
688
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
689
|
+
const int64_t ne = ne00*ne01*ne02;
|
|
690
|
+
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
|
691
|
+
|
|
692
|
+
int start = tgpig * gs;
|
|
693
|
+
int end = start + gs;
|
|
694
|
+
|
|
695
|
+
start += tpitg;
|
|
696
|
+
|
|
697
|
+
if (end >= ne) {
|
|
698
|
+
end = ne;
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
float tmp = 0.0f; // partial sum for thread in warp
|
|
702
|
+
|
|
703
|
+
for (int j = start; j < end; j += ntg) {
|
|
704
|
+
tmp += src0[j];
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
708
|
+
tmp = simd_sum(tmp);
|
|
709
|
+
if (ntg > N_SIMDWIDTH) {
|
|
710
|
+
if (sgitg == 0) {
|
|
711
|
+
buf[tiisg] = 0.0f;
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
715
|
+
|
|
716
|
+
if (tiisg == 0) {
|
|
717
|
+
buf[sgitg] = tmp;
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
721
|
+
|
|
722
|
+
tmp = buf[tiisg];
|
|
723
|
+
tmp = simd_sum(tmp);
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
const float mean = tmp / gs;
|
|
727
|
+
tmp = 0.0f;
|
|
728
|
+
|
|
729
|
+
for (int j = start; j < end; j += ntg) {
|
|
730
|
+
float xi = src0[j] - mean;
|
|
731
|
+
dst[j] = xi;
|
|
732
|
+
tmp += xi * xi;
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
tmp = simd_sum(tmp);
|
|
736
|
+
if (ntg > N_SIMDWIDTH) {
|
|
737
|
+
if (sgitg == 0) {
|
|
738
|
+
buf[tiisg] = 0.0f;
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
742
|
+
|
|
743
|
+
if (tiisg == 0) {
|
|
744
|
+
buf[sgitg] = tmp;
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
748
|
+
|
|
749
|
+
tmp = buf[tiisg];
|
|
750
|
+
tmp = simd_sum(tmp);
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
const float variance = tmp / gs;
|
|
754
|
+
const float scale = 1.0f/sqrt(variance + eps);
|
|
755
|
+
for (int j = start; j < end; j += ntg) {
|
|
756
|
+
dst[j] *= scale;
|
|
757
|
+
}
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
761
|
+
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
|
762
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
763
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
764
|
+
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
765
|
+
float d = qb_curr->d;
|
|
766
|
+
|
|
767
|
+
float2 acc = 0.f;
|
|
768
|
+
|
|
769
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
|
502
770
|
|
|
503
771
|
for (int i = 0; i < 8; i+=2) {
|
|
504
772
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
|
@@ -576,15 +844,25 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
576
844
|
// putting them in the kernel cause a significant performance penalty
|
|
577
845
|
#define N_DST 4 // each SIMD group works on 4 rows
|
|
578
846
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
579
|
-
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
580
847
|
//Note: This is a template, but strictly speaking it only applies to
|
|
581
848
|
// quantizations where the block size is 32. It also does not
|
|
582
849
|
// giard against the number of rows not being divisible by
|
|
583
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
584
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
585
|
-
void
|
|
586
|
-
|
|
587
|
-
|
|
852
|
+
void mul_vec_q_n_f32_impl(
|
|
853
|
+
device const void * src0,
|
|
854
|
+
device const float * src1,
|
|
855
|
+
device float * dst,
|
|
856
|
+
int64_t ne00,
|
|
857
|
+
int64_t ne01,
|
|
858
|
+
int64_t ne02,
|
|
859
|
+
int64_t ne10,
|
|
860
|
+
int64_t ne12,
|
|
861
|
+
int64_t ne0,
|
|
862
|
+
int64_t ne1,
|
|
863
|
+
uint r2,
|
|
864
|
+
uint r3,
|
|
865
|
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
|
588
866
|
const int nb = ne00/QK4_0;
|
|
589
867
|
|
|
590
868
|
const int r0 = tgpig.x;
|
|
@@ -593,7 +871,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
|
593
871
|
|
|
594
872
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
|
595
873
|
|
|
596
|
-
const uint
|
|
874
|
+
const uint i12 = im%ne12;
|
|
875
|
+
const uint i13 = im/ne12;
|
|
876
|
+
|
|
877
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
597
878
|
|
|
598
879
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
|
599
880
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
@@ -643,13 +924,14 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
|
643
924
|
constant int64_t & ne02[[buffer(5)]],
|
|
644
925
|
constant int64_t & ne10[[buffer(9)]],
|
|
645
926
|
constant int64_t & ne12[[buffer(11)]],
|
|
646
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
647
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
648
|
-
constant uint &
|
|
927
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
928
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
929
|
+
constant uint & r2 [[buffer(17)]],
|
|
930
|
+
constant uint & r3 [[buffer(18)]],
|
|
649
931
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
650
932
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
651
933
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
652
|
-
|
|
934
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
653
935
|
}
|
|
654
936
|
|
|
655
937
|
kernel void kernel_mul_mv_q4_1_f32(
|
|
@@ -661,13 +943,14 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
661
943
|
constant int64_t & ne02[[buffer(5)]],
|
|
662
944
|
constant int64_t & ne10[[buffer(9)]],
|
|
663
945
|
constant int64_t & ne12[[buffer(11)]],
|
|
664
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
665
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
666
|
-
constant uint &
|
|
946
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
947
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
948
|
+
constant uint & r2 [[buffer(17)]],
|
|
949
|
+
constant uint & r3 [[buffer(18)]],
|
|
667
950
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
668
951
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
669
952
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
670
|
-
|
|
953
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
671
954
|
}
|
|
672
955
|
|
|
673
956
|
kernel void kernel_mul_mv_q5_0_f32(
|
|
@@ -679,13 +962,14 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
|
679
962
|
constant int64_t & ne02[[buffer(5)]],
|
|
680
963
|
constant int64_t & ne10[[buffer(9)]],
|
|
681
964
|
constant int64_t & ne12[[buffer(11)]],
|
|
682
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
683
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
684
|
-
constant uint &
|
|
965
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
966
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
967
|
+
constant uint & r2 [[buffer(17)]],
|
|
968
|
+
constant uint & r3 [[buffer(18)]],
|
|
685
969
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
686
970
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
687
971
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
688
|
-
|
|
972
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
689
973
|
}
|
|
690
974
|
|
|
691
975
|
kernel void kernel_mul_mv_q5_1_f32(
|
|
@@ -697,33 +981,35 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
|
697
981
|
constant int64_t & ne02[[buffer(5)]],
|
|
698
982
|
constant int64_t & ne10[[buffer(9)]],
|
|
699
983
|
constant int64_t & ne12[[buffer(11)]],
|
|
700
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
701
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
702
|
-
constant uint &
|
|
984
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
985
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
986
|
+
constant uint & r2 [[buffer(17)]],
|
|
987
|
+
constant uint & r3 [[buffer(18)]],
|
|
703
988
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
704
989
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
705
990
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
706
|
-
|
|
991
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
707
992
|
}
|
|
708
993
|
|
|
709
994
|
|
|
710
995
|
#define NB_Q8_0 8
|
|
711
996
|
|
|
712
|
-
|
|
997
|
+
void kernel_mul_mv_q8_0_f32_impl(
|
|
713
998
|
device const void * src0,
|
|
714
999
|
device const float * src1,
|
|
715
1000
|
device float * dst,
|
|
716
1001
|
constant int64_t & ne00,
|
|
717
|
-
constant int64_t & ne01
|
|
718
|
-
constant int64_t & ne02
|
|
719
|
-
constant int64_t & ne10
|
|
720
|
-
constant int64_t & ne12
|
|
721
|
-
constant int64_t & ne0
|
|
722
|
-
constant int64_t & ne1
|
|
723
|
-
constant uint &
|
|
1002
|
+
constant int64_t & ne01,
|
|
1003
|
+
constant int64_t & ne02,
|
|
1004
|
+
constant int64_t & ne10,
|
|
1005
|
+
constant int64_t & ne12,
|
|
1006
|
+
constant int64_t & ne0,
|
|
1007
|
+
constant int64_t & ne1,
|
|
1008
|
+
constant uint & r2,
|
|
1009
|
+
constant uint & r3,
|
|
724
1010
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
725
|
-
uint
|
|
726
|
-
uint
|
|
1011
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
1012
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
727
1013
|
const int nr = N_DST;
|
|
728
1014
|
const int nsg = N_SIMDGROUP;
|
|
729
1015
|
const int nw = N_SIMDWIDTH;
|
|
@@ -732,8 +1018,14 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
732
1018
|
const int r0 = tgpig.x;
|
|
733
1019
|
const int r1 = tgpig.y;
|
|
734
1020
|
const int im = tgpig.z;
|
|
1021
|
+
|
|
735
1022
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
|
736
|
-
|
|
1023
|
+
|
|
1024
|
+
const uint i12 = im%ne12;
|
|
1025
|
+
const uint i13 = im/ne12;
|
|
1026
|
+
|
|
1027
|
+
const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
1028
|
+
|
|
737
1029
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
|
738
1030
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
739
1031
|
|
|
@@ -771,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
771
1063
|
}
|
|
772
1064
|
}
|
|
773
1065
|
|
|
1066
|
+
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
|
1067
|
+
kernel void kernel_mul_mv_q8_0_f32(
|
|
1068
|
+
device const void * src0,
|
|
1069
|
+
device const float * src1,
|
|
1070
|
+
device float * dst,
|
|
1071
|
+
constant int64_t & ne00,
|
|
1072
|
+
constant int64_t & ne01,
|
|
1073
|
+
constant int64_t & ne02,
|
|
1074
|
+
constant int64_t & ne10,
|
|
1075
|
+
constant int64_t & ne12,
|
|
1076
|
+
constant int64_t & ne0,
|
|
1077
|
+
constant int64_t & ne1,
|
|
1078
|
+
constant uint & r2 [[buffer(17)]],
|
|
1079
|
+
constant uint & r3 [[buffer(18)]],
|
|
1080
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1081
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
1082
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1083
|
+
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
|
|
1084
|
+
}
|
|
1085
|
+
|
|
774
1086
|
#define N_F32_F32 4
|
|
775
1087
|
|
|
776
|
-
|
|
1088
|
+
void kernel_mul_mv_f32_f32_impl(
|
|
777
1089
|
device const char * src0,
|
|
778
1090
|
device const char * src1,
|
|
779
1091
|
device float * dst,
|
|
@@ -791,14 +1103,21 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
791
1103
|
constant uint64_t & nb12,
|
|
792
1104
|
constant int64_t & ne0,
|
|
793
1105
|
constant int64_t & ne1,
|
|
1106
|
+
constant uint & r2,
|
|
1107
|
+
constant uint & r3,
|
|
794
1108
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
795
|
-
uint
|
|
1109
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
796
1110
|
|
|
797
1111
|
const int64_t r0 = tgpig.x;
|
|
798
1112
|
const int64_t rb = tgpig.y*N_F32_F32;
|
|
799
1113
|
const int64_t im = tgpig.z;
|
|
800
1114
|
|
|
801
|
-
|
|
1115
|
+
const uint i12 = im%ne12;
|
|
1116
|
+
const uint i13 = im/ne12;
|
|
1117
|
+
|
|
1118
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1119
|
+
|
|
1120
|
+
device const float * x = (device const float *) (src0 + offset0);
|
|
802
1121
|
|
|
803
1122
|
if (ne00 < 128) {
|
|
804
1123
|
for (int row = 0; row < N_F32_F32; ++row) {
|
|
@@ -844,7 +1163,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
844
1163
|
}
|
|
845
1164
|
}
|
|
846
1165
|
|
|
847
|
-
|
|
1166
|
+
[[host_name("kernel_mul_mv_f32_f32")]]
|
|
1167
|
+
kernel void kernel_mul_mv_f32_f32(
|
|
848
1168
|
device const char * src0,
|
|
849
1169
|
device const char * src1,
|
|
850
1170
|
device float * dst,
|
|
@@ -862,6 +1182,113 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
862
1182
|
constant uint64_t & nb12,
|
|
863
1183
|
constant int64_t & ne0,
|
|
864
1184
|
constant int64_t & ne1,
|
|
1185
|
+
constant uint & r2 [[buffer(17)]],
|
|
1186
|
+
constant uint & r3 [[buffer(18)]],
|
|
1187
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1188
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1189
|
+
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
#define N_F16_F16 4
|
|
1193
|
+
|
|
1194
|
+
kernel void kernel_mul_mv_f16_f16(
|
|
1195
|
+
device const char * src0,
|
|
1196
|
+
device const char * src1,
|
|
1197
|
+
device float * dst,
|
|
1198
|
+
constant int64_t & ne00,
|
|
1199
|
+
constant int64_t & ne01,
|
|
1200
|
+
constant int64_t & ne02,
|
|
1201
|
+
constant uint64_t & nb00,
|
|
1202
|
+
constant uint64_t & nb01,
|
|
1203
|
+
constant uint64_t & nb02,
|
|
1204
|
+
constant int64_t & ne10,
|
|
1205
|
+
constant int64_t & ne11,
|
|
1206
|
+
constant int64_t & ne12,
|
|
1207
|
+
constant uint64_t & nb10,
|
|
1208
|
+
constant uint64_t & nb11,
|
|
1209
|
+
constant uint64_t & nb12,
|
|
1210
|
+
constant int64_t & ne0,
|
|
1211
|
+
constant int64_t & ne1,
|
|
1212
|
+
constant uint & r2 [[buffer(17)]],
|
|
1213
|
+
constant uint & r3 [[buffer(18)]],
|
|
1214
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1215
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1216
|
+
|
|
1217
|
+
const int64_t r0 = tgpig.x;
|
|
1218
|
+
const int64_t rb = tgpig.y*N_F16_F16;
|
|
1219
|
+
const int64_t im = tgpig.z;
|
|
1220
|
+
|
|
1221
|
+
const uint i12 = im%ne12;
|
|
1222
|
+
const uint i13 = im/ne12;
|
|
1223
|
+
|
|
1224
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1225
|
+
|
|
1226
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
1227
|
+
|
|
1228
|
+
if (ne00 < 128) {
|
|
1229
|
+
for (int row = 0; row < N_F16_F16; ++row) {
|
|
1230
|
+
int r1 = rb + row;
|
|
1231
|
+
if (r1 >= ne11) {
|
|
1232
|
+
break;
|
|
1233
|
+
}
|
|
1234
|
+
|
|
1235
|
+
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
1236
|
+
|
|
1237
|
+
float sumf = 0;
|
|
1238
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
|
1239
|
+
sumf += (half) x[i] * (half) y[i];
|
|
1240
|
+
}
|
|
1241
|
+
|
|
1242
|
+
float all_sum = simd_sum(sumf);
|
|
1243
|
+
if (tiisg == 0) {
|
|
1244
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1245
|
+
}
|
|
1246
|
+
}
|
|
1247
|
+
} else {
|
|
1248
|
+
device const half4 * x4 = (device const half4 *)x;
|
|
1249
|
+
for (int row = 0; row < N_F16_F16; ++row) {
|
|
1250
|
+
int r1 = rb + row;
|
|
1251
|
+
if (r1 >= ne11) {
|
|
1252
|
+
break;
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
1256
|
+
device const half4 * y4 = (device const half4 *) y;
|
|
1257
|
+
|
|
1258
|
+
float sumf = 0;
|
|
1259
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
1260
|
+
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
|
1261
|
+
}
|
|
1262
|
+
|
|
1263
|
+
float all_sum = simd_sum(sumf);
|
|
1264
|
+
if (tiisg == 0) {
|
|
1265
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
|
1266
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
1267
|
+
}
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1270
|
+
}
|
|
1271
|
+
|
|
1272
|
+
void kernel_mul_mv_f16_f32_1row_impl(
|
|
1273
|
+
device const char * src0,
|
|
1274
|
+
device const char * src1,
|
|
1275
|
+
device float * dst,
|
|
1276
|
+
constant int64_t & ne00,
|
|
1277
|
+
constant int64_t & ne01,
|
|
1278
|
+
constant int64_t & ne02,
|
|
1279
|
+
constant uint64_t & nb00,
|
|
1280
|
+
constant uint64_t & nb01,
|
|
1281
|
+
constant uint64_t & nb02,
|
|
1282
|
+
constant int64_t & ne10,
|
|
1283
|
+
constant int64_t & ne11,
|
|
1284
|
+
constant int64_t & ne12,
|
|
1285
|
+
constant uint64_t & nb10,
|
|
1286
|
+
constant uint64_t & nb11,
|
|
1287
|
+
constant uint64_t & nb12,
|
|
1288
|
+
constant int64_t & ne0,
|
|
1289
|
+
constant int64_t & ne1,
|
|
1290
|
+
constant uint & r2,
|
|
1291
|
+
constant uint & r3,
|
|
865
1292
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
866
1293
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
867
1294
|
|
|
@@ -869,7 +1296,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
869
1296
|
const int64_t r1 = tgpig.y;
|
|
870
1297
|
const int64_t im = tgpig.z;
|
|
871
1298
|
|
|
872
|
-
|
|
1299
|
+
const uint i12 = im%ne12;
|
|
1300
|
+
const uint i13 = im/ne12;
|
|
1301
|
+
|
|
1302
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1303
|
+
|
|
1304
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
873
1305
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
874
1306
|
|
|
875
1307
|
float sumf = 0;
|
|
@@ -893,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
893
1325
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
894
1326
|
}
|
|
895
1327
|
}
|
|
1328
|
+
}
|
|
896
1329
|
|
|
1330
|
+
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
|
1331
|
+
kernel void kernel_mul_mv_f16_f32_1row(
|
|
1332
|
+
device const char * src0,
|
|
1333
|
+
device const char * src1,
|
|
1334
|
+
device float * dst,
|
|
1335
|
+
constant int64_t & ne00,
|
|
1336
|
+
constant int64_t & ne01,
|
|
1337
|
+
constant int64_t & ne02,
|
|
1338
|
+
constant uint64_t & nb00,
|
|
1339
|
+
constant uint64_t & nb01,
|
|
1340
|
+
constant uint64_t & nb02,
|
|
1341
|
+
constant int64_t & ne10,
|
|
1342
|
+
constant int64_t & ne11,
|
|
1343
|
+
constant int64_t & ne12,
|
|
1344
|
+
constant uint64_t & nb10,
|
|
1345
|
+
constant uint64_t & nb11,
|
|
1346
|
+
constant uint64_t & nb12,
|
|
1347
|
+
constant int64_t & ne0,
|
|
1348
|
+
constant int64_t & ne1,
|
|
1349
|
+
constant uint & r2 [[buffer(17)]],
|
|
1350
|
+
constant uint & r3 [[buffer(18)]],
|
|
1351
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1352
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1353
|
+
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
897
1354
|
}
|
|
898
1355
|
|
|
899
1356
|
#define N_F16_F32 4
|
|
900
1357
|
|
|
901
|
-
|
|
1358
|
+
void kernel_mul_mv_f16_f32_impl(
|
|
902
1359
|
device const char * src0,
|
|
903
1360
|
device const char * src1,
|
|
904
1361
|
device float * dst,
|
|
@@ -916,6 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
916
1373
|
constant uint64_t & nb12,
|
|
917
1374
|
constant int64_t & ne0,
|
|
918
1375
|
constant int64_t & ne1,
|
|
1376
|
+
constant uint & r2,
|
|
1377
|
+
constant uint & r3,
|
|
919
1378
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
920
1379
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
921
1380
|
|
|
@@ -923,7 +1382,12 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
923
1382
|
const int64_t rb = tgpig.y*N_F16_F32;
|
|
924
1383
|
const int64_t im = tgpig.z;
|
|
925
1384
|
|
|
926
|
-
|
|
1385
|
+
const uint i12 = im%ne12;
|
|
1386
|
+
const uint i13 = im/ne12;
|
|
1387
|
+
|
|
1388
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1389
|
+
|
|
1390
|
+
device const half * x = (device const half *) (src0 + offset0);
|
|
927
1391
|
|
|
928
1392
|
if (ne00 < 128) {
|
|
929
1393
|
for (int row = 0; row < N_F16_F32; ++row) {
|
|
@@ -969,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
969
1433
|
}
|
|
970
1434
|
}
|
|
971
1435
|
|
|
1436
|
+
[[host_name("kernel_mul_mv_f16_f32")]]
|
|
1437
|
+
kernel void kernel_mul_mv_f16_f32(
|
|
1438
|
+
device const char * src0,
|
|
1439
|
+
device const char * src1,
|
|
1440
|
+
device float * dst,
|
|
1441
|
+
constant int64_t & ne00,
|
|
1442
|
+
constant int64_t & ne01,
|
|
1443
|
+
constant int64_t & ne02,
|
|
1444
|
+
constant uint64_t & nb00,
|
|
1445
|
+
constant uint64_t & nb01,
|
|
1446
|
+
constant uint64_t & nb02,
|
|
1447
|
+
constant int64_t & ne10,
|
|
1448
|
+
constant int64_t & ne11,
|
|
1449
|
+
constant int64_t & ne12,
|
|
1450
|
+
constant uint64_t & nb10,
|
|
1451
|
+
constant uint64_t & nb11,
|
|
1452
|
+
constant uint64_t & nb12,
|
|
1453
|
+
constant int64_t & ne0,
|
|
1454
|
+
constant int64_t & ne1,
|
|
1455
|
+
constant uint & r2 [[buffer(17)]],
|
|
1456
|
+
constant uint & r3 [[buffer(18)]],
|
|
1457
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1458
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1459
|
+
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
1460
|
+
}
|
|
1461
|
+
|
|
972
1462
|
// Assumes row size (ne00) is a multiple of 4
|
|
973
1463
|
kernel void kernel_mul_mv_f16_f32_l4(
|
|
974
1464
|
device const char * src0,
|
|
@@ -988,6 +1478,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
988
1478
|
constant uint64_t & nb12,
|
|
989
1479
|
constant int64_t & ne0,
|
|
990
1480
|
constant int64_t & ne1,
|
|
1481
|
+
constant uint & r2 [[buffer(17)]],
|
|
1482
|
+
constant uint & r3 [[buffer(18)]],
|
|
991
1483
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
992
1484
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
993
1485
|
|
|
@@ -995,7 +1487,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
995
1487
|
const int64_t r0 = tgpig.x;
|
|
996
1488
|
const int64_t im = tgpig.z;
|
|
997
1489
|
|
|
998
|
-
|
|
1490
|
+
const uint i12 = im%ne12;
|
|
1491
|
+
const uint i13 = im/ne12;
|
|
1492
|
+
|
|
1493
|
+
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
1494
|
+
|
|
1495
|
+
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
|
999
1496
|
|
|
1000
1497
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
|
1001
1498
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
|
@@ -1047,17 +1544,21 @@ kernel void kernel_alibi_f32(
|
|
|
1047
1544
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1048
1545
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1049
1546
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1547
|
+
const int64_t k = i3*ne3 + i2;
|
|
1050
1548
|
|
|
1051
|
-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1052
1549
|
float m_k;
|
|
1053
|
-
if (
|
|
1054
|
-
m_k = pow(m0,
|
|
1550
|
+
if (k < n_heads_log2_floor) {
|
|
1551
|
+
m_k = pow(m0, k + 1);
|
|
1055
1552
|
} else {
|
|
1056
|
-
m_k = pow(m1, 2 * (
|
|
1553
|
+
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
|
1057
1554
|
}
|
|
1555
|
+
|
|
1556
|
+
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
|
1557
|
+
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
|
1058
1558
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1059
|
-
|
|
1060
|
-
|
|
1559
|
+
const float src_v = *(device float *)(src_row + i00*nb00);
|
|
1560
|
+
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
|
1561
|
+
*dst_v = i00 * m_k + src_v;
|
|
1061
1562
|
}
|
|
1062
1563
|
}
|
|
1063
1564
|
|
|
@@ -1213,25 +1714,333 @@ kernel void kernel_rope(
|
|
|
1213
1714
|
|
|
1214
1715
|
const int64_t i0 = ib*n_dims + ic/2;
|
|
1215
1716
|
|
|
1216
|
-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1217
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1717
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1718
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1719
|
+
|
|
1720
|
+
const float x0 = src[0];
|
|
1721
|
+
const float x1 = src[n_dims/2];
|
|
1722
|
+
|
|
1723
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1724
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1725
|
+
}
|
|
1726
|
+
}
|
|
1727
|
+
}
|
|
1728
|
+
}
|
|
1729
|
+
|
|
1730
|
+
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
|
1731
|
+
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
|
1732
|
+
|
|
1733
|
+
kernel void kernel_im2col_f16(
|
|
1734
|
+
device const float * x,
|
|
1735
|
+
device half * dst,
|
|
1736
|
+
constant int32_t & ofs0,
|
|
1737
|
+
constant int32_t & ofs1,
|
|
1738
|
+
constant int32_t & IW,
|
|
1739
|
+
constant int32_t & IH,
|
|
1740
|
+
constant int32_t & CHW,
|
|
1741
|
+
constant int32_t & s0,
|
|
1742
|
+
constant int32_t & s1,
|
|
1743
|
+
constant int32_t & p0,
|
|
1744
|
+
constant int32_t & p1,
|
|
1745
|
+
constant int32_t & d0,
|
|
1746
|
+
constant int32_t & d1,
|
|
1747
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1748
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
1749
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1750
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1751
|
+
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
|
|
1752
|
+
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
|
|
1753
|
+
|
|
1754
|
+
const int32_t offset_dst =
|
|
1755
|
+
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
|
1756
|
+
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
|
1757
|
+
|
|
1758
|
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
1759
|
+
dst[offset_dst] = 0.0f;
|
|
1760
|
+
} else {
|
|
1761
|
+
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
|
1762
|
+
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
|
1763
|
+
}
|
|
1764
|
+
}
|
|
1765
|
+
|
|
1766
|
+
kernel void kernel_upscale_f32(
|
|
1767
|
+
device const char * src0,
|
|
1768
|
+
device char * dst,
|
|
1769
|
+
constant int64_t & ne00,
|
|
1770
|
+
constant int64_t & ne01,
|
|
1771
|
+
constant int64_t & ne02,
|
|
1772
|
+
constant int64_t & ne03,
|
|
1773
|
+
constant uint64_t & nb00,
|
|
1774
|
+
constant uint64_t & nb01,
|
|
1775
|
+
constant uint64_t & nb02,
|
|
1776
|
+
constant uint64_t & nb03,
|
|
1777
|
+
constant int64_t & ne0,
|
|
1778
|
+
constant int64_t & ne1,
|
|
1779
|
+
constant int64_t & ne2,
|
|
1780
|
+
constant int64_t & ne3,
|
|
1781
|
+
constant uint64_t & nb0,
|
|
1782
|
+
constant uint64_t & nb1,
|
|
1783
|
+
constant uint64_t & nb2,
|
|
1784
|
+
constant uint64_t & nb3,
|
|
1785
|
+
constant int32_t & sf,
|
|
1786
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1787
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1788
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1789
|
+
|
|
1790
|
+
const int64_t i3 = tgpig.z;
|
|
1791
|
+
const int64_t i2 = tgpig.y;
|
|
1792
|
+
const int64_t i1 = tgpig.x;
|
|
1793
|
+
|
|
1794
|
+
const int64_t i03 = i3;
|
|
1795
|
+
const int64_t i02 = i2;
|
|
1796
|
+
const int64_t i01 = i1/sf;
|
|
1797
|
+
|
|
1798
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
|
1799
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
|
1800
|
+
|
|
1801
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1802
|
+
dst_ptr[i0] = src0_ptr[i0/sf];
|
|
1803
|
+
}
|
|
1804
|
+
}
|
|
1805
|
+
|
|
1806
|
+
kernel void kernel_pad_f32(
|
|
1807
|
+
device const char * src0,
|
|
1808
|
+
device char * dst,
|
|
1809
|
+
constant int64_t & ne00,
|
|
1810
|
+
constant int64_t & ne01,
|
|
1811
|
+
constant int64_t & ne02,
|
|
1812
|
+
constant int64_t & ne03,
|
|
1813
|
+
constant uint64_t & nb00,
|
|
1814
|
+
constant uint64_t & nb01,
|
|
1815
|
+
constant uint64_t & nb02,
|
|
1816
|
+
constant uint64_t & nb03,
|
|
1817
|
+
constant int64_t & ne0,
|
|
1818
|
+
constant int64_t & ne1,
|
|
1819
|
+
constant int64_t & ne2,
|
|
1820
|
+
constant int64_t & ne3,
|
|
1821
|
+
constant uint64_t & nb0,
|
|
1822
|
+
constant uint64_t & nb1,
|
|
1823
|
+
constant uint64_t & nb2,
|
|
1824
|
+
constant uint64_t & nb3,
|
|
1825
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1826
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1827
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1828
|
+
|
|
1829
|
+
const int64_t i3 = tgpig.z;
|
|
1830
|
+
const int64_t i2 = tgpig.y;
|
|
1831
|
+
const int64_t i1 = tgpig.x;
|
|
1832
|
+
|
|
1833
|
+
const int64_t i03 = i3;
|
|
1834
|
+
const int64_t i02 = i2;
|
|
1835
|
+
const int64_t i01 = i1;
|
|
1836
|
+
|
|
1837
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
|
1838
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
|
1839
|
+
|
|
1840
|
+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
|
1841
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1842
|
+
if (i0 < ne00) {
|
|
1843
|
+
dst_ptr[i0] = src0_ptr[i0];
|
|
1844
|
+
} else {
|
|
1845
|
+
dst_ptr[i0] = 0.0f;
|
|
1846
|
+
}
|
|
1847
|
+
}
|
|
1848
|
+
|
|
1849
|
+
return;
|
|
1850
|
+
}
|
|
1851
|
+
|
|
1852
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1853
|
+
dst_ptr[i0] = 0.0f;
|
|
1854
|
+
}
|
|
1855
|
+
}
|
|
1856
|
+
|
|
1857
|
+
// bitonic sort implementation following the CUDA kernels as reference
|
|
1858
|
+
typedef void (argsort_t)(
|
|
1859
|
+
device const float * x,
|
|
1860
|
+
device int32_t * dst,
|
|
1861
|
+
constant int64_t & ncols,
|
|
1862
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1863
|
+
uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
1864
|
+
|
|
1865
|
+
template<ggml_sort_order order>
|
|
1866
|
+
kernel void kernel_argsort_f32_i32(
|
|
1867
|
+
device const float * x,
|
|
1868
|
+
device int32_t * dst,
|
|
1869
|
+
constant int64_t & ncols,
|
|
1870
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1871
|
+
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
|
1872
|
+
// bitonic sort
|
|
1873
|
+
int col = tpitg[0];
|
|
1874
|
+
int row = tgpig[1];
|
|
1875
|
+
|
|
1876
|
+
if (col >= ncols) return;
|
|
1877
|
+
|
|
1878
|
+
device const float * x_row = x + row * ncols;
|
|
1879
|
+
device int32_t * dst_row = dst + row * ncols;
|
|
1880
|
+
|
|
1881
|
+
// initialize indices
|
|
1882
|
+
if (col < ncols) {
|
|
1883
|
+
dst_row[col] = col;
|
|
1884
|
+
}
|
|
1885
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1886
|
+
|
|
1887
|
+
for (int k = 2; k <= ncols; k *= 2) {
|
|
1888
|
+
for (int j = k / 2; j > 0; j /= 2) {
|
|
1889
|
+
int ixj = col ^ j;
|
|
1890
|
+
if (ixj > col) {
|
|
1891
|
+
if ((col & k) == 0) {
|
|
1892
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
|
1893
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
|
1894
|
+
}
|
|
1895
|
+
} else {
|
|
1896
|
+
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
|
1897
|
+
SWAP(dst_row[col], dst_row[ixj]);
|
|
1898
|
+
}
|
|
1899
|
+
}
|
|
1900
|
+
}
|
|
1901
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1902
|
+
}
|
|
1903
|
+
}
|
|
1904
|
+
}
|
|
1905
|
+
|
|
1906
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
|
1907
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
|
1908
|
+
|
|
1909
|
+
kernel void kernel_leaky_relu_f32(
|
|
1910
|
+
device const float * src0,
|
|
1911
|
+
device float * dst,
|
|
1912
|
+
constant float & slope,
|
|
1913
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1914
|
+
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
|
1915
|
+
}
|
|
1916
|
+
|
|
1917
|
+
kernel void kernel_cpy_f16_f16(
|
|
1918
|
+
device const half * src0,
|
|
1919
|
+
device half * dst,
|
|
1920
|
+
constant int64_t & ne00,
|
|
1921
|
+
constant int64_t & ne01,
|
|
1922
|
+
constant int64_t & ne02,
|
|
1923
|
+
constant int64_t & ne03,
|
|
1924
|
+
constant uint64_t & nb00,
|
|
1925
|
+
constant uint64_t & nb01,
|
|
1926
|
+
constant uint64_t & nb02,
|
|
1927
|
+
constant uint64_t & nb03,
|
|
1928
|
+
constant int64_t & ne0,
|
|
1929
|
+
constant int64_t & ne1,
|
|
1930
|
+
constant int64_t & ne2,
|
|
1931
|
+
constant int64_t & ne3,
|
|
1932
|
+
constant uint64_t & nb0,
|
|
1933
|
+
constant uint64_t & nb1,
|
|
1934
|
+
constant uint64_t & nb2,
|
|
1935
|
+
constant uint64_t & nb3,
|
|
1936
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1937
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1938
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1939
|
+
const int64_t i03 = tgpig[2];
|
|
1940
|
+
const int64_t i02 = tgpig[1];
|
|
1941
|
+
const int64_t i01 = tgpig[0];
|
|
1942
|
+
|
|
1943
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
1944
|
+
|
|
1945
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1946
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1947
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1948
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1949
|
+
|
|
1950
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1951
|
+
|
|
1952
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1953
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1954
|
+
dst_data[i00] = src[0];
|
|
1955
|
+
}
|
|
1956
|
+
}
|
|
1957
|
+
|
|
1958
|
+
kernel void kernel_cpy_f16_f32(
|
|
1959
|
+
device const half * src0,
|
|
1960
|
+
device float * dst,
|
|
1961
|
+
constant int64_t & ne00,
|
|
1962
|
+
constant int64_t & ne01,
|
|
1963
|
+
constant int64_t & ne02,
|
|
1964
|
+
constant int64_t & ne03,
|
|
1965
|
+
constant uint64_t & nb00,
|
|
1966
|
+
constant uint64_t & nb01,
|
|
1967
|
+
constant uint64_t & nb02,
|
|
1968
|
+
constant uint64_t & nb03,
|
|
1969
|
+
constant int64_t & ne0,
|
|
1970
|
+
constant int64_t & ne1,
|
|
1971
|
+
constant int64_t & ne2,
|
|
1972
|
+
constant int64_t & ne3,
|
|
1973
|
+
constant uint64_t & nb0,
|
|
1974
|
+
constant uint64_t & nb1,
|
|
1975
|
+
constant uint64_t & nb2,
|
|
1976
|
+
constant uint64_t & nb3,
|
|
1977
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1978
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1979
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1980
|
+
const int64_t i03 = tgpig[2];
|
|
1981
|
+
const int64_t i02 = tgpig[1];
|
|
1982
|
+
const int64_t i01 = tgpig[0];
|
|
1983
|
+
|
|
1984
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
1985
|
+
|
|
1986
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1987
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1988
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1989
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1990
|
+
|
|
1991
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1992
|
+
|
|
1993
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1994
|
+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1995
|
+
dst_data[i00] = src[0];
|
|
1996
|
+
}
|
|
1997
|
+
}
|
|
1998
|
+
|
|
1999
|
+
kernel void kernel_cpy_f32_f16(
|
|
2000
|
+
device const float * src0,
|
|
2001
|
+
device half * dst,
|
|
2002
|
+
constant int64_t & ne00,
|
|
2003
|
+
constant int64_t & ne01,
|
|
2004
|
+
constant int64_t & ne02,
|
|
2005
|
+
constant int64_t & ne03,
|
|
2006
|
+
constant uint64_t & nb00,
|
|
2007
|
+
constant uint64_t & nb01,
|
|
2008
|
+
constant uint64_t & nb02,
|
|
2009
|
+
constant uint64_t & nb03,
|
|
2010
|
+
constant int64_t & ne0,
|
|
2011
|
+
constant int64_t & ne1,
|
|
2012
|
+
constant int64_t & ne2,
|
|
2013
|
+
constant int64_t & ne3,
|
|
2014
|
+
constant uint64_t & nb0,
|
|
2015
|
+
constant uint64_t & nb1,
|
|
2016
|
+
constant uint64_t & nb2,
|
|
2017
|
+
constant uint64_t & nb3,
|
|
2018
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2019
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2020
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2021
|
+
const int64_t i03 = tgpig[2];
|
|
2022
|
+
const int64_t i02 = tgpig[1];
|
|
2023
|
+
const int64_t i01 = tgpig[0];
|
|
2024
|
+
|
|
2025
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
2026
|
+
|
|
2027
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
2028
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
2029
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
2030
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
2031
|
+
|
|
2032
|
+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1218
2033
|
|
|
1219
|
-
|
|
1220
|
-
|
|
2034
|
+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
2035
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1221
2036
|
|
|
1222
|
-
|
|
1223
|
-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1224
|
-
}
|
|
1225
|
-
}
|
|
2037
|
+
dst_data[i00] = src[0];
|
|
1226
2038
|
}
|
|
1227
2039
|
}
|
|
1228
2040
|
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
kernel void kernel_cpy_f16_f16(
|
|
1233
|
-
device const half * src0,
|
|
1234
|
-
device half * dst,
|
|
2041
|
+
kernel void kernel_cpy_f32_f32(
|
|
2042
|
+
device const float * src0,
|
|
2043
|
+
device float * dst,
|
|
1235
2044
|
constant int64_t & ne00,
|
|
1236
2045
|
constant int64_t & ne01,
|
|
1237
2046
|
constant int64_t & ne02,
|
|
@@ -1262,17 +2071,18 @@ kernel void kernel_cpy_f16_f16(
|
|
|
1262
2071
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1263
2072
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1264
2073
|
|
|
1265
|
-
device
|
|
2074
|
+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1266
2075
|
|
|
1267
2076
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
1268
|
-
device const
|
|
2077
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
2078
|
+
|
|
1269
2079
|
dst_data[i00] = src[0];
|
|
1270
2080
|
}
|
|
1271
2081
|
}
|
|
1272
2082
|
|
|
1273
|
-
kernel void
|
|
2083
|
+
kernel void kernel_cpy_f32_q8_0(
|
|
1274
2084
|
device const float * src0,
|
|
1275
|
-
device
|
|
2085
|
+
device void * dst,
|
|
1276
2086
|
constant int64_t & ne00,
|
|
1277
2087
|
constant int64_t & ne01,
|
|
1278
2088
|
constant int64_t & ne02,
|
|
@@ -1301,20 +2111,36 @@ kernel void kernel_cpy_f32_f16(
|
|
|
1301
2111
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1302
2112
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1303
2113
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1304
|
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
2114
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
|
1305
2115
|
|
|
1306
|
-
device
|
|
2116
|
+
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1307
2117
|
|
|
1308
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
2118
|
+
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
|
1309
2119
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1310
2120
|
|
|
1311
|
-
|
|
2121
|
+
float amax = 0.0f; // absolute max
|
|
2122
|
+
|
|
2123
|
+
for (int j = 0; j < QK8_0; j++) {
|
|
2124
|
+
const float v = src[j];
|
|
2125
|
+
amax = MAX(amax, fabs(v));
|
|
2126
|
+
}
|
|
2127
|
+
|
|
2128
|
+
const float d = amax / ((1 << 7) - 1);
|
|
2129
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
2130
|
+
|
|
2131
|
+
dst_data[i00/QK8_0].d = d;
|
|
2132
|
+
|
|
2133
|
+
for (int j = 0; j < QK8_0; ++j) {
|
|
2134
|
+
const float x0 = src[j]*id;
|
|
2135
|
+
|
|
2136
|
+
dst_data[i00/QK8_0].qs[j] = round(x0);
|
|
2137
|
+
}
|
|
1312
2138
|
}
|
|
1313
2139
|
}
|
|
1314
2140
|
|
|
1315
|
-
kernel void
|
|
2141
|
+
kernel void kernel_cpy_f32_q4_0(
|
|
1316
2142
|
device const float * src0,
|
|
1317
|
-
device
|
|
2143
|
+
device void * dst,
|
|
1318
2144
|
constant int64_t & ne00,
|
|
1319
2145
|
constant int64_t & ne01,
|
|
1320
2146
|
constant int64_t & ne02,
|
|
@@ -1343,21 +2169,112 @@ kernel void kernel_cpy_f32_f32(
|
|
|
1343
2169
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1344
2170
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1345
2171
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1346
|
-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
2172
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
|
|
1347
2173
|
|
|
1348
|
-
device
|
|
2174
|
+
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1349
2175
|
|
|
1350
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
2176
|
+
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
|
|
1351
2177
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
1352
2178
|
|
|
1353
|
-
|
|
2179
|
+
float amax = 0.0f; // absolute max
|
|
2180
|
+
float max = 0.0f;
|
|
2181
|
+
|
|
2182
|
+
for (int j = 0; j < QK4_0; j++) {
|
|
2183
|
+
const float v = src[j];
|
|
2184
|
+
if (amax < fabs(v)) {
|
|
2185
|
+
amax = fabs(v);
|
|
2186
|
+
max = v;
|
|
2187
|
+
}
|
|
2188
|
+
}
|
|
2189
|
+
|
|
2190
|
+
const float d = max / -8;
|
|
2191
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
2192
|
+
|
|
2193
|
+
dst_data[i00/QK4_0].d = d;
|
|
2194
|
+
|
|
2195
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
|
2196
|
+
const float x0 = src[0 + j]*id;
|
|
2197
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
|
2198
|
+
|
|
2199
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
2200
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
2201
|
+
|
|
2202
|
+
dst_data[i00/QK4_0].qs[j] = xi0;
|
|
2203
|
+
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
|
2204
|
+
}
|
|
2205
|
+
}
|
|
2206
|
+
}
|
|
2207
|
+
|
|
2208
|
+
kernel void kernel_cpy_f32_q4_1(
|
|
2209
|
+
device const float * src0,
|
|
2210
|
+
device void * dst,
|
|
2211
|
+
constant int64_t & ne00,
|
|
2212
|
+
constant int64_t & ne01,
|
|
2213
|
+
constant int64_t & ne02,
|
|
2214
|
+
constant int64_t & ne03,
|
|
2215
|
+
constant uint64_t & nb00,
|
|
2216
|
+
constant uint64_t & nb01,
|
|
2217
|
+
constant uint64_t & nb02,
|
|
2218
|
+
constant uint64_t & nb03,
|
|
2219
|
+
constant int64_t & ne0,
|
|
2220
|
+
constant int64_t & ne1,
|
|
2221
|
+
constant int64_t & ne2,
|
|
2222
|
+
constant int64_t & ne3,
|
|
2223
|
+
constant uint64_t & nb0,
|
|
2224
|
+
constant uint64_t & nb1,
|
|
2225
|
+
constant uint64_t & nb2,
|
|
2226
|
+
constant uint64_t & nb3,
|
|
2227
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2228
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
2229
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
2230
|
+
const int64_t i03 = tgpig[2];
|
|
2231
|
+
const int64_t i02 = tgpig[1];
|
|
2232
|
+
const int64_t i01 = tgpig[0];
|
|
2233
|
+
|
|
2234
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
2235
|
+
|
|
2236
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
2237
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
2238
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
2239
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
|
|
2240
|
+
|
|
2241
|
+
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
2242
|
+
|
|
2243
|
+
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
|
|
2244
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
2245
|
+
|
|
2246
|
+
float min = FLT_MAX;
|
|
2247
|
+
float max = -FLT_MAX;
|
|
2248
|
+
|
|
2249
|
+
for (int j = 0; j < QK4_1; j++) {
|
|
2250
|
+
const float v = src[j];
|
|
2251
|
+
if (min > v) min = v;
|
|
2252
|
+
if (max < v) max = v;
|
|
2253
|
+
}
|
|
2254
|
+
|
|
2255
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
|
2256
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
2257
|
+
|
|
2258
|
+
dst_data[i00/QK4_1].d = d;
|
|
2259
|
+
dst_data[i00/QK4_1].m = min;
|
|
2260
|
+
|
|
2261
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
|
2262
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
2263
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
2264
|
+
|
|
2265
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
2266
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
2267
|
+
|
|
2268
|
+
dst_data[i00/QK4_1].qs[j] = xi0;
|
|
2269
|
+
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
|
2270
|
+
}
|
|
1354
2271
|
}
|
|
1355
2272
|
}
|
|
1356
2273
|
|
|
1357
2274
|
kernel void kernel_concat(
|
|
1358
|
-
device
|
|
1359
|
-
device
|
|
1360
|
-
device
|
|
2275
|
+
device const char * src0,
|
|
2276
|
+
device const char * src1,
|
|
2277
|
+
device char * dst,
|
|
1361
2278
|
constant int64_t & ne00,
|
|
1362
2279
|
constant int64_t & ne01,
|
|
1363
2280
|
constant int64_t & ne02,
|
|
@@ -1394,7 +2311,7 @@ kernel void kernel_concat(
|
|
|
1394
2311
|
const int64_t i12 = i02 % ne12;
|
|
1395
2312
|
const int64_t i11 = i01 % ne11;
|
|
1396
2313
|
|
|
1397
|
-
device const char * src0_ptr = src0 + i03
|
|
2314
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
|
1398
2315
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
|
1399
2316
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
|
1400
2317
|
|
|
@@ -1502,32 +2419,39 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
|
1502
2419
|
|
|
1503
2420
|
//====================================== dot products =========================
|
|
1504
2421
|
|
|
1505
|
-
|
|
2422
|
+
void kernel_mul_mv_q2_K_f32_impl(
|
|
1506
2423
|
device const void * src0,
|
|
1507
2424
|
device const float * src1,
|
|
1508
2425
|
device float * dst,
|
|
1509
2426
|
constant int64_t & ne00,
|
|
1510
|
-
constant int64_t & ne01
|
|
1511
|
-
constant int64_t & ne02
|
|
1512
|
-
constant int64_t & ne10
|
|
1513
|
-
constant int64_t & ne12
|
|
1514
|
-
constant int64_t & ne0
|
|
1515
|
-
constant int64_t & ne1
|
|
1516
|
-
constant uint &
|
|
2427
|
+
constant int64_t & ne01,
|
|
2428
|
+
constant int64_t & ne02,
|
|
2429
|
+
constant int64_t & ne10,
|
|
2430
|
+
constant int64_t & ne12,
|
|
2431
|
+
constant int64_t & ne0,
|
|
2432
|
+
constant int64_t & ne1,
|
|
2433
|
+
constant uint & r2,
|
|
2434
|
+
constant uint & r3,
|
|
1517
2435
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1518
|
-
uint
|
|
1519
|
-
uint
|
|
2436
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2437
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1520
2438
|
|
|
1521
2439
|
const int nb = ne00/QK_K;
|
|
1522
2440
|
const int r0 = tgpig.x;
|
|
1523
2441
|
const int r1 = tgpig.y;
|
|
1524
|
-
const int
|
|
2442
|
+
const int im = tgpig.z;
|
|
1525
2443
|
|
|
1526
2444
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
1527
2445
|
const int ib_row = first_row * nb;
|
|
1528
|
-
|
|
2446
|
+
|
|
2447
|
+
const uint i12 = im%ne12;
|
|
2448
|
+
const uint i13 = im/ne12;
|
|
2449
|
+
|
|
2450
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2451
|
+
|
|
1529
2452
|
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
|
1530
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2453
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2454
|
+
|
|
1531
2455
|
float yl[32];
|
|
1532
2456
|
float sumf[N_DST]={0.f}, all_sum;
|
|
1533
2457
|
|
|
@@ -1536,11 +2460,11 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
1536
2460
|
#if QK_K == 256
|
|
1537
2461
|
const int ix = tiisg/8; // 0...3
|
|
1538
2462
|
const int it = tiisg%8; // 0...7
|
|
1539
|
-
const int
|
|
2463
|
+
const int iq = it/4; // 0 or 1
|
|
1540
2464
|
const int ir = it%4; // 0...3
|
|
1541
2465
|
const int is = (8*ir)/16;// 0 or 1
|
|
1542
2466
|
|
|
1543
|
-
device const float * y4 = y + ix * QK_K + 128 *
|
|
2467
|
+
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
|
1544
2468
|
|
|
1545
2469
|
for (int ib = ix; ib < nb; ib += 4) {
|
|
1546
2470
|
|
|
@@ -1552,8 +2476,8 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
1552
2476
|
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
|
1553
2477
|
}
|
|
1554
2478
|
|
|
1555
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*
|
|
1556
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 *
|
|
2479
|
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
|
|
2480
|
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
1557
2481
|
device const half * dh = &x[ib].d;
|
|
1558
2482
|
|
|
1559
2483
|
for (int row = 0; row < N_DST; row++) {
|
|
@@ -1640,13 +2564,13 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
1640
2564
|
for (int row = 0; row < N_DST; ++row) {
|
|
1641
2565
|
all_sum = simd_sum(sumf[row]);
|
|
1642
2566
|
if (tiisg == 0) {
|
|
1643
|
-
dst[r1*ne0 +
|
|
2567
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
1644
2568
|
}
|
|
1645
2569
|
}
|
|
1646
2570
|
}
|
|
1647
2571
|
|
|
1648
|
-
|
|
1649
|
-
kernel void
|
|
2572
|
+
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
|
2573
|
+
kernel void kernel_mul_mv_q2_K_f32(
|
|
1650
2574
|
device const void * src0,
|
|
1651
2575
|
device const float * src1,
|
|
1652
2576
|
device float * dst,
|
|
@@ -1655,23 +2579,50 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1655
2579
|
constant int64_t & ne02[[buffer(5)]],
|
|
1656
2580
|
constant int64_t & ne10[[buffer(9)]],
|
|
1657
2581
|
constant int64_t & ne12[[buffer(11)]],
|
|
1658
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
1659
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
1660
|
-
constant uint &
|
|
2582
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2583
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2584
|
+
constant uint & r2 [[buffer(17)]],
|
|
2585
|
+
constant uint & r3 [[buffer(18)]],
|
|
1661
2586
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1662
|
-
uint
|
|
1663
|
-
uint
|
|
2587
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2588
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2589
|
+
|
|
2590
|
+
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
2591
|
+
}
|
|
2592
|
+
|
|
2593
|
+
#if QK_K == 256
|
|
2594
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
|
2595
|
+
device const void * src0,
|
|
2596
|
+
device const float * src1,
|
|
2597
|
+
device float * dst,
|
|
2598
|
+
constant int64_t & ne00,
|
|
2599
|
+
constant int64_t & ne01,
|
|
2600
|
+
constant int64_t & ne02,
|
|
2601
|
+
constant int64_t & ne10,
|
|
2602
|
+
constant int64_t & ne12,
|
|
2603
|
+
constant int64_t & ne0,
|
|
2604
|
+
constant int64_t & ne1,
|
|
2605
|
+
constant uint & r2,
|
|
2606
|
+
constant uint & r3,
|
|
2607
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2608
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2609
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1664
2610
|
|
|
1665
2611
|
const int nb = ne00/QK_K;
|
|
1666
2612
|
|
|
1667
2613
|
const int64_t r0 = tgpig.x;
|
|
1668
2614
|
const int64_t r1 = tgpig.y;
|
|
1669
|
-
const int64_t
|
|
2615
|
+
const int64_t im = tgpig.z;
|
|
1670
2616
|
|
|
1671
2617
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
|
1672
|
-
|
|
2618
|
+
|
|
2619
|
+
const uint i12 = im%ne12;
|
|
2620
|
+
const uint i13 = im/ne12;
|
|
2621
|
+
|
|
2622
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2623
|
+
|
|
1673
2624
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
|
1674
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2625
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
1675
2626
|
|
|
1676
2627
|
float yl[32];
|
|
1677
2628
|
|
|
@@ -1793,40 +2744,47 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1793
2744
|
}
|
|
1794
2745
|
if (tiisg == 0) {
|
|
1795
2746
|
for (int row = 0; row < 2; ++row) {
|
|
1796
|
-
dst[r1*ne0 +
|
|
2747
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
|
|
1797
2748
|
}
|
|
1798
2749
|
}
|
|
1799
2750
|
}
|
|
1800
2751
|
#else
|
|
1801
|
-
|
|
2752
|
+
void kernel_mul_mv_q3_K_f32_impl(
|
|
1802
2753
|
device const void * src0,
|
|
1803
2754
|
device const float * src1,
|
|
1804
2755
|
device float * dst,
|
|
1805
2756
|
constant int64_t & ne00,
|
|
1806
|
-
constant int64_t & ne01
|
|
1807
|
-
constant int64_t & ne02
|
|
1808
|
-
constant int64_t & ne10
|
|
1809
|
-
constant int64_t & ne12
|
|
1810
|
-
constant int64_t & ne0
|
|
1811
|
-
constant int64_t & ne1
|
|
1812
|
-
constant uint &
|
|
2757
|
+
constant int64_t & ne01,
|
|
2758
|
+
constant int64_t & ne02,
|
|
2759
|
+
constant int64_t & ne10,
|
|
2760
|
+
constant int64_t & ne12,
|
|
2761
|
+
constant int64_t & ne0,
|
|
2762
|
+
constant int64_t & ne1,
|
|
2763
|
+
constant uint & r2,
|
|
2764
|
+
constant uint & r3,
|
|
1813
2765
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1814
|
-
uint
|
|
1815
|
-
uint
|
|
2766
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2767
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1816
2768
|
|
|
1817
2769
|
const int nb = ne00/QK_K;
|
|
1818
2770
|
|
|
1819
2771
|
const int64_t r0 = tgpig.x;
|
|
1820
2772
|
const int64_t r1 = tgpig.y;
|
|
1821
|
-
const int64_t
|
|
2773
|
+
const int64_t im = tgpig.z;
|
|
1822
2774
|
|
|
1823
2775
|
const int row = 2 * r0 + sgitg;
|
|
1824
|
-
|
|
2776
|
+
|
|
2777
|
+
const uint i12 = im%ne12;
|
|
2778
|
+
const uint i13 = im/ne12;
|
|
2779
|
+
|
|
2780
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2781
|
+
|
|
1825
2782
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
|
1826
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
2783
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2784
|
+
|
|
1827
2785
|
const int ix = tiisg/4;
|
|
1828
2786
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
|
1829
|
-
const int
|
|
2787
|
+
const int iq = il/8; // 0, 0, 1, 1
|
|
1830
2788
|
const int in = il%8; // 0, 4, 0, 4
|
|
1831
2789
|
|
|
1832
2790
|
float2 sum = {0.f, 0.f};
|
|
@@ -1846,7 +2804,7 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1846
2804
|
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
|
1847
2805
|
|
|
1848
2806
|
for (int l = 0; l < 4; l += 2) {
|
|
1849
|
-
const uint16_t hm = h[l/2] >>
|
|
2807
|
+
const uint16_t hm = h[l/2] >> iq;
|
|
1850
2808
|
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
|
1851
2809
|
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
|
1852
2810
|
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
|
@@ -1862,28 +2820,50 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
1862
2820
|
|
|
1863
2821
|
const float tot = simd_sum(sumf);
|
|
1864
2822
|
if (tiisg == 0) {
|
|
1865
|
-
dst[r1*ne0 +
|
|
2823
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
|
1866
2824
|
}
|
|
1867
2825
|
|
|
1868
2826
|
}
|
|
1869
2827
|
#endif
|
|
1870
2828
|
|
|
2829
|
+
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
|
2830
|
+
kernel void kernel_mul_mv_q3_K_f32(
|
|
2831
|
+
device const void * src0,
|
|
2832
|
+
device const float * src1,
|
|
2833
|
+
device float * dst,
|
|
2834
|
+
constant int64_t & ne00,
|
|
2835
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
2836
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
2837
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
2838
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
2839
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
2840
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
2841
|
+
constant uint & r2 [[buffer(17)]],
|
|
2842
|
+
constant uint & r3 [[buffer(18)]],
|
|
2843
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2844
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2845
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2846
|
+
|
|
2847
|
+
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
2848
|
+
}
|
|
2849
|
+
|
|
1871
2850
|
#if QK_K == 256
|
|
1872
|
-
|
|
2851
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
|
1873
2852
|
device const void * src0,
|
|
1874
2853
|
device const float * src1,
|
|
1875
2854
|
device float * dst,
|
|
1876
2855
|
constant int64_t & ne00,
|
|
1877
|
-
constant int64_t & ne01
|
|
1878
|
-
constant int64_t & ne02
|
|
1879
|
-
constant int64_t & ne10
|
|
1880
|
-
constant int64_t & ne12
|
|
1881
|
-
constant int64_t & ne0
|
|
1882
|
-
constant int64_t & ne1
|
|
1883
|
-
constant uint &
|
|
2856
|
+
constant int64_t & ne01,
|
|
2857
|
+
constant int64_t & ne02,
|
|
2858
|
+
constant int64_t & ne10,
|
|
2859
|
+
constant int64_t & ne12,
|
|
2860
|
+
constant int64_t & ne0,
|
|
2861
|
+
constant int64_t & ne1,
|
|
2862
|
+
constant uint & r2,
|
|
2863
|
+
constant uint & r3,
|
|
1884
2864
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1885
|
-
uint
|
|
1886
|
-
uint
|
|
2865
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
2866
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
1887
2867
|
|
|
1888
2868
|
const uint16_t kmask1 = 0x3f3f;
|
|
1889
2869
|
const uint16_t kmask2 = 0x0f0f;
|
|
@@ -1891,26 +2871,32 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1891
2871
|
|
|
1892
2872
|
const int ix = tiisg/8; // 0...3
|
|
1893
2873
|
const int it = tiisg%8; // 0...7
|
|
1894
|
-
const int
|
|
2874
|
+
const int iq = it/4; // 0 or 1
|
|
1895
2875
|
const int ir = it%4; // 0...3
|
|
1896
2876
|
|
|
1897
2877
|
const int nb = ne00/QK_K;
|
|
1898
2878
|
const int r0 = tgpig.x;
|
|
1899
2879
|
const int r1 = tgpig.y;
|
|
1900
|
-
const int
|
|
2880
|
+
const int im = tgpig.z;
|
|
1901
2881
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
1902
2882
|
const int first_row = r0 * N_DST;
|
|
1903
2883
|
const int ib_row = first_row * nb;
|
|
1904
|
-
|
|
2884
|
+
|
|
2885
|
+
const uint i12 = im%ne12;
|
|
2886
|
+
const uint i13 = im/ne12;
|
|
2887
|
+
|
|
2888
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2889
|
+
|
|
1905
2890
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
|
1906
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2891
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2892
|
+
|
|
1907
2893
|
float yl[16];
|
|
1908
2894
|
float yh[16];
|
|
1909
2895
|
float sumf[N_DST]={0.f}, all_sum;
|
|
1910
2896
|
|
|
1911
2897
|
const int step = sizeof(block_q4_K) * nb / 2;
|
|
1912
2898
|
|
|
1913
|
-
device const float * y4 = y + ix * QK_K + 64 *
|
|
2899
|
+
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
|
1914
2900
|
|
|
1915
2901
|
uint16_t sc16[4];
|
|
1916
2902
|
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
|
@@ -1925,8 +2911,8 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1925
2911
|
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
|
1926
2912
|
}
|
|
1927
2913
|
|
|
1928
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales +
|
|
1929
|
-
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 *
|
|
2914
|
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
|
|
2915
|
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
|
1930
2916
|
device const half * dh = &x[ib].d;
|
|
1931
2917
|
|
|
1932
2918
|
for (int row = 0; row < N_DST; row++) {
|
|
@@ -1970,23 +2956,24 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1970
2956
|
for (int row = 0; row < N_DST; ++row) {
|
|
1971
2957
|
all_sum = simd_sum(sumf[row]);
|
|
1972
2958
|
if (tiisg == 0) {
|
|
1973
|
-
dst[r1*ne0 +
|
|
2959
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
1974
2960
|
}
|
|
1975
2961
|
}
|
|
1976
2962
|
}
|
|
1977
2963
|
#else
|
|
1978
|
-
|
|
2964
|
+
void kernel_mul_mv_q4_K_f32_impl(
|
|
1979
2965
|
device const void * src0,
|
|
1980
2966
|
device const float * src1,
|
|
1981
2967
|
device float * dst,
|
|
1982
2968
|
constant int64_t & ne00,
|
|
1983
|
-
constant int64_t & ne01
|
|
1984
|
-
constant int64_t & ne02
|
|
1985
|
-
constant int64_t & ne10
|
|
1986
|
-
constant int64_t & ne12
|
|
1987
|
-
constant int64_t & ne0
|
|
1988
|
-
constant int64_t & ne1
|
|
1989
|
-
constant uint &
|
|
2969
|
+
constant int64_t & ne01,
|
|
2970
|
+
constant int64_t & ne02,
|
|
2971
|
+
constant int64_t & ne10,
|
|
2972
|
+
constant int64_t & ne12,
|
|
2973
|
+
constant int64_t & ne0,
|
|
2974
|
+
constant int64_t & ne1,
|
|
2975
|
+
constant uint & r2,
|
|
2976
|
+
constant uint & r3,
|
|
1990
2977
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1991
2978
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1992
2979
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1997,12 +2984,18 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
1997
2984
|
const int nb = ne00/QK_K;
|
|
1998
2985
|
const int r0 = tgpig.x;
|
|
1999
2986
|
const int r1 = tgpig.y;
|
|
2000
|
-
const int
|
|
2987
|
+
const int im = tgpig.z;
|
|
2001
2988
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
2002
2989
|
const int ib_row = first_row * nb;
|
|
2003
|
-
|
|
2990
|
+
|
|
2991
|
+
const uint i12 = im%ne12;
|
|
2992
|
+
const uint i13 = im/ne12;
|
|
2993
|
+
|
|
2994
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
2995
|
+
|
|
2004
2996
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
|
2005
|
-
device const float * y = (device const float *) src1 + r1*ne10 +
|
|
2997
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2998
|
+
|
|
2006
2999
|
float yl[8];
|
|
2007
3000
|
float yh[8];
|
|
2008
3001
|
float sumf[N_DST]={0.f}, all_sum;
|
|
@@ -2058,13 +3051,14 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
2058
3051
|
for (int row = 0; row < N_DST; ++row) {
|
|
2059
3052
|
all_sum = simd_sum(sumf[row]);
|
|
2060
3053
|
if (tiisg == 0) {
|
|
2061
|
-
dst[r1*ne0+
|
|
3054
|
+
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
|
2062
3055
|
}
|
|
2063
3056
|
}
|
|
2064
3057
|
}
|
|
2065
3058
|
#endif
|
|
2066
3059
|
|
|
2067
|
-
|
|
3060
|
+
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
|
3061
|
+
kernel void kernel_mul_mv_q4_K_f32(
|
|
2068
3062
|
device const void * src0,
|
|
2069
3063
|
device const float * src1,
|
|
2070
3064
|
device float * dst,
|
|
@@ -2073,23 +3067,49 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2073
3067
|
constant int64_t & ne02[[buffer(5)]],
|
|
2074
3068
|
constant int64_t & ne10[[buffer(9)]],
|
|
2075
3069
|
constant int64_t & ne12[[buffer(11)]],
|
|
2076
|
-
constant int64_t & ne0[[buffer(15)]],
|
|
2077
|
-
constant int64_t & ne1[[buffer(16)]],
|
|
2078
|
-
constant uint &
|
|
3070
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
3071
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
3072
|
+
constant uint & r2 [[buffer(17)]],
|
|
3073
|
+
constant uint & r3 [[buffer(18)]],
|
|
2079
3074
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2080
3075
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2081
3076
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2082
3077
|
|
|
3078
|
+
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
3079
|
+
}
|
|
3080
|
+
|
|
3081
|
+
void kernel_mul_mv_q5_K_f32_impl(
|
|
3082
|
+
device const void * src0,
|
|
3083
|
+
device const float * src1,
|
|
3084
|
+
device float * dst,
|
|
3085
|
+
constant int64_t & ne00,
|
|
3086
|
+
constant int64_t & ne01,
|
|
3087
|
+
constant int64_t & ne02,
|
|
3088
|
+
constant int64_t & ne10,
|
|
3089
|
+
constant int64_t & ne12,
|
|
3090
|
+
constant int64_t & ne0,
|
|
3091
|
+
constant int64_t & ne1,
|
|
3092
|
+
constant uint & r2,
|
|
3093
|
+
constant uint & r3,
|
|
3094
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3095
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3096
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3097
|
+
|
|
2083
3098
|
const int nb = ne00/QK_K;
|
|
2084
3099
|
|
|
2085
3100
|
const int64_t r0 = tgpig.x;
|
|
2086
3101
|
const int64_t r1 = tgpig.y;
|
|
2087
|
-
const int
|
|
3102
|
+
const int im = tgpig.z;
|
|
2088
3103
|
|
|
2089
3104
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
|
2090
|
-
|
|
3105
|
+
|
|
3106
|
+
const uint i12 = im%ne12;
|
|
3107
|
+
const uint i13 = im/ne12;
|
|
3108
|
+
|
|
3109
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
3110
|
+
|
|
2091
3111
|
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
|
2092
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
3112
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2093
3113
|
|
|
2094
3114
|
float sumf[2]={0.f};
|
|
2095
3115
|
|
|
@@ -2105,15 +3125,15 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2105
3125
|
|
|
2106
3126
|
const int tid = tiisg/4;
|
|
2107
3127
|
const int ix = tiisg%4;
|
|
2108
|
-
const int
|
|
3128
|
+
const int iq = tid/4;
|
|
2109
3129
|
const int ir = tid%4;
|
|
2110
3130
|
const int n = 8;
|
|
2111
3131
|
|
|
2112
3132
|
const int l0 = n*ir;
|
|
2113
|
-
const int q_offset = 32*
|
|
2114
|
-
const int y_offset = 64*
|
|
3133
|
+
const int q_offset = 32*iq + l0;
|
|
3134
|
+
const int y_offset = 64*iq + l0;
|
|
2115
3135
|
|
|
2116
|
-
const uint8_t hm1 = 1u << (2*
|
|
3136
|
+
const uint8_t hm1 = 1u << (2*iq);
|
|
2117
3137
|
const uint8_t hm2 = hm1 << 1;
|
|
2118
3138
|
const uint8_t hm3 = hm1 << 4;
|
|
2119
3139
|
const uint8_t hm4 = hm2 << 4;
|
|
@@ -2128,7 +3148,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2128
3148
|
device const uint8_t * q1 = x[i].qs + q_offset;
|
|
2129
3149
|
device const uint8_t * qh = x[i].qh + l0;
|
|
2130
3150
|
device const half * dh = &x[i].d;
|
|
2131
|
-
device const uint16_t * a = (device const uint16_t *)x[i].scales +
|
|
3151
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
|
|
2132
3152
|
|
|
2133
3153
|
device const float * y2 = y1 + 128;
|
|
2134
3154
|
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
|
@@ -2184,7 +3204,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2184
3204
|
|
|
2185
3205
|
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
|
2186
3206
|
const int ix = tiisg%8;
|
|
2187
|
-
const int
|
|
3207
|
+
const int iq = il/8; // 0, 0, 1, 1
|
|
2188
3208
|
const int in = il%8; // 0, 4, 0, 4
|
|
2189
3209
|
|
|
2190
3210
|
device const float * y = yy + ix*QK_K + il;
|
|
@@ -2209,7 +3229,7 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2209
3229
|
|
|
2210
3230
|
float2 acc = {0.f, 0.f};
|
|
2211
3231
|
for (int l = 0; l < 4; ++l) {
|
|
2212
|
-
const uint8_t hl = h[l] >>
|
|
3232
|
+
const uint8_t hl = h[l] >> iq;
|
|
2213
3233
|
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
|
2214
3234
|
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
|
2215
3235
|
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
|
@@ -2231,27 +3251,48 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
2231
3251
|
for (int row = 0; row < 2; ++row) {
|
|
2232
3252
|
const float tot = simd_sum(sumf[row]);
|
|
2233
3253
|
if (tiisg == 0) {
|
|
2234
|
-
dst[r1*ne0 +
|
|
3254
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
|
2235
3255
|
}
|
|
2236
3256
|
}
|
|
3257
|
+
}
|
|
3258
|
+
|
|
3259
|
+
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
|
3260
|
+
kernel void kernel_mul_mv_q5_K_f32(
|
|
3261
|
+
device const void * src0,
|
|
3262
|
+
device const float * src1,
|
|
3263
|
+
device float * dst,
|
|
3264
|
+
constant int64_t & ne00,
|
|
3265
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
3266
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
3267
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
3268
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
3269
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
3270
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
3271
|
+
constant uint & r2 [[buffer(17)]],
|
|
3272
|
+
constant uint & r3 [[buffer(18)]],
|
|
3273
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3274
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3275
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2237
3276
|
|
|
3277
|
+
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
2238
3278
|
}
|
|
2239
3279
|
|
|
2240
|
-
|
|
3280
|
+
void kernel_mul_mv_q6_K_f32_impl(
|
|
2241
3281
|
device const void * src0,
|
|
2242
3282
|
device const float * src1,
|
|
2243
3283
|
device float * dst,
|
|
2244
3284
|
constant int64_t & ne00,
|
|
2245
|
-
constant int64_t & ne01
|
|
2246
|
-
constant int64_t & ne02
|
|
2247
|
-
constant int64_t & ne10
|
|
2248
|
-
constant int64_t & ne12
|
|
2249
|
-
constant int64_t & ne0
|
|
2250
|
-
constant int64_t & ne1
|
|
2251
|
-
constant uint &
|
|
3285
|
+
constant int64_t & ne01,
|
|
3286
|
+
constant int64_t & ne02,
|
|
3287
|
+
constant int64_t & ne10,
|
|
3288
|
+
constant int64_t & ne12,
|
|
3289
|
+
constant int64_t & ne0,
|
|
3290
|
+
constant int64_t & ne1,
|
|
3291
|
+
constant uint & r2,
|
|
3292
|
+
constant uint & r3,
|
|
2252
3293
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2253
|
-
uint
|
|
2254
|
-
uint
|
|
3294
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3295
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2255
3296
|
|
|
2256
3297
|
const uint8_t kmask1 = 0x03;
|
|
2257
3298
|
const uint8_t kmask2 = 0x0C;
|
|
@@ -2262,12 +3303,17 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
2262
3303
|
|
|
2263
3304
|
const int64_t r0 = tgpig.x;
|
|
2264
3305
|
const int64_t r1 = tgpig.y;
|
|
2265
|
-
const int
|
|
3306
|
+
const int im = tgpig.z;
|
|
2266
3307
|
|
|
2267
3308
|
const int row = 2 * r0 + sgitg;
|
|
2268
|
-
|
|
3309
|
+
|
|
3310
|
+
const uint i12 = im%ne12;
|
|
3311
|
+
const uint i13 = im/ne12;
|
|
3312
|
+
|
|
3313
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
3314
|
+
|
|
2269
3315
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
|
2270
|
-
device const float * yy = (device const float *) src1 + r1*ne10 +
|
|
3316
|
+
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
2271
3317
|
|
|
2272
3318
|
float sumf = 0;
|
|
2273
3319
|
|
|
@@ -2333,10 +3379,31 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
2333
3379
|
|
|
2334
3380
|
const float tot = simd_sum(sumf);
|
|
2335
3381
|
if (tiisg == 0) {
|
|
2336
|
-
dst[r1*ne0 +
|
|
3382
|
+
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
|
2337
3383
|
}
|
|
2338
3384
|
}
|
|
2339
3385
|
|
|
3386
|
+
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
|
3387
|
+
kernel void kernel_mul_mv_q6_K_f32(
|
|
3388
|
+
device const void * src0,
|
|
3389
|
+
device const float * src1,
|
|
3390
|
+
device float * dst,
|
|
3391
|
+
constant int64_t & ne00,
|
|
3392
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
3393
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
3394
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
3395
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
3396
|
+
constant int64_t & ne0 [[buffer(15)]],
|
|
3397
|
+
constant int64_t & ne1 [[buffer(16)]],
|
|
3398
|
+
constant uint & r2 [[buffer(17)]],
|
|
3399
|
+
constant uint & r3 [[buffer(18)]],
|
|
3400
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3401
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3402
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3403
|
+
|
|
3404
|
+
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
3405
|
+
}
|
|
3406
|
+
|
|
2340
3407
|
//============================= templates and their specializations =============================
|
|
2341
3408
|
|
|
2342
3409
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
@@ -2454,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
|
2454
3521
|
|
|
2455
3522
|
template <typename type4x4>
|
|
2456
3523
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
|
2457
|
-
const
|
|
2458
|
-
const
|
|
3524
|
+
const float d = xb->d;
|
|
3525
|
+
const float min = xb->dmin;
|
|
2459
3526
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
|
2460
|
-
|
|
3527
|
+
float dl, ml;
|
|
2461
3528
|
uint8_t sc = xb->scales[il];
|
|
2462
3529
|
|
|
2463
3530
|
#if QK_K == 256
|
|
@@ -2527,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
|
2527
3594
|
q = q + (il/4) * 32 + 16 * (il&1);
|
|
2528
3595
|
il = il & 3;
|
|
2529
3596
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
2530
|
-
const
|
|
2531
|
-
const
|
|
2532
|
-
const
|
|
2533
|
-
const
|
|
3597
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
|
3598
|
+
const float min = xb->dmin;
|
|
3599
|
+
const float dl = d * sc[0];
|
|
3600
|
+
const float ml = min * sc[1];
|
|
2534
3601
|
#else
|
|
2535
3602
|
q = q + 16 * (il&1);
|
|
2536
3603
|
device const uint8_t * s = xb->scales;
|
|
@@ -2557,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
|
2557
3624
|
uint8_t ul = 1 << (il/2);
|
|
2558
3625
|
il = il & 3;
|
|
2559
3626
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
2560
|
-
const
|
|
2561
|
-
const
|
|
2562
|
-
const
|
|
2563
|
-
const
|
|
3627
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
|
3628
|
+
const float min = xb->dmin;
|
|
3629
|
+
const float dl = d * sc[0];
|
|
3630
|
+
const float ml = min * sc[1];
|
|
2564
3631
|
|
|
2565
|
-
const ushort mask
|
|
2566
|
-
const
|
|
3632
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
|
3633
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
|
2567
3634
|
for (int i = 0; i < 16; ++i) {
|
|
2568
3635
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
|
2569
3636
|
}
|
|
@@ -2611,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
|
2611
3678
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
2612
3679
|
kernel void kernel_get_rows(
|
|
2613
3680
|
device const void * src0,
|
|
2614
|
-
device const
|
|
3681
|
+
device const char * src1,
|
|
2615
3682
|
device float * dst,
|
|
2616
3683
|
constant int64_t & ne00,
|
|
2617
3684
|
constant uint64_t & nb01,
|
|
3685
|
+
constant uint64_t & nb02,
|
|
3686
|
+
constant int64_t & ne10,
|
|
3687
|
+
constant uint64_t & nb10,
|
|
3688
|
+
constant uint64_t & nb11,
|
|
2618
3689
|
constant uint64_t & nb1,
|
|
2619
|
-
|
|
3690
|
+
constant uint64_t & nb2,
|
|
3691
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2620
3692
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
2621
|
-
|
|
2622
|
-
const
|
|
2623
|
-
const
|
|
3693
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
3694
|
+
//const int64_t i = tgpig;
|
|
3695
|
+
//const int64_t r = ((device int32_t *) src1)[i];
|
|
3696
|
+
|
|
3697
|
+
const int64_t i10 = tgpig.x;
|
|
3698
|
+
const int64_t i11 = tgpig.y;
|
|
2624
3699
|
|
|
2625
|
-
|
|
3700
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
3701
|
+
|
|
3702
|
+
const int64_t i02 = i11;
|
|
3703
|
+
|
|
3704
|
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
|
2626
3705
|
float4x4 temp;
|
|
2627
3706
|
dequantize_func(
|
|
2628
|
-
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
|
2629
|
-
*(((device float4x4 *) ((device char *) dst +
|
|
3707
|
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
3708
|
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
|
3709
|
+
}
|
|
3710
|
+
}
|
|
3711
|
+
|
|
3712
|
+
kernel void kernel_get_rows_f32(
|
|
3713
|
+
device const void * src0,
|
|
3714
|
+
device const char * src1,
|
|
3715
|
+
device float * dst,
|
|
3716
|
+
constant int64_t & ne00,
|
|
3717
|
+
constant uint64_t & nb01,
|
|
3718
|
+
constant uint64_t & nb02,
|
|
3719
|
+
constant int64_t & ne10,
|
|
3720
|
+
constant uint64_t & nb10,
|
|
3721
|
+
constant uint64_t & nb11,
|
|
3722
|
+
constant uint64_t & nb1,
|
|
3723
|
+
constant uint64_t & nb2,
|
|
3724
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3725
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3726
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
3727
|
+
const int64_t i10 = tgpig.x;
|
|
3728
|
+
const int64_t i11 = tgpig.y;
|
|
3729
|
+
|
|
3730
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
3731
|
+
|
|
3732
|
+
const int64_t i02 = i11;
|
|
3733
|
+
|
|
3734
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
3735
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
3736
|
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
3737
|
+
}
|
|
3738
|
+
}
|
|
3739
|
+
|
|
3740
|
+
kernel void kernel_get_rows_f16(
|
|
3741
|
+
device const void * src0,
|
|
3742
|
+
device const char * src1,
|
|
3743
|
+
device float * dst,
|
|
3744
|
+
constant int64_t & ne00,
|
|
3745
|
+
constant uint64_t & nb01,
|
|
3746
|
+
constant uint64_t & nb02,
|
|
3747
|
+
constant int64_t & ne10,
|
|
3748
|
+
constant uint64_t & nb10,
|
|
3749
|
+
constant uint64_t & nb11,
|
|
3750
|
+
constant uint64_t & nb1,
|
|
3751
|
+
constant uint64_t & nb2,
|
|
3752
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3753
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3754
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
3755
|
+
const int64_t i10 = tgpig.x;
|
|
3756
|
+
const int64_t i11 = tgpig.y;
|
|
3757
|
+
|
|
3758
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
3759
|
+
|
|
3760
|
+
const int64_t i02 = i11;
|
|
3761
|
+
|
|
3762
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
3763
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
3764
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
2630
3765
|
}
|
|
2631
3766
|
}
|
|
2632
3767
|
|
|
@@ -2643,24 +3778,25 @@ kernel void kernel_get_rows(
|
|
|
2643
3778
|
|
|
2644
3779
|
// each block_q contains 16*nl weights
|
|
2645
3780
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
3781
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
|
3782
|
+
device const uchar * src1,
|
|
3783
|
+
device float * dst,
|
|
3784
|
+
constant int64_t & ne00,
|
|
3785
|
+
constant int64_t & ne02,
|
|
3786
|
+
constant int64_t & nb01,
|
|
3787
|
+
constant int64_t & nb02,
|
|
3788
|
+
constant int64_t & ne12,
|
|
3789
|
+
constant int64_t & nb10,
|
|
3790
|
+
constant int64_t & nb11,
|
|
3791
|
+
constant int64_t & nb12,
|
|
3792
|
+
constant int64_t & ne0,
|
|
3793
|
+
constant int64_t & ne1,
|
|
3794
|
+
constant uint & r2,
|
|
3795
|
+
constant uint & r3,
|
|
3796
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3797
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3798
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3799
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2664
3800
|
|
|
2665
3801
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
2666
3802
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
@@ -2686,7 +3822,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2686
3822
|
|
|
2687
3823
|
short il = (tiitg % THREAD_PER_ROW);
|
|
2688
3824
|
|
|
2689
|
-
uint
|
|
3825
|
+
const uint i12 = im%ne12;
|
|
3826
|
+
const uint i13 = im/ne12;
|
|
3827
|
+
|
|
3828
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
2690
3829
|
ushort offset1 = il/nl;
|
|
2691
3830
|
|
|
2692
3831
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
@@ -2770,17 +3909,137 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
2770
3909
|
}
|
|
2771
3910
|
}
|
|
2772
3911
|
|
|
3912
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3913
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
|
3914
|
+
device const uchar * src1,
|
|
3915
|
+
device float * dst,
|
|
3916
|
+
constant int64_t & ne00,
|
|
3917
|
+
constant int64_t & ne02,
|
|
3918
|
+
constant int64_t & nb01,
|
|
3919
|
+
constant int64_t & nb02,
|
|
3920
|
+
constant int64_t & ne12,
|
|
3921
|
+
constant int64_t & nb10,
|
|
3922
|
+
constant int64_t & nb11,
|
|
3923
|
+
constant int64_t & nb12,
|
|
3924
|
+
constant int64_t & ne0,
|
|
3925
|
+
constant int64_t & ne1,
|
|
3926
|
+
constant uint & r2,
|
|
3927
|
+
constant uint & r3,
|
|
3928
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3929
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3930
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3931
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3932
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
3933
|
+
src0,
|
|
3934
|
+
src1,
|
|
3935
|
+
dst,
|
|
3936
|
+
ne00,
|
|
3937
|
+
ne02,
|
|
3938
|
+
nb01,
|
|
3939
|
+
nb02,
|
|
3940
|
+
ne12,
|
|
3941
|
+
nb10,
|
|
3942
|
+
nb11,
|
|
3943
|
+
nb12,
|
|
3944
|
+
ne0,
|
|
3945
|
+
ne1,
|
|
3946
|
+
r2,
|
|
3947
|
+
r3,
|
|
3948
|
+
shared_memory,
|
|
3949
|
+
tgpig,
|
|
3950
|
+
tiitg,
|
|
3951
|
+
sgitg);
|
|
3952
|
+
}
|
|
3953
|
+
|
|
3954
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3955
|
+
kernel void kernel_mul_mm_id(
|
|
3956
|
+
device const uchar * ids,
|
|
3957
|
+
device const uchar * src1,
|
|
3958
|
+
device uchar * dst,
|
|
3959
|
+
constant int64_t & nbi1,
|
|
3960
|
+
constant int64_t & ne00,
|
|
3961
|
+
constant int64_t & ne02,
|
|
3962
|
+
constant int64_t & nb01,
|
|
3963
|
+
constant int64_t & nb02,
|
|
3964
|
+
constant int64_t & ne12,
|
|
3965
|
+
constant int64_t & ne13,
|
|
3966
|
+
constant int64_t & nb10,
|
|
3967
|
+
constant int64_t & nb11,
|
|
3968
|
+
constant int64_t & nb12,
|
|
3969
|
+
constant int64_t & ne0,
|
|
3970
|
+
constant int64_t & ne1,
|
|
3971
|
+
constant int64_t & nb1,
|
|
3972
|
+
constant uint & r2,
|
|
3973
|
+
constant uint & r3,
|
|
3974
|
+
constant int & idx,
|
|
3975
|
+
device const uchar * src00,
|
|
3976
|
+
device const uchar * src01,
|
|
3977
|
+
device const uchar * src02,
|
|
3978
|
+
device const uchar * src03,
|
|
3979
|
+
device const uchar * src04,
|
|
3980
|
+
device const uchar * src05,
|
|
3981
|
+
device const uchar * src06,
|
|
3982
|
+
device const uchar * src07,
|
|
3983
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
3984
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3985
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
3986
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3987
|
+
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
3988
|
+
|
|
3989
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
3990
|
+
|
|
3991
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
3992
|
+
|
|
3993
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
3994
|
+
|
|
3995
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
3996
|
+
src0[id],
|
|
3997
|
+
src1 + bid*nb11,
|
|
3998
|
+
(device float *) (dst + bid*nb1),
|
|
3999
|
+
ne00,
|
|
4000
|
+
ne02,
|
|
4001
|
+
nb01,
|
|
4002
|
+
nb02,
|
|
4003
|
+
ne12,
|
|
4004
|
+
nb10,
|
|
4005
|
+
nb11,
|
|
4006
|
+
nb12,
|
|
4007
|
+
ne0,
|
|
4008
|
+
ne1,
|
|
4009
|
+
r2,
|
|
4010
|
+
r3,
|
|
4011
|
+
shared_memory,
|
|
4012
|
+
tgpig,
|
|
4013
|
+
tiitg,
|
|
4014
|
+
sgitg);
|
|
4015
|
+
}
|
|
4016
|
+
|
|
2773
4017
|
#if QK_K == 256
|
|
2774
4018
|
#define QK_NL 16
|
|
2775
4019
|
#else
|
|
2776
4020
|
#define QK_NL 4
|
|
2777
4021
|
#endif
|
|
2778
4022
|
|
|
2779
|
-
|
|
2780
|
-
|
|
4023
|
+
//
|
|
4024
|
+
// get rows
|
|
4025
|
+
//
|
|
2781
4026
|
|
|
2782
|
-
|
|
2783
|
-
|
|
4027
|
+
typedef void (get_rows_t)(
|
|
4028
|
+
device const void * src0,
|
|
4029
|
+
device const char * src1,
|
|
4030
|
+
device float * dst,
|
|
4031
|
+
constant int64_t & ne00,
|
|
4032
|
+
constant uint64_t & nb01,
|
|
4033
|
+
constant uint64_t & nb02,
|
|
4034
|
+
constant int64_t & ne10,
|
|
4035
|
+
constant uint64_t & nb10,
|
|
4036
|
+
constant uint64_t & nb11,
|
|
4037
|
+
constant uint64_t & nb1,
|
|
4038
|
+
constant uint64_t & nb2,
|
|
4039
|
+
uint3, uint, uint3);
|
|
4040
|
+
|
|
4041
|
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
4042
|
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
2784
4043
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
2785
4044
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
2786
4045
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
@@ -2792,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
|
2792
4051
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
2793
4052
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
2794
4053
|
|
|
4054
|
+
//
|
|
4055
|
+
// matrix-matrix multiplication
|
|
4056
|
+
//
|
|
4057
|
+
|
|
2795
4058
|
typedef void (mat_mm_t)(
|
|
2796
4059
|
device const uchar * src0,
|
|
2797
4060
|
device const uchar * src1,
|
|
@@ -2806,8 +4069,10 @@ typedef void (mat_mm_t)(
|
|
|
2806
4069
|
constant int64_t & nb12,
|
|
2807
4070
|
constant int64_t & ne0,
|
|
2808
4071
|
constant int64_t & ne1,
|
|
2809
|
-
constant uint &
|
|
2810
|
-
|
|
4072
|
+
constant uint & r2,
|
|
4073
|
+
constant uint & r3,
|
|
4074
|
+
threadgroup uchar *,
|
|
4075
|
+
uint3, uint, uint);
|
|
2811
4076
|
|
|
2812
4077
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
|
2813
4078
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
@@ -2821,3 +4086,823 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
2821
4086
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
2822
4087
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
2823
4088
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4089
|
+
|
|
4090
|
+
//
|
|
4091
|
+
// indirect matrix-matrix multiplication
|
|
4092
|
+
//
|
|
4093
|
+
|
|
4094
|
+
typedef void (mat_mm_id_t)(
|
|
4095
|
+
device const uchar * ids,
|
|
4096
|
+
device const uchar * src1,
|
|
4097
|
+
device uchar * dst,
|
|
4098
|
+
constant int64_t & nbi1,
|
|
4099
|
+
constant int64_t & ne00,
|
|
4100
|
+
constant int64_t & ne02,
|
|
4101
|
+
constant int64_t & nb01,
|
|
4102
|
+
constant int64_t & nb02,
|
|
4103
|
+
constant int64_t & ne12,
|
|
4104
|
+
constant int64_t & ne13,
|
|
4105
|
+
constant int64_t & nb10,
|
|
4106
|
+
constant int64_t & nb11,
|
|
4107
|
+
constant int64_t & nb12,
|
|
4108
|
+
constant int64_t & ne0,
|
|
4109
|
+
constant int64_t & ne1,
|
|
4110
|
+
constant int64_t & nb1,
|
|
4111
|
+
constant uint & r2,
|
|
4112
|
+
constant uint & r3,
|
|
4113
|
+
constant int & idx,
|
|
4114
|
+
device const uchar * src00,
|
|
4115
|
+
device const uchar * src01,
|
|
4116
|
+
device const uchar * src02,
|
|
4117
|
+
device const uchar * src03,
|
|
4118
|
+
device const uchar * src04,
|
|
4119
|
+
device const uchar * src05,
|
|
4120
|
+
device const uchar * src06,
|
|
4121
|
+
device const uchar * src07,
|
|
4122
|
+
threadgroup uchar *,
|
|
4123
|
+
uint3, uint, uint);
|
|
4124
|
+
|
|
4125
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
|
4126
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
4127
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
|
4128
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
|
4129
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
|
4130
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
|
4131
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
|
4132
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
4133
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
4134
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
4135
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
4136
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4137
|
+
|
|
4138
|
+
//
|
|
4139
|
+
// matrix-vector multiplication
|
|
4140
|
+
//
|
|
4141
|
+
|
|
4142
|
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
|
4143
|
+
kernel void kernel_mul_mv_id_f32_f32(
|
|
4144
|
+
device const char * ids,
|
|
4145
|
+
device const char * src1,
|
|
4146
|
+
device uchar * dst,
|
|
4147
|
+
constant int64_t & nbi1,
|
|
4148
|
+
constant int64_t & ne00,
|
|
4149
|
+
constant int64_t & ne01,
|
|
4150
|
+
constant int64_t & ne02,
|
|
4151
|
+
constant uint64_t & nb00,
|
|
4152
|
+
constant uint64_t & nb01,
|
|
4153
|
+
constant uint64_t & nb02,
|
|
4154
|
+
constant int64_t & ne10,
|
|
4155
|
+
constant int64_t & ne11,
|
|
4156
|
+
constant int64_t & ne12,
|
|
4157
|
+
constant int64_t & ne13,
|
|
4158
|
+
constant uint64_t & nb10,
|
|
4159
|
+
constant uint64_t & nb11,
|
|
4160
|
+
constant uint64_t & nb12,
|
|
4161
|
+
constant int64_t & ne0,
|
|
4162
|
+
constant int64_t & ne1,
|
|
4163
|
+
constant int64_t & nb1,
|
|
4164
|
+
constant uint & r2,
|
|
4165
|
+
constant uint & r3,
|
|
4166
|
+
constant int & idx,
|
|
4167
|
+
device const char * src00,
|
|
4168
|
+
device const char * src01,
|
|
4169
|
+
device const char * src02,
|
|
4170
|
+
device const char * src03,
|
|
4171
|
+
device const char * src04,
|
|
4172
|
+
device const char * src05,
|
|
4173
|
+
device const char * src06,
|
|
4174
|
+
device const char * src07,
|
|
4175
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4176
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4177
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4178
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4179
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4180
|
+
|
|
4181
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4182
|
+
|
|
4183
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4184
|
+
|
|
4185
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4186
|
+
|
|
4187
|
+
kernel_mul_mv_f32_f32_impl(
|
|
4188
|
+
src0[id],
|
|
4189
|
+
src1 + bid*nb11,
|
|
4190
|
+
(device float *) (dst + bid*nb1),
|
|
4191
|
+
ne00,
|
|
4192
|
+
ne01,
|
|
4193
|
+
ne02,
|
|
4194
|
+
nb00,
|
|
4195
|
+
nb01,
|
|
4196
|
+
nb02,
|
|
4197
|
+
ne10,
|
|
4198
|
+
ne11,
|
|
4199
|
+
ne12,
|
|
4200
|
+
nb10,
|
|
4201
|
+
nb11,
|
|
4202
|
+
nb12,
|
|
4203
|
+
ne0,
|
|
4204
|
+
ne1,
|
|
4205
|
+
r2,
|
|
4206
|
+
r3,
|
|
4207
|
+
tgpig,
|
|
4208
|
+
tiisg);
|
|
4209
|
+
}
|
|
4210
|
+
|
|
4211
|
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
|
4212
|
+
kernel void kernel_mul_mv_id_f16_f32(
|
|
4213
|
+
device const char * ids,
|
|
4214
|
+
device const char * src1,
|
|
4215
|
+
device uchar * dst,
|
|
4216
|
+
constant int64_t & nbi1,
|
|
4217
|
+
constant int64_t & ne00,
|
|
4218
|
+
constant int64_t & ne01,
|
|
4219
|
+
constant int64_t & ne02,
|
|
4220
|
+
constant uint64_t & nb00,
|
|
4221
|
+
constant uint64_t & nb01,
|
|
4222
|
+
constant uint64_t & nb02,
|
|
4223
|
+
constant int64_t & ne10,
|
|
4224
|
+
constant int64_t & ne11,
|
|
4225
|
+
constant int64_t & ne12,
|
|
4226
|
+
constant int64_t & ne13,
|
|
4227
|
+
constant uint64_t & nb10,
|
|
4228
|
+
constant uint64_t & nb11,
|
|
4229
|
+
constant uint64_t & nb12,
|
|
4230
|
+
constant int64_t & ne0,
|
|
4231
|
+
constant int64_t & ne1,
|
|
4232
|
+
constant int64_t & nb1,
|
|
4233
|
+
constant uint & r2,
|
|
4234
|
+
constant uint & r3,
|
|
4235
|
+
constant int & idx,
|
|
4236
|
+
device const char * src00,
|
|
4237
|
+
device const char * src01,
|
|
4238
|
+
device const char * src02,
|
|
4239
|
+
device const char * src03,
|
|
4240
|
+
device const char * src04,
|
|
4241
|
+
device const char * src05,
|
|
4242
|
+
device const char * src06,
|
|
4243
|
+
device const char * src07,
|
|
4244
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4245
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4246
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4247
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4248
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4249
|
+
|
|
4250
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4251
|
+
|
|
4252
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4253
|
+
|
|
4254
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4255
|
+
|
|
4256
|
+
kernel_mul_mv_f16_f32_impl(
|
|
4257
|
+
src0[id],
|
|
4258
|
+
src1 + bid*nb11,
|
|
4259
|
+
(device float *) (dst + bid*nb1),
|
|
4260
|
+
ne00,
|
|
4261
|
+
ne01,
|
|
4262
|
+
ne02,
|
|
4263
|
+
nb00,
|
|
4264
|
+
nb01,
|
|
4265
|
+
nb02,
|
|
4266
|
+
ne10,
|
|
4267
|
+
ne11,
|
|
4268
|
+
ne12,
|
|
4269
|
+
nb10,
|
|
4270
|
+
nb11,
|
|
4271
|
+
nb12,
|
|
4272
|
+
ne0,
|
|
4273
|
+
ne1,
|
|
4274
|
+
r2,
|
|
4275
|
+
r3,
|
|
4276
|
+
tgpig,
|
|
4277
|
+
tiisg);
|
|
4278
|
+
}
|
|
4279
|
+
|
|
4280
|
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
|
4281
|
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4282
|
+
device const char * ids,
|
|
4283
|
+
device const char * src1,
|
|
4284
|
+
device uchar * dst,
|
|
4285
|
+
constant int64_t & nbi1,
|
|
4286
|
+
constant int64_t & ne00,
|
|
4287
|
+
constant int64_t & ne01,
|
|
4288
|
+
constant int64_t & ne02,
|
|
4289
|
+
constant uint64_t & nb00,
|
|
4290
|
+
constant uint64_t & nb01,
|
|
4291
|
+
constant uint64_t & nb02,
|
|
4292
|
+
constant int64_t & ne10,
|
|
4293
|
+
constant int64_t & ne11,
|
|
4294
|
+
constant int64_t & ne12,
|
|
4295
|
+
constant int64_t & ne13,
|
|
4296
|
+
constant uint64_t & nb10,
|
|
4297
|
+
constant uint64_t & nb11,
|
|
4298
|
+
constant uint64_t & nb12,
|
|
4299
|
+
constant int64_t & ne0,
|
|
4300
|
+
constant int64_t & ne1,
|
|
4301
|
+
constant int64_t & nb1,
|
|
4302
|
+
constant uint & r2,
|
|
4303
|
+
constant uint & r3,
|
|
4304
|
+
constant int & idx,
|
|
4305
|
+
device const char * src00,
|
|
4306
|
+
device const char * src01,
|
|
4307
|
+
device const char * src02,
|
|
4308
|
+
device const char * src03,
|
|
4309
|
+
device const char * src04,
|
|
4310
|
+
device const char * src05,
|
|
4311
|
+
device const char * src06,
|
|
4312
|
+
device const char * src07,
|
|
4313
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4314
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4315
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4316
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4317
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4318
|
+
|
|
4319
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4320
|
+
|
|
4321
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4322
|
+
|
|
4323
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4324
|
+
|
|
4325
|
+
kernel_mul_mv_q8_0_f32_impl(
|
|
4326
|
+
src0[id],
|
|
4327
|
+
(device const float *) (src1 + bid*nb11),
|
|
4328
|
+
(device float *) ( dst + bid*nb1),
|
|
4329
|
+
ne00,
|
|
4330
|
+
ne01,
|
|
4331
|
+
ne02,
|
|
4332
|
+
ne10,
|
|
4333
|
+
ne12,
|
|
4334
|
+
ne0,
|
|
4335
|
+
ne1,
|
|
4336
|
+
r2,
|
|
4337
|
+
r3,
|
|
4338
|
+
tgpig,
|
|
4339
|
+
tiisg,
|
|
4340
|
+
sgitg);
|
|
4341
|
+
}
|
|
4342
|
+
|
|
4343
|
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
|
4344
|
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4345
|
+
device const char * ids,
|
|
4346
|
+
device const char * src1,
|
|
4347
|
+
device uchar * dst,
|
|
4348
|
+
constant int64_t & nbi1,
|
|
4349
|
+
constant int64_t & ne00,
|
|
4350
|
+
constant int64_t & ne01,
|
|
4351
|
+
constant int64_t & ne02,
|
|
4352
|
+
constant uint64_t & nb00,
|
|
4353
|
+
constant uint64_t & nb01,
|
|
4354
|
+
constant uint64_t & nb02,
|
|
4355
|
+
constant int64_t & ne10,
|
|
4356
|
+
constant int64_t & ne11,
|
|
4357
|
+
constant int64_t & ne12,
|
|
4358
|
+
constant int64_t & ne13,
|
|
4359
|
+
constant uint64_t & nb10,
|
|
4360
|
+
constant uint64_t & nb11,
|
|
4361
|
+
constant uint64_t & nb12,
|
|
4362
|
+
constant int64_t & ne0,
|
|
4363
|
+
constant int64_t & ne1,
|
|
4364
|
+
constant int64_t & nb1,
|
|
4365
|
+
constant uint & r2,
|
|
4366
|
+
constant uint & r3,
|
|
4367
|
+
constant int & idx,
|
|
4368
|
+
device const char * src00,
|
|
4369
|
+
device const char * src01,
|
|
4370
|
+
device const char * src02,
|
|
4371
|
+
device const char * src03,
|
|
4372
|
+
device const char * src04,
|
|
4373
|
+
device const char * src05,
|
|
4374
|
+
device const char * src06,
|
|
4375
|
+
device const char * src07,
|
|
4376
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4377
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4378
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4379
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4380
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4381
|
+
|
|
4382
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4383
|
+
|
|
4384
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4385
|
+
|
|
4386
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4387
|
+
|
|
4388
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4389
|
+
src0[id],
|
|
4390
|
+
(device const float *) (src1 + bid*nb11),
|
|
4391
|
+
(device float *) ( dst + bid*nb1),
|
|
4392
|
+
ne00,
|
|
4393
|
+
ne01,
|
|
4394
|
+
ne02,
|
|
4395
|
+
ne10,
|
|
4396
|
+
ne12,
|
|
4397
|
+
ne0,
|
|
4398
|
+
ne1,
|
|
4399
|
+
r2,
|
|
4400
|
+
r3,
|
|
4401
|
+
tgpig,
|
|
4402
|
+
tiisg,
|
|
4403
|
+
sgitg);
|
|
4404
|
+
}
|
|
4405
|
+
|
|
4406
|
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
|
4407
|
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4408
|
+
device const char * ids,
|
|
4409
|
+
device const char * src1,
|
|
4410
|
+
device uchar * dst,
|
|
4411
|
+
constant int64_t & nbi1,
|
|
4412
|
+
constant int64_t & ne00,
|
|
4413
|
+
constant int64_t & ne01,
|
|
4414
|
+
constant int64_t & ne02,
|
|
4415
|
+
constant uint64_t & nb00,
|
|
4416
|
+
constant uint64_t & nb01,
|
|
4417
|
+
constant uint64_t & nb02,
|
|
4418
|
+
constant int64_t & ne10,
|
|
4419
|
+
constant int64_t & ne11,
|
|
4420
|
+
constant int64_t & ne12,
|
|
4421
|
+
constant int64_t & ne13,
|
|
4422
|
+
constant uint64_t & nb10,
|
|
4423
|
+
constant uint64_t & nb11,
|
|
4424
|
+
constant uint64_t & nb12,
|
|
4425
|
+
constant int64_t & ne0,
|
|
4426
|
+
constant int64_t & ne1,
|
|
4427
|
+
constant int64_t & nb1,
|
|
4428
|
+
constant uint & r2,
|
|
4429
|
+
constant uint & r3,
|
|
4430
|
+
constant int & idx,
|
|
4431
|
+
device const char * src00,
|
|
4432
|
+
device const char * src01,
|
|
4433
|
+
device const char * src02,
|
|
4434
|
+
device const char * src03,
|
|
4435
|
+
device const char * src04,
|
|
4436
|
+
device const char * src05,
|
|
4437
|
+
device const char * src06,
|
|
4438
|
+
device const char * src07,
|
|
4439
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4440
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4441
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4442
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4443
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4444
|
+
|
|
4445
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4446
|
+
|
|
4447
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4448
|
+
|
|
4449
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4450
|
+
|
|
4451
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4452
|
+
src0[id],
|
|
4453
|
+
(device const float *) (src1 + bid*nb11),
|
|
4454
|
+
(device float *) ( dst + bid*nb1),
|
|
4455
|
+
ne00,
|
|
4456
|
+
ne01,
|
|
4457
|
+
ne02,
|
|
4458
|
+
ne10,
|
|
4459
|
+
ne12,
|
|
4460
|
+
ne0,
|
|
4461
|
+
ne1,
|
|
4462
|
+
r2,
|
|
4463
|
+
r3,
|
|
4464
|
+
tgpig,
|
|
4465
|
+
tiisg,
|
|
4466
|
+
sgitg);
|
|
4467
|
+
}
|
|
4468
|
+
|
|
4469
|
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
|
4470
|
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4471
|
+
device const char * ids,
|
|
4472
|
+
device const char * src1,
|
|
4473
|
+
device uchar * dst,
|
|
4474
|
+
constant int64_t & nbi1,
|
|
4475
|
+
constant int64_t & ne00,
|
|
4476
|
+
constant int64_t & ne01,
|
|
4477
|
+
constant int64_t & ne02,
|
|
4478
|
+
constant uint64_t & nb00,
|
|
4479
|
+
constant uint64_t & nb01,
|
|
4480
|
+
constant uint64_t & nb02,
|
|
4481
|
+
constant int64_t & ne10,
|
|
4482
|
+
constant int64_t & ne11,
|
|
4483
|
+
constant int64_t & ne12,
|
|
4484
|
+
constant int64_t & ne13,
|
|
4485
|
+
constant uint64_t & nb10,
|
|
4486
|
+
constant uint64_t & nb11,
|
|
4487
|
+
constant uint64_t & nb12,
|
|
4488
|
+
constant int64_t & ne0,
|
|
4489
|
+
constant int64_t & ne1,
|
|
4490
|
+
constant int64_t & nb1,
|
|
4491
|
+
constant uint & r2,
|
|
4492
|
+
constant uint & r3,
|
|
4493
|
+
constant int & idx,
|
|
4494
|
+
device const char * src00,
|
|
4495
|
+
device const char * src01,
|
|
4496
|
+
device const char * src02,
|
|
4497
|
+
device const char * src03,
|
|
4498
|
+
device const char * src04,
|
|
4499
|
+
device const char * src05,
|
|
4500
|
+
device const char * src06,
|
|
4501
|
+
device const char * src07,
|
|
4502
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4503
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4504
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4505
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4506
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4507
|
+
|
|
4508
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4509
|
+
|
|
4510
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4511
|
+
|
|
4512
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4513
|
+
|
|
4514
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4515
|
+
src0[id],
|
|
4516
|
+
(device const float *) (src1 + bid*nb11),
|
|
4517
|
+
(device float *) ( dst + bid*nb1),
|
|
4518
|
+
ne00,
|
|
4519
|
+
ne01,
|
|
4520
|
+
ne02,
|
|
4521
|
+
ne10,
|
|
4522
|
+
ne12,
|
|
4523
|
+
ne0,
|
|
4524
|
+
ne1,
|
|
4525
|
+
r2,
|
|
4526
|
+
r3,
|
|
4527
|
+
tgpig,
|
|
4528
|
+
tiisg,
|
|
4529
|
+
sgitg);
|
|
4530
|
+
}
|
|
4531
|
+
|
|
4532
|
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
|
4533
|
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4534
|
+
device const char * ids,
|
|
4535
|
+
device const char * src1,
|
|
4536
|
+
device uchar * dst,
|
|
4537
|
+
constant int64_t & nbi1,
|
|
4538
|
+
constant int64_t & ne00,
|
|
4539
|
+
constant int64_t & ne01,
|
|
4540
|
+
constant int64_t & ne02,
|
|
4541
|
+
constant uint64_t & nb00,
|
|
4542
|
+
constant uint64_t & nb01,
|
|
4543
|
+
constant uint64_t & nb02,
|
|
4544
|
+
constant int64_t & ne10,
|
|
4545
|
+
constant int64_t & ne11,
|
|
4546
|
+
constant int64_t & ne12,
|
|
4547
|
+
constant int64_t & ne13,
|
|
4548
|
+
constant uint64_t & nb10,
|
|
4549
|
+
constant uint64_t & nb11,
|
|
4550
|
+
constant uint64_t & nb12,
|
|
4551
|
+
constant int64_t & ne0,
|
|
4552
|
+
constant int64_t & ne1,
|
|
4553
|
+
constant int64_t & nb1,
|
|
4554
|
+
constant uint & r2,
|
|
4555
|
+
constant uint & r3,
|
|
4556
|
+
constant int & idx,
|
|
4557
|
+
device const char * src00,
|
|
4558
|
+
device const char * src01,
|
|
4559
|
+
device const char * src02,
|
|
4560
|
+
device const char * src03,
|
|
4561
|
+
device const char * src04,
|
|
4562
|
+
device const char * src05,
|
|
4563
|
+
device const char * src06,
|
|
4564
|
+
device const char * src07,
|
|
4565
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4566
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4567
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4568
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4569
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4570
|
+
|
|
4571
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4572
|
+
|
|
4573
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4574
|
+
|
|
4575
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4576
|
+
|
|
4577
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4578
|
+
src0[id],
|
|
4579
|
+
(device const float *) (src1 + bid*nb11),
|
|
4580
|
+
(device float *) ( dst + bid*nb1),
|
|
4581
|
+
ne00,
|
|
4582
|
+
ne01,
|
|
4583
|
+
ne02,
|
|
4584
|
+
ne10,
|
|
4585
|
+
ne12,
|
|
4586
|
+
ne0,
|
|
4587
|
+
ne1,
|
|
4588
|
+
r2,
|
|
4589
|
+
r3,
|
|
4590
|
+
tgpig,
|
|
4591
|
+
tiisg,
|
|
4592
|
+
sgitg);
|
|
4593
|
+
}
|
|
4594
|
+
|
|
4595
|
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
|
4596
|
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4597
|
+
device const char * ids,
|
|
4598
|
+
device const char * src1,
|
|
4599
|
+
device uchar * dst,
|
|
4600
|
+
constant int64_t & nbi1,
|
|
4601
|
+
constant int64_t & ne00,
|
|
4602
|
+
constant int64_t & ne01,
|
|
4603
|
+
constant int64_t & ne02,
|
|
4604
|
+
constant uint64_t & nb00,
|
|
4605
|
+
constant uint64_t & nb01,
|
|
4606
|
+
constant uint64_t & nb02,
|
|
4607
|
+
constant int64_t & ne10,
|
|
4608
|
+
constant int64_t & ne11,
|
|
4609
|
+
constant int64_t & ne12,
|
|
4610
|
+
constant int64_t & ne13,
|
|
4611
|
+
constant uint64_t & nb10,
|
|
4612
|
+
constant uint64_t & nb11,
|
|
4613
|
+
constant uint64_t & nb12,
|
|
4614
|
+
constant int64_t & ne0,
|
|
4615
|
+
constant int64_t & ne1,
|
|
4616
|
+
constant int64_t & nb1,
|
|
4617
|
+
constant uint & r2,
|
|
4618
|
+
constant uint & r3,
|
|
4619
|
+
constant int & idx,
|
|
4620
|
+
device const char * src00,
|
|
4621
|
+
device const char * src01,
|
|
4622
|
+
device const char * src02,
|
|
4623
|
+
device const char * src03,
|
|
4624
|
+
device const char * src04,
|
|
4625
|
+
device const char * src05,
|
|
4626
|
+
device const char * src06,
|
|
4627
|
+
device const char * src07,
|
|
4628
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4629
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4630
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4631
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4632
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4633
|
+
|
|
4634
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4635
|
+
|
|
4636
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4637
|
+
|
|
4638
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4639
|
+
|
|
4640
|
+
kernel_mul_mv_q2_K_f32_impl(
|
|
4641
|
+
src0[id],
|
|
4642
|
+
(device const float *) (src1 + bid*nb11),
|
|
4643
|
+
(device float *) ( dst + bid*nb1),
|
|
4644
|
+
ne00,
|
|
4645
|
+
ne01,
|
|
4646
|
+
ne02,
|
|
4647
|
+
ne10,
|
|
4648
|
+
ne12,
|
|
4649
|
+
ne0,
|
|
4650
|
+
ne1,
|
|
4651
|
+
r2,
|
|
4652
|
+
r3,
|
|
4653
|
+
tgpig,
|
|
4654
|
+
tiisg,
|
|
4655
|
+
sgitg);
|
|
4656
|
+
}
|
|
4657
|
+
|
|
4658
|
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
|
4659
|
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4660
|
+
device const char * ids,
|
|
4661
|
+
device const char * src1,
|
|
4662
|
+
device uchar * dst,
|
|
4663
|
+
constant int64_t & nbi1,
|
|
4664
|
+
constant int64_t & ne00,
|
|
4665
|
+
constant int64_t & ne01,
|
|
4666
|
+
constant int64_t & ne02,
|
|
4667
|
+
constant uint64_t & nb00,
|
|
4668
|
+
constant uint64_t & nb01,
|
|
4669
|
+
constant uint64_t & nb02,
|
|
4670
|
+
constant int64_t & ne10,
|
|
4671
|
+
constant int64_t & ne11,
|
|
4672
|
+
constant int64_t & ne12,
|
|
4673
|
+
constant int64_t & ne13,
|
|
4674
|
+
constant uint64_t & nb10,
|
|
4675
|
+
constant uint64_t & nb11,
|
|
4676
|
+
constant uint64_t & nb12,
|
|
4677
|
+
constant int64_t & ne0,
|
|
4678
|
+
constant int64_t & ne1,
|
|
4679
|
+
constant int64_t & nb1,
|
|
4680
|
+
constant uint & r2,
|
|
4681
|
+
constant uint & r3,
|
|
4682
|
+
constant int & idx,
|
|
4683
|
+
device const char * src00,
|
|
4684
|
+
device const char * src01,
|
|
4685
|
+
device const char * src02,
|
|
4686
|
+
device const char * src03,
|
|
4687
|
+
device const char * src04,
|
|
4688
|
+
device const char * src05,
|
|
4689
|
+
device const char * src06,
|
|
4690
|
+
device const char * src07,
|
|
4691
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4692
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4693
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4694
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4695
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4696
|
+
|
|
4697
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4698
|
+
|
|
4699
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4700
|
+
|
|
4701
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4702
|
+
|
|
4703
|
+
kernel_mul_mv_q3_K_f32_impl(
|
|
4704
|
+
src0[id],
|
|
4705
|
+
(device const float *) (src1 + bid*nb11),
|
|
4706
|
+
(device float *) ( dst + bid*nb1),
|
|
4707
|
+
ne00,
|
|
4708
|
+
ne01,
|
|
4709
|
+
ne02,
|
|
4710
|
+
ne10,
|
|
4711
|
+
ne12,
|
|
4712
|
+
ne0,
|
|
4713
|
+
ne1,
|
|
4714
|
+
r2,
|
|
4715
|
+
r3,
|
|
4716
|
+
tgpig,
|
|
4717
|
+
tiisg,
|
|
4718
|
+
sgitg);
|
|
4719
|
+
}
|
|
4720
|
+
|
|
4721
|
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
|
4722
|
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4723
|
+
device const char * ids,
|
|
4724
|
+
device const char * src1,
|
|
4725
|
+
device uchar * dst,
|
|
4726
|
+
constant int64_t & nbi1,
|
|
4727
|
+
constant int64_t & ne00,
|
|
4728
|
+
constant int64_t & ne01,
|
|
4729
|
+
constant int64_t & ne02,
|
|
4730
|
+
constant uint64_t & nb00,
|
|
4731
|
+
constant uint64_t & nb01,
|
|
4732
|
+
constant uint64_t & nb02,
|
|
4733
|
+
constant int64_t & ne10,
|
|
4734
|
+
constant int64_t & ne11,
|
|
4735
|
+
constant int64_t & ne12,
|
|
4736
|
+
constant int64_t & ne13,
|
|
4737
|
+
constant uint64_t & nb10,
|
|
4738
|
+
constant uint64_t & nb11,
|
|
4739
|
+
constant uint64_t & nb12,
|
|
4740
|
+
constant int64_t & ne0,
|
|
4741
|
+
constant int64_t & ne1,
|
|
4742
|
+
constant int64_t & nb1,
|
|
4743
|
+
constant uint & r2,
|
|
4744
|
+
constant uint & r3,
|
|
4745
|
+
constant int & idx,
|
|
4746
|
+
device const char * src00,
|
|
4747
|
+
device const char * src01,
|
|
4748
|
+
device const char * src02,
|
|
4749
|
+
device const char * src03,
|
|
4750
|
+
device const char * src04,
|
|
4751
|
+
device const char * src05,
|
|
4752
|
+
device const char * src06,
|
|
4753
|
+
device const char * src07,
|
|
4754
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4755
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4756
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4757
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4758
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4759
|
+
|
|
4760
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4761
|
+
|
|
4762
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4763
|
+
|
|
4764
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4765
|
+
|
|
4766
|
+
kernel_mul_mv_q4_K_f32_impl(
|
|
4767
|
+
src0[id],
|
|
4768
|
+
(device const float *) (src1 + bid*nb11),
|
|
4769
|
+
(device float *) ( dst + bid*nb1),
|
|
4770
|
+
ne00,
|
|
4771
|
+
ne01,
|
|
4772
|
+
ne02,
|
|
4773
|
+
ne10,
|
|
4774
|
+
ne12,
|
|
4775
|
+
ne0,
|
|
4776
|
+
ne1,
|
|
4777
|
+
r2,
|
|
4778
|
+
r3,
|
|
4779
|
+
tgpig,
|
|
4780
|
+
tiisg,
|
|
4781
|
+
sgitg);
|
|
4782
|
+
}
|
|
4783
|
+
|
|
4784
|
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
|
4785
|
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4786
|
+
device const char * ids,
|
|
4787
|
+
device const char * src1,
|
|
4788
|
+
device uchar * dst,
|
|
4789
|
+
constant int64_t & nbi1,
|
|
4790
|
+
constant int64_t & ne00,
|
|
4791
|
+
constant int64_t & ne01,
|
|
4792
|
+
constant int64_t & ne02,
|
|
4793
|
+
constant uint64_t & nb00,
|
|
4794
|
+
constant uint64_t & nb01,
|
|
4795
|
+
constant uint64_t & nb02,
|
|
4796
|
+
constant int64_t & ne10,
|
|
4797
|
+
constant int64_t & ne11,
|
|
4798
|
+
constant int64_t & ne12,
|
|
4799
|
+
constant int64_t & ne13,
|
|
4800
|
+
constant uint64_t & nb10,
|
|
4801
|
+
constant uint64_t & nb11,
|
|
4802
|
+
constant uint64_t & nb12,
|
|
4803
|
+
constant int64_t & ne0,
|
|
4804
|
+
constant int64_t & ne1,
|
|
4805
|
+
constant int64_t & nb1,
|
|
4806
|
+
constant uint & r2,
|
|
4807
|
+
constant uint & r3,
|
|
4808
|
+
constant int & idx,
|
|
4809
|
+
device const char * src00,
|
|
4810
|
+
device const char * src01,
|
|
4811
|
+
device const char * src02,
|
|
4812
|
+
device const char * src03,
|
|
4813
|
+
device const char * src04,
|
|
4814
|
+
device const char * src05,
|
|
4815
|
+
device const char * src06,
|
|
4816
|
+
device const char * src07,
|
|
4817
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4818
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4819
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4820
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4821
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4822
|
+
|
|
4823
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4824
|
+
|
|
4825
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4826
|
+
|
|
4827
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4828
|
+
|
|
4829
|
+
kernel_mul_mv_q5_K_f32_impl(
|
|
4830
|
+
src0[id],
|
|
4831
|
+
(device const float *) (src1 + bid*nb11),
|
|
4832
|
+
(device float *) ( dst + bid*nb1),
|
|
4833
|
+
ne00,
|
|
4834
|
+
ne01,
|
|
4835
|
+
ne02,
|
|
4836
|
+
ne10,
|
|
4837
|
+
ne12,
|
|
4838
|
+
ne0,
|
|
4839
|
+
ne1,
|
|
4840
|
+
r2,
|
|
4841
|
+
r3,
|
|
4842
|
+
tgpig,
|
|
4843
|
+
tiisg,
|
|
4844
|
+
sgitg);
|
|
4845
|
+
}
|
|
4846
|
+
|
|
4847
|
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
|
4848
|
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
|
4849
|
+
device const char * ids,
|
|
4850
|
+
device const char * src1,
|
|
4851
|
+
device uchar * dst,
|
|
4852
|
+
constant int64_t & nbi1,
|
|
4853
|
+
constant int64_t & ne00,
|
|
4854
|
+
constant int64_t & ne01,
|
|
4855
|
+
constant int64_t & ne02,
|
|
4856
|
+
constant uint64_t & nb00,
|
|
4857
|
+
constant uint64_t & nb01,
|
|
4858
|
+
constant uint64_t & nb02,
|
|
4859
|
+
constant int64_t & ne10,
|
|
4860
|
+
constant int64_t & ne11,
|
|
4861
|
+
constant int64_t & ne12,
|
|
4862
|
+
constant int64_t & ne13,
|
|
4863
|
+
constant uint64_t & nb10,
|
|
4864
|
+
constant uint64_t & nb11,
|
|
4865
|
+
constant uint64_t & nb12,
|
|
4866
|
+
constant int64_t & ne0,
|
|
4867
|
+
constant int64_t & ne1,
|
|
4868
|
+
constant int64_t & nb1,
|
|
4869
|
+
constant uint & r2,
|
|
4870
|
+
constant uint & r3,
|
|
4871
|
+
constant int & idx,
|
|
4872
|
+
device const char * src00,
|
|
4873
|
+
device const char * src01,
|
|
4874
|
+
device const char * src02,
|
|
4875
|
+
device const char * src03,
|
|
4876
|
+
device const char * src04,
|
|
4877
|
+
device const char * src05,
|
|
4878
|
+
device const char * src06,
|
|
4879
|
+
device const char * src07,
|
|
4880
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4881
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4882
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4883
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4884
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4885
|
+
|
|
4886
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4887
|
+
|
|
4888
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4889
|
+
|
|
4890
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4891
|
+
|
|
4892
|
+
kernel_mul_mv_q6_K_f32_impl(
|
|
4893
|
+
src0[id],
|
|
4894
|
+
(device const float *) (src1 + bid*nb11),
|
|
4895
|
+
(device float *) ( dst + bid*nb1),
|
|
4896
|
+
ne00,
|
|
4897
|
+
ne01,
|
|
4898
|
+
ne02,
|
|
4899
|
+
ne10,
|
|
4900
|
+
ne12,
|
|
4901
|
+
ne0,
|
|
4902
|
+
ne1,
|
|
4903
|
+
r2,
|
|
4904
|
+
r3,
|
|
4905
|
+
tgpig,
|
|
4906
|
+
tiisg,
|
|
4907
|
+
sgitg);
|
|
4908
|
+
}
|